Source code for alpaca.ue.masks

import abc
from typing import Union
import torch

from alpaca.utils.functions import corrcoef, mc_probability, cov
from dppy.finite_dpps import FiniteDPP

__all__ = ["reg_masks"]

reg_masks = {}


def register_mask(cls):
    for key in cls._name_collection:
        reg_masks[key] = cls.__name__


[docs]class BaseMask(metaclass=abc.ABCMeta): __doc__ = r""" The base class for masks """ def __init__(self): self._init_run = True @abc.abstractmethod def __call__( self, x: torch.Tensor, *, dropout_rate: Union[torch.Tensor, float] = 0.5, ) -> torch.Tensor: """ Performs masked inference logic Parameters ---------- x : torch.Tensor Tensor to be masked dropout_rate : float Dropout rate of the binary mask Returns ------- x_ : torch.Tensor Masked tensor """ pass
[docs] def copy(self) -> "BaseMask": """ Creates the copy of an instance """ instance = self.__class__() instance.__dict__ = self.__dict__.copy() return instance
def reset(self): pass
[docs]class MaskLayered(BaseMask): __doc__ = r""" The base class for nn layered masks """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.layer_correlations: dict = dict() self.norm: dict = dict() @abc.abstractmethod def reset(self): """ Resets layers info/stat """ pass def reset(self): self.layer_correlations = None self._init_run = True
[docs]class BasicBernoulliMask(BaseMask): """ The implementation of Monte Carlo Dropout (MCD) logic More about the behaviours of MCD can be found in: https://arxiv.org/pdf/2008.02627.pdf Examples -------- >>> estimator = MCDUE(model, nn_runs=100, acquisition='std') >>> estimations1 = estimator.estimate(x_batch) """ _name_collection = {"mc_dropout"} def __init__(self): super().__init__() def __call__( self, x: torch.Tensor, dropout_rate: Union[torch.Tensor, float] = 0.5, *, is_train=True, ) -> torch.Tensor: dropout_rate = torch.as_tensor(dropout_rate) p = 1.0 - dropout_rate if p.lt(0.0) or p.gt(1.0): raise ValueError( "Dropout probability has to be between 0 and 1, " "but got {}".format(p.item) ) res = ( torch.bernoulli(p.expand(x.shape)) .div_(p) .to(dtype=x.dtype, device=x.device) ) return res
[docs]class DecorrelationMask(MaskLayered): """ TODO: """ _name_collection = {"decorrelating"} def __init__( self, *, scaling: bool = False, ht_norm: bool = False, eps: float = 1e-8, ): super().__init__() self.scaling = scaling # use adaptive scaling before softmax self.ht_norm = ht_norm self.eps = eps def __call__( self, x: torch.Tensor, dropout_rate: Union[torch.Tensor, float] = 0.5, *, is_train=True, ) -> torch.Tensor: mask_len = x.size(-1) k = int(mask_len * (1.0 - dropout_rate)) if self._init_run is True and is_train is False: self._init_run = False return self._init_layers(x, mask_len, k) mask = torch.zeros(mask_len, dtype=x.dtype, device=x.device) inds = torch.multinomial(self.layer_correlation, k, replacement=False) if self.ht_norm is True: mask[inds] = self.norm[inds] else: mask[inds] = torch.Tensor([1 / (1 - dropout_rate)]).to(dtype=x.dtype) return mask.expand_as(x) def _init_layers(self, x: torch.Tensor, mask_len: int, k: int) -> torch.Tensor: noise = torch.rand(*x.shape, dtype=x.dtype, device=x.device) * self.eps corrs = torch.sum(torch.abs(corrcoef((x + noise).transpose(1, 0))), dim=1) scores = torch.reciprocal(corrs) if self.scaling: scores = ( 4.0 * scores / torch.max(scores) ) # TODO: remove hard coding or annotate self.layer_correlation = torch.nn.functional.softmax(scores, dim=-1) if self.ht_norm: # Horvitz-Thopson normalization (1 / marginal_prob for each element) probabilities = self.layer_correlation samples = max(1000, 4 * x.size(-1)) # TODO: why? (explain 1000) self.norm = torch.reciprocal(mc_probability(probabilities, k, samples)) # Initially we should pass identity mask, # otherwise we won't get right correlations for all layers return torch.ones(mask_len, dtype=x.dtype, device=x.device)
[docs]class DecorrelationMaskScaled(MaskLayered): """ TODO: """ _name_collection = {"decorrelating_sc"} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.scaling = True
[docs]class LeverageScoreMask(MaskLayered): """ TODO: """ _name_collection = {"leveragescoremask"} def __init__( self, *, ht_norm: bool = True, lambda_: int = 1, covariance: bool = False, ): super().__init__() self.ht_norm = ht_norm self.lambda_ = lambda_ self.covariance = covariance def __call__( self, x: torch.Tensor, dropout_rate: Union[torch.Tensor, float] = 0.5, *, is_train=True, ) -> torch.Tensor: mask_len = x.shape[-1] k = int(mask_len * (1 - dropout_rate)) if self._init_run is True and is_train is False: self._init_run = False return self._init_layers(x, mask_len, k) mask = torch.zeros(mask_len, device=x.device) ids = torch.multinomial(self.layer_correlations, k, replacement=False) if self.ht_norm: mask[ids] = self.norm[ids] else: mask[ids] = 1 / (1 - dropout_rate) return mask.expand_as(x) def _init_layers(self, x: torch.Tensor, mask_len: int, k: int): if self.covariance: K = cov(x.transpose(1, 0)) else: K = corrcoef(x.transpose(1, 0)) identity = torch.eye(K.size(0)) leverages_matrix = K @ torch.inverse(K + self.lambda_ * identity) probabilities = torch.diagonal(leverages_matrix) probabilities = probabilities / torch.sum(probabilities) self.layer_correlations = probabilities if self.ht_norm: # Horvitz-Thopson normalization (1 / marginal_prob for each element) probabilities = self.layer_correlations samples = max(1000, 4 * x.size(-1)) # TODO: why? (explain 1000) self.norm = torch.reciprocal(mc_probability(probabilities, k, samples)) # Initially we should pass identity mask, # otherwise we won't get right correlations for all layers return torch.ones(mask_len, device=x.device)
[docs]class LeverageScoreMaskCov(LeverageScoreMask): _name_collection = {"cov_leverages"} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.ht_norm = True self.lambda_ = 1 self.covariance = True
[docs]class DPPMask(MaskLayered): def __init__(self, ht_norm: bool = False, covariance: bool = False): self.ht_norm = ht_norm self.covariance = covariance def __call__( self, x: torch.Tensor, dropout_rate: Union[torch.Tensor, float] = 0.5, *, is_train=True, ) -> torch.Tensor: if self._init_run is True and is_train is False: self._init_run = False return self._init_layers(x) # sampling nodes ids dpp = self.dpps for _ in range(ATTEMPTS): dpp.sample_exact() ids = dpp.list_of_samples[-1] if len(ids): # We should retry if mask is zero-length break mask_len = x.shape[-1] mask = torch.zeros(mask_len).double().cuda() if self.ht_norm: mask[ids] = self.norm[layer_num][ids] else: mask[ids] = mask_len / len(ids) return x.data.new(mask) def _init_layers(self, x: torch.Tensor, eps: float = 1e-12): x += torch.rand(x.shape) * eps if self.covariance: L = cov(x_matrix.transpose(0, 1)) else: L = corrcoef(x_matrix.transpose(0, 1)) self.dpps = FiniteDPP("likelihood", **{"L": L}) self.layer_correlations[layer_num] = L if self.ht_norm: L = torch.DoubleTensor(L).cuda() I = torch.eye(len(L)).double().cuda() K = torch.mm(L, torch.inverse(L + I)) self.norm[layer_num] = torch.reciprocal( torch.diag(K) ) # / len(correlations) self.L = L self.K = K return x.data.new(x.data.size()[-1]).fill_(1)