Source code for alpaca.nn.modules.module

from typing import Optional
import torch.nn as nn


[docs]class Module: """ The class links nn.Module with the alpaca Module abstraction by allowing us to copy nn.Module instance's dictionary into this class instance. Additionally, the class introduces additional flags for the inference/uncertainty estimation modes. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.uncertainty_mode = False
[docs] def instantiate_with_dropout_params( self, module: nn.Module, dropout_rate: float = 0.0, dropout_mask: "BaseMask" = None, ) -> "alpaca.nn.Module": """ Copies the instant nn.Module but also adding dropout_mask/dropout_rate parameters Parameters ---------- module : nn.Module The instance nn.Module to be copied dropout_rate : float The dropout rate dropout_mask : "BaseMask" Base mask instance setting the type of mask of the module """ self.__dict__ = module.__dict__.copy() self.dropout_rate = dropout_rate self.dropout_mask = dropout_mask return self
def __str__(self) -> str: return "ann.{}, dropout_rate: {}, dropout_mask: {}".format( self.__class__.__name__, self.dropout_rate, self.dropout_mask.__class__.__name__, )
[docs] def ue_mode(self) -> "alpaca.nn.Module": """ Sets the alpaca.Module into the uncertainty estimaton mode. This will enable the dropout mask logic calculation with the dropout rate activated. """ self.uncertainty_mode = True return self
[docs] def inf_mode(self) -> "alpaca.nn.Module": """ Sets the alpaca.Module into inference mode. This will disable dropout_rate and dropout_mask of the module. """ self.uncertainty_mode = False return self