Source code for alpaca.ue.ensemble

from typing import Tuple, Optional, Union, Callable
import torch
from tqdm import tqdm
from functools import partial

from alpaca.ue.base import UE
from alpaca.models import Ensemble
from alpaca.ue import acquisitions
from alpaca.utils.model_builder import uncertainty_mode, inference_mode

__all__ = ["EnsembleMCDUE"]


[docs]class EnsembleMCDUE(UE): """ Estimate uncertainty for samples with Ensemble and MCDUE approach """ _name = "EnsembleMCDUE" _default_acquisition = partial(acquisitions.std) def __init__( self, *args, acquisition: Optional[Union[str, Callable]] = None, reduction=None, **kwargs ): super().__init__(*args, **kwargs) self._create_model_from_list() self.reduction = reduction # set acquisition strategy if acquisition is None: # set default acquisiiton strategy if not given # defined as the attribute for each subclass self._acquisition = self._default_acquisition elif callable(acquisition): self._acquisition = acquisition else: try: self._acquisition = acquisitions.acq_reg[acquisition] except KeyError: # TODO: move this to exceptions list raise ValueError("The given acquisition strategy doesn't exist") def __call__(self, X_pool: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: mcd_runs = None self.net = uncertainty_mode(self.net) with torch.no_grad(): # Some mask needs first run without dropout, i.e. decorrelation mask self.net(X_pool, reduction=self.reduction) # Get mcdue estimation for nn_run in tqdm(range(self.nn_runs), total=self.nn_runs, desc=self.desc): prediction = self.net(X_pool, reduction=self.reduction) mcd_runs = ( prediction.flatten().cpu()[None, ...] if mcd_runs is None else torch.cat( [mcd_runs, prediction.flatten().cpu()[None, ...]], dim=0 ) ) predictions = mcd_runs.mean(dim=0) # save `mcf_runs` stats if self.keep_runs is True: self._mcd_runs = mcd_runs self.net = inference_mode(self.net) return predictions, self._acquisition(mcd_runs) def _create_model_from_list(self): self.net = Ensemble(self.net) self.net.eval()