Source code for alpaca.ue.base

import abc
from collections.abc import Iterable
from typing import Tuple
import torch
import torch.nn as nn

from alpaca.nn.modules.module import Module


__all__ = ["UE"]


[docs]class UE(metaclass=abc.ABCMeta): """ Abstract class for all uncertainty estimation method implementations Parameters ---------- net: :class:`torch.nn.Module` Neural network on based on which we are calculating uncertainty region nn_runs A number of iterations keep_runs: bool Whenever to save iteration results Examples -------- >>> # This could be used to create custom >>> # uncertainty estimation strategy >>> class CustomUE(UE): >>> def __init__(self, ...): >>> ... >>> def __call__(self, X_pool: torch.Tensor): >>> ... >>> estimator = CustomUE(model, ...) >>> predictions, estimations = estimator(x_batch) """ _name = None _default_acquisition = None def __init__( self, net, *, nn_runs=25, keep_runs: bool = False, ): # we are keeping list for ensemble estimators self.net = net self.nn_runs = nn_runs self.keep_runs = keep_runs if isinstance(net, nn.Module): # evaluate model for the model self.net.eval() self._masks_collect() self._reset() @abc.abstractmethod def __call__(self, X_pool: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Estimate uncertainty Parameters ---------- X_pool: torch.Tensor Batch of tensor based on which the uncertainty is estimated Returns ------- Tuple[torch.Tensor, torch.Tensor] """ pass
[docs] def last_mcd_runs(self) -> torch.Tensor: """ Return model prediction for the last uncertainty estimation """ if not self.keep_runs: raise ValueError( "mcd_runs: You should set `keep_runs=True` to properly use this method" ) return self._mcd_runs
@property def mcd_runs(self): if hasattr(self, "_mcd_runs"): # TODO we an add logger here to inform a user # that the `keep_runs` flag is False return self._mcd_runs return None @property def desc(self) -> str: return "Uncertainty estimation with {} approach".format(self._name) def _masks_collect_helper(self, model): if not isinstance(model, Iterable): model = [model] for model_ in model: for key, item in model_._modules.items(): if isinstance(item, Module): self.all_masks.add(item.dropout_mask) elif type(item) == nn.Sequential or type(item) == nn.ModuleList: for i, module in enumerate(item): self._masks_collect_helper(module) def _masks_collect(self): self.all_masks = set() self._masks_collect_helper(self.net) def _reset(self): for item in self.all_masks: if item: item.reset()