Source code for storch.method.baseline

from __future__ import annotations
from abc import ABC, abstractmethod
import torch
import storch
from storch.tensor import StochasticTensor, CostTensor


[docs]class Baseline(ABC, torch.nn.Module):
[docs] @abstractmethod def compute_baseline( self, tensor: StochasticTensor, cost_node: CostTensor ) -> torch.Tensor: pass
[docs]class MovingAverageBaseline(Baseline): """ Takes the (unconditional) average over the different costs. """ def __init__(self: MovingAverageBaseline, exponential_decay=0.95): super().__init__() self.register_buffer("exponential_decay", torch.tensor(exponential_decay)) self.register_buffer("moving_average", torch.tensor(0.0))
[docs] def compute_baseline( self, tensor: StochasticTensor, cost_node: CostTensor ) -> torch.Tensor: avg_cost = storch.reduce_plates(cost_node).detach() self.moving_average = ( self.exponential_decay * self.moving_average + (1 - self.exponential_decay) * avg_cost )._tensor return self.moving_average
[docs]class BatchAverageBaseline(Baseline): """ Uses the average over the other samples as baseline. Introduced by https://arxiv.org/abs/1602.06725 """
[docs] def compute_baseline( self, tensor: StochasticTensor, costs: CostTensor ) -> torch.Tensor: if tensor.n == 1: raise ValueError( "Can only use the batch average baseline if multiple samples are used." ) costs = costs.detach() sum_costs = storch.sum(costs, tensor.name) # TODO: Should reduce correctly baseline = (sum_costs - costs) / (tensor.n - 1) return baseline