Source code for alpaca.ue.acquisitions
"""
TODO: Docs
"""
import torch
__all__ = [
"var_ratio",
"var_soft",
"bald",
"bald_normed",
]
# store all acquisitions in dict
acq_reg = {}
def reg_acquisition(f):
acq_reg[f.__name__] = f
return f
[docs]@reg_acquisition
def std(mcd_runs: torch.Tensor):
"""
var_ratio_acquisition
TODO: docs
"""
return mcd_runs.std(dim=0)
@reg_acquisition
def var_ratio(mcd_runs: torch.Tensor, nn_runs: int):
predictions = torch.argmax(mcd_runs, axis=-1)
# count how many time repeats the strongest class
mode_count = lambda preds: torch.max(torch.bincount(preds))
modes = [mode_count(point) for point in predictions]
ue = 1 - torch.stack(modes) / nn_runs
return ue
[docs]@reg_acquisition
def var_soft(mcd_runs: torch.Tensor):
"""
var_soft_acquisition
TODO: docs
"""
probabilities = torch.softmax(mcd_runs, axis=-1)
ue = torch.mean(torch.std(probabilities, dim=-2), dim=-1)
return ue
[docs]@reg_acquisition
def bald(mcd_runs: torch.Tensor):
"""
bald_acquisition
TODO: docs
"""
return _bald(mcd_runs)
def _entropy(x):
return torch.sum(-x * torch.log(torch.clamp(x, 1e-8, 1)), dim=-1)
def _bald(logits):
predictions = torch.softmax(logits, dim=-1)
predictive_entropy = _entropy(torch.mean(predictions, dim=1))
expected_entropy = torch.mean(_entropy(predictions), dim=1)
res = predictive_entropy - expected_entropy
return res