Source code for alpaca.ue.mcdue

from typing import Tuple, Optional, Union, Callable
import torch
from tqdm import tqdm
from functools import partial

from alpaca.ue.base import UE
from alpaca.ue import acquisitions
from alpaca.utils.model_builder import uncertainty_mode, inference_mode


__all__ = ["MCDUE"]


[docs]class MCDUE: r""" MCDUE constructor. Depending on the provided `num_classes` argument, the constructor will initialize **MCDUE_regression** or **MCDUE_classification** classes. Other Parameters ---------------- num_classes : int Integer that sets the number of classes for prediction Examples -------- >>> import alpaca >>> model : nn.Module = ... # define a torch nn.Model >>> model = train_model(...) # train the model >>> estimator = MCDUE(model, nn_runs=100, num_classes=10) >>> predictions, estimations = estimator(x_batch) """ _name = "MCDUE" _default_acquisition = None def __new__(cls, *args, num_classes=0, **kwargs): if num_classes == 0: return MCDUE_regression(*args, **kwargs) elif num_classes > 0: return MCDUE_classification(*args, num_classes=num_classes, **kwargs) else: raise ValueError("`num_classes` can't take the negative value")
class MCDUE_regression(UE): """ MCDUE implementation for regression task Default attributes ------------------ _name : "MCDUE_regression" _default_acquisition : :method:`alpaca.ue.acquisitions.std` """ _name = "MCDUE_regression" _default_acquisition = partial(acquisitions.std) def __init__( self, *args, acquisition: Optional[Union[str, Callable]] = None, **kwargs ): super().__init__(*args, **kwargs) # set acquisition strategy if acquisition is None: # set default acquisiiton strategy if not given # defined as the attribute for each subclass self._acquisition = self._default_acquisition elif callable(acquisition): self._acquisition = acquisition else: try: self._acquisition = acquisitions.acq_reg[acquisition] except KeyError: # TODO: move this to exceptions list raise ValueError("The given acquisition strategy doesn't exist") def __call__(self, X_pool: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: self.net = uncertainty_mode(self.net) mcd_runs = None with torch.no_grad(): self.net(X_pool) for nn_run in tqdm(range(self.nn_runs), total=self.nn_runs, desc=self.desc): prediction = self.net(X_pool) mcd_runs = ( prediction.flatten().cpu()[None, ...] if mcd_runs is None else torch.cat( [mcd_runs, prediction.flatten().cpu()[None, ...]], dim=0 ) ) predictions = mcd_runs.mean(dim=0) # save `mcf_runs` stats if self.keep_runs is True: self._mcd_runs = mcd_runs self.net = inference_mode(self.net) return predictions, self._acquisition(mcd_runs) class MCDUE_classification(UE): """ MCDUE implementation for a classification task Default attributes ------------------ _name : "MCDUE_classification" _default_acquisition : :method:`alpaca.ue.acquisitions.bald` Parameters ---------- num_classes Integer that sets the number of classes for prediction """ _name = "MCDUE_classification" _default_acquisition = partial(acquisitions.bald) def __init__( self, *args, num_classes, acquisition: Optional[Union[str, Callable]] = None, **kwargs ): super().__init__(*args, **kwargs) self.num_classes = num_classes # set acquisition strategy if acquisition is None: # set default acquisiiton strategy if not given # defined as the attribute for each subclass self._acquisition = self._default_acquisition elif callable(acquisition): self._acquisition = acquisition else: try: self._acquisition = acquisitions.acq_reg[acquisition] except KeyError: # TODO: move this to exceptions list raise ValueError("The given acquisition strategy doesn't exist") def __call__(self, X_pool: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: mcd_runs = None self.net = uncertainty_mode(self.net) with torch.no_grad(): self.net(X_pool) # Get mcdue estimation for nn_run in tqdm(range(self.nn_runs), total=self.nn_runs, desc=self.desc): prediction = self.net(X_pool) mcd_runs = ( prediction.cpu()[None, ...] if mcd_runs is None else torch.cat([mcd_runs, prediction.cpu()[None, ...]], dim=0) ) mcd_runs = mcd_runs.permute((1, 0, 2)) predictions = mcd_runs.mean(dim=1) if self.keep_runs is True: self._mcd_runs = mcd_runs self.net = inference_mode(self.net) return predictions, self._acquisition(mcd_runs)