Source code for alpaca.nn.modules.dropout

from typing import Optional
import torch
import torch.nn as nn

from alpaca.nn.modules.module import Module

__all__ = ["Dropout"]


[docs]class Dropout(Module, nn.Dropout): """ The subclass of nn.Dropout layer with the additional `dropout_mask` and `dropout_rate` parameterization Parameters ---------- dropout_rate : float Dropout rate of the mask dropout_mask : "BaseMask" Base mask instance setting the type of mask of the module """ def __init__( self, *args, dropout_rate: float = 0.0, dropout_mask: "BaseMask" = None, **kwargs ): super().__init__(*args, **kwargs) self.dropout_rate = dropout_rate self.dropout_mask = dropout_mask def __call__( self, input: torch.Tensor, ) -> torch.Tensor: if self.training: return torch.nn.functional.dropout( input, p=self.dropout_rate, inplace=self.inplace ) else: if self.uncertainty_mode is True and self.dropout_mask: return input * self.dropout_mask( input, dropout_rate=self.dropout_rate, is_train=self.training ) else: return input