TR2-D2 / tr2d2-dna /diffusion_gosai_cfg.py
zyc4975matholic
Include DNA training code
303c2e0
import itertools
import math
from dataclasses import dataclass
import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
import torchmetrics
from torch import Tensor
import dataloader_gosai
import models
import noise_schedule
import utils
import oracle
LOG2 = math.log(2)
LOGGER = utils.get_logger(__name__)
def _sample_categorical(categorical_probs):
gumbel_norm = (
1e-10
- (torch.rand_like(categorical_probs) + 1e-10).log())
return (categorical_probs / gumbel_norm).argmax(dim=-1)
def _unsqueeze(x, reference):
return x.view(
* x.shape,
* ((1,) * (len(reference.shape) - len(x.shape))))
@dataclass
class Loss:
loss: torch.FloatTensor
nlls: torch.FloatTensor
token_mask: torch.FloatTensor
class NLL(torchmetrics.aggregation.MeanMetric):
pass
class BPD(NLL):
def compute(self) -> Tensor:
"""Computes the bits per dimension.
Returns:
bpd
"""
return self.mean_value / self.weight / LOG2
class Perplexity(NLL):
def compute(self) -> Tensor:
"""Computes the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.mean_value / self.weight)
class Diffusion(L.LightningModule):
def __init__(
self,
config,
eval=True):
super().__init__()
self.save_hyperparameters()
self.config = config
self.vocab_size = 4
self.sampler = self.config.sampling.predictor
self.antithetic_sampling = self.config.training.antithetic_sampling
self.importance_sampling = self.config.training.importance_sampling
self.change_of_variables = self.config.training.change_of_variables
self.mask_index = self.vocab_size
self.vocab_size += 1
self.parameterization = self.config.parameterization
if self.config.backbone == 'cnn':
self.backbone = models.dnaconv.CNNModel(
self.config.model, alphabet_size=self.vocab_size, num_cls=2)
else:
raise ValueError(
f'Unknown backbone: {self.config.backbone}')
self.T = self.config.T
self.subs_masking = self.config.subs_masking
self.softplus = torch.nn.Softplus()
# metrics are automatically reset at end of epoch
metrics = torchmetrics.MetricCollection({
'nll': NLL(),
'bpd': BPD(),
'ppl': Perplexity(),
})
metrics.set_dtype(torch.float64)
self.train_metrics = metrics.clone(prefix='train/')
self.valid_metrics = metrics.clone(prefix='val/')
self.test_metrics = metrics.clone(prefix='test/')
# generative perplexity
self.gen_ppl_metric = Perplexity()
self.noise = noise_schedule.get_noise(self.config,
dtype=self.dtype)
if self.config.training.ema > 0:
self.ema = models.ema.ExponentialMovingAverage(
itertools.chain(self.backbone.parameters(),
self.noise.parameters()),
decay=self.config.training.ema)
else:
self.ema = None
self.lr = self.config.optim.lr
self.sampling_eps = self.config.training.sampling_eps
self.time_conditioning = self.config.time_conditioning
self.neg_infinity = -1000000.0
self.fast_forward_epochs = None
self.fast_forward_batches = None
self._validate_configuration()
# subset of data for evaluation
if eval:
self.eval_sets_sp = oracle.subset_for_eval(n=config.eval.subset_size)
self.eval_sets_sp_clss = oracle.subset_eval_groundtruth(self.eval_sets_sp)
self.eval_sets_sp_preds = oracle.subset_eval_preds(self.eval_sets_sp)
self.eval_sets_sp_kmers = oracle.subset_eval_kmers(self.eval_sets_sp)
self.emb_pca = oracle.cal_emb_pca(oracle.subset_for_eval(n=40000), n_components=50)
self.eval_sets_sp_embs_pca = oracle.subset_eval_embs_pca(self.eval_sets_sp, self.emb_pca)
def _validate_configuration(self):
assert not (self.change_of_variables
and self.importance_sampling)
assert self.parameterization == 'subs'
def on_load_checkpoint(self, checkpoint):
if self.ema:
self.ema.load_state_dict(checkpoint['ema'])
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
self.fast_forward_epochs = checkpoint['loops'][
'fit_loop']['epoch_progress']['current']['completed']
self.fast_forward_batches = checkpoint['loops'][
'fit_loop']['epoch_loop.batch_progress'][
'current']['completed']
def on_save_checkpoint(self, checkpoint):
if self.ema:
checkpoint['ema'] = self.ema.state_dict()
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress.
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['total'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total'][
'completed'] * self.trainer.accumulate_grad_batches
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['current'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['current'][
'completed'] * self.trainer.accumulate_grad_batches
# _batches_that_stepped tracks the number of global steps, not the number
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
checkpoint['loops']['fit_loop'][
'epoch_loop.state_dict'][
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total']['completed']
if 'sampler' not in checkpoint.keys():
checkpoint['sampler'] = {}
if hasattr(self.trainer.train_dataloader.sampler,
'state_dict'):
sampler_state_dict = self.trainer.\
train_dataloader.sampler.state_dict()
checkpoint['sampler'][
'random_state'] = sampler_state_dict.get(
'random_state', None)
else:
checkpoint['sampler']['random_state'] = None
def on_train_start(self):
if self.ema:
self.ema.move_shadow_params_to_device(self.device)
# Adapted from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
distributed = (
self.trainer._accelerator_connector.use_distributed_sampler
and self.trainer._accelerator_connector.is_distributed)
print('distributed:', distributed)
if distributed:
sampler_cls = dataloader_gosai.FaultTolerantDistributedSampler
else:
sampler_cls = dataloader_gosai.RandomFaultTolerantSampler
updated_dls = []
for dl in self.trainer.fit_loop._combined_loader.flattened:
if hasattr(dl.sampler, 'shuffle'):
dl_sampler = sampler_cls(
dl.dataset, shuffle=dl.sampler.shuffle)
else:
dl_sampler = sampler_cls(dl.dataset)
if (distributed
and self.fast_forward_epochs is not None
and self.fast_forward_batches is not None):
dl_sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': (self.fast_forward_batches
* self.config.loader.batch_size)})
updated_dls.append(
torch.utils.data.DataLoader(
dl.dataset,
batch_size=self.config.loader.batch_size,
num_workers=self.config.loader.num_workers,
pin_memory=self.config.loader.pin_memory,
sampler=dl_sampler,
shuffle=False,
persistent_workers=True))
self.trainer.fit_loop._combined_loader.flattened = updated_dls
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
if self.ema:
self.ema.update(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
def _subs_parameterization(self, logits, xt):
logits[:, :, self.mask_index] += self.neg_infinity
logits = logits - torch.logsumexp(logits, dim=-1,
keepdim=True)
unmasked_indices = (xt != self.mask_index)
logits[unmasked_indices] = self.neg_infinity
logits[unmasked_indices, xt[unmasked_indices]] = 0
return logits
def _process_sigma(self, sigma):
if sigma is None:
assert self.parameterization == 'ar'
return sigma
if sigma.ndim > 1:
sigma = sigma.squeeze(-1)
if not self.time_conditioning:
sigma = torch.zeros_like(sigma)
assert sigma.ndim == 1, sigma.shape
return sigma
def forward(self, x, sigma, binary_clss=None):
"""Returns log score."""
sigma = self._process_sigma(sigma)
with torch.cuda.amp.autocast(dtype=torch.float32):
logits = self.backbone(x, sigma, cls=binary_clss)
if self.parameterization == 'subs':
return self._subs_parameterization(logits=logits, xt=x)
return logits
def _compute_loss(self, batch, prefix):
if 'attention_mask' in batch:
attention_mask = batch['attention_mask']
else:
attention_mask = None
# classifier-free guidance
assert self.config.model.cls_free_guidance == True
binary_clss = (batch['clss'][:,0] > self.config.model.cls_free_threshold).long()
random_list = np.random.binomial(1, self.config.model.cls_free_prob, binary_clss.shape[0])
binary_clss[random_list==1] = 2
losses = self._loss(batch['seqs'], attention_mask, binary_clss)
loss = losses.loss
if prefix == 'train':
self.train_metrics.update(losses.nlls, losses.token_mask)
metrics = self.train_metrics
elif prefix == 'val':
self.valid_metrics.update(losses.nlls, losses.token_mask)
metrics = self.valid_metrics
elif prefix == 'test':
self.test_metrics.update(losses.nlls, losses.token_mask)
metrics = self.test_metrics
else:
raise ValueError(f'Invalid prefix: {prefix}')
self.log_dict(metrics,
on_step=False,
on_epoch=True,
sync_dist=True)
return loss
def on_train_epoch_start(self):
self.backbone.train()
self.noise.train()
def training_step(self, batch, batch_idx):
loss = self._compute_loss(batch, prefix='train')
self.log(name='trainer/loss',
value=loss.item(),
on_step=True,
on_epoch=False,
sync_dist=True)
return loss
def on_validation_epoch_start(self):
if self.ema:
self.ema.store(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.ema.copy_to(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.backbone.eval()
self.noise.eval()
assert self.valid_metrics.nll.mean_value == 0
assert self.valid_metrics.nll.weight == 0
def validation_step(self, batch, batch_idx):
return self._compute_loss(batch, prefix='val')
def on_validation_epoch_end(self):
if ((self.config.eval.compute_perplexity_on_sanity
or not self.trainer.sanity_checking)
and self.config.eval.generate_samples
and not self.parameterization == 'ar'):
all_samples, all_detoeknized_samples = [], []
for _ in range(
self.config.sampling.num_sample_batches):
samples = self._sample(cls=1).detach().cpu().numpy()
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
all_samples.append(samples)
all_detoeknized_samples.extend(detokenized_samples)
all_samples = np.concatenate(all_samples, axis=0)
generated_preds = oracle.cal_gosai_pred(all_detoeknized_samples, mode='eval')[:,0]
avg_generated_preds = np.mean(generated_preds, axis=0)
current_step = self.trainer.global_step
LOGGER.info(f'Current step: {current_step}')
LOGGER.info(f'Generated preds: {avg_generated_preds}')
self.log('val/gosai_preds_avg', avg_generated_preds, on_step=False, on_epoch=True, sync_dist=True)
if self.ema:
self.ema.restore(
itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
def configure_optimizers(self):
# TODO(yair): Lightning currently giving this warning when using `fp16`:
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
# Not clear if this is a problem or not.
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
optimizer = torch.optim.AdamW(
itertools.chain(self.backbone.parameters(),
self.noise.parameters()),
lr=self.config.optim.lr,
betas=(self.config.optim.beta1,
self.config.optim.beta2),
eps=self.config.optim.eps,
weight_decay=self.config.optim.weight_decay)
scheduler = hydra.utils.instantiate(
self.config.lr_scheduler, optimizer=optimizer)
scheduler_dict = {
'scheduler': scheduler,
'interval': 'step',
'monitor': 'val/loss',
'name': 'trainer/lr',
}
return [optimizer], [scheduler_dict]
def q_xt(self, x, move_chance):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape (batch_size, 1).
"""
move_indices = torch.rand(
* x.shape, device=x.device) < move_chance
xt = torch.where(move_indices, self.mask_index, x)
return xt
def _sample_prior(self, *batch_dims):
return self.mask_index * torch.ones(
* batch_dims, dtype=torch.int64)
def _ddpm_caching_update(self, x, t, dt, p_x0=None):
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
move_chance_t = t[:, None, None]
move_chance_s = (t - dt)[:, None, None]
assert move_chance_t.ndim == 3, move_chance_t.shape
if p_x0 is None:
p_x0 = self.forward(x, sigma_t).exp()
assert move_chance_t.ndim == p_x0.ndim
q_xs = p_x0 * (move_chance_t - move_chance_s)
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
_x = _sample_categorical(q_xs)
copy_flag = (x != self.mask_index).to(x.dtype)
return p_x0, copy_flag * x + (1 - copy_flag) * _x
def _ddpm_update(self, x, t, dt, cls, w):
sigma_t, _ = self.noise(t)
sigma_s, _ = self.noise(t - dt)
if sigma_t.ndim > 1:
sigma_t = sigma_t.squeeze(-1)
if sigma_s.ndim > 1:
sigma_s = sigma_s.squeeze(-1)
assert sigma_t.ndim == 1, sigma_t.shape
assert sigma_s.ndim == 1, sigma_s.shape
move_chance_t = 1 - torch.exp(-sigma_t)
move_chance_s = 1 - torch.exp(-sigma_s)
move_chance_t = move_chance_t[:, None, None]
move_chance_s = move_chance_s[:, None, None]
unet_conditioning = sigma_t
uncond = (2 * torch.ones(x.shape[0], device=x.device)).long()
cond = (cls * torch.ones(x.shape[0], device=x.device)).long()
log_p_x0_uncond = self.forward(x, unet_conditioning, uncond)
log_p_x0_cond = self.forward(x, unet_conditioning, cond)
log_p_x0 = (1+w) * log_p_x0_cond - w * log_p_x0_uncond
assert move_chance_t.ndim == log_p_x0.ndim
q_xs = log_p_x0.exp() * (move_chance_t
- move_chance_s)
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
_x = _sample_categorical(q_xs)
copy_flag = (x != self.mask_index).to(x.dtype)
return copy_flag * x + (1 - copy_flag) * _x
def _ar_sampler(self, bsz):
# precompute token buffer
num_pred_tokens = self.config.model.length - 1
x = torch.zeros(
(bsz, num_pred_tokens + 1),
dtype=torch.long,
device=self.device)
x[:, 0] = self.tokenizer.bos_token_id
# precompute noise
noise = (torch.distributions.Gumbel(0, 1)
.sample((bsz, num_pred_tokens, self.vocab_size))
.to(self.device))
for i in range(num_pred_tokens):
next_logits = self.forward(x[:, :i + 1], None)[:, -1]
y = (next_logits + noise[:, i]).argmax(-1)
x[:, i + 1] = y
return x
@torch.no_grad()
def _sample(self, num_steps=None, eps=1e-5, eval_sp_size=None, cls=1, w=None):
"""Generate samples from the model."""
if w is None:
w = self.config.model.cls_free_weight
if eval_sp_size is None:
batch_size_per_gpu = self.config.loader.eval_batch_size
else:
batch_size_per_gpu = eval_sp_size
if self.parameterization == 'ar':
return self._ar_sampler(batch_size_per_gpu)
if num_steps is None:
num_steps = self.config.sampling.steps
x = self._sample_prior(
batch_size_per_gpu,
self.config.model.length).to(self.device)
timesteps = torch.linspace(
1, eps, num_steps + 1, device=self.device)
dt = (1 - eps) / num_steps
p_x0_cache = None
for i in range(num_steps):
t = timesteps[i] * torch.ones(
x.shape[0], 1, device=self.device)
if self.sampler == 'ddpm':
x = self._ddpm_update(x, t, dt, cls, w)
else:
raise NotImplementedError
if self.config.sampling.noise_removal:
t = timesteps[-1] * torch.ones(x.shape[0], 1,
device=self.device)
unet_conditioning = self.noise(t)[0]
uncond = (2 * torch.ones(x.shape[0], device=x.device)).long()
cond = (cls * torch.ones(x.shape[0], device=x.device)).long()
log_p_x0_uncond = self.forward(x, unet_conditioning, uncond)
log_p_x0_cond = self.forward(x, unet_conditioning, cond)
logits = (1+w) * log_p_x0_cond - w * log_p_x0_uncond
x = logits[:, :, :-1].argmax(dim=-1)
return x
def get_score(self, x, sigma):
model_output = self.forward(x, sigma)
if self.parameterization == 'subs':
# score(x, t) = p_t(y) / p_t(x)
# => log score(x, t) = log p_t(y) - log p_t(x)
# case 1: x = masked
# (i) y = unmasked
# log score(x, t) = log p_\theta(x)|_y + log k
# where k = exp(- sigma) / (1 - exp(- sigma))
# (ii) y = masked
# log score(x, t) = 0
# case 2: x = unmasked
# (i) y != masked, y != x
# log score(x_i, t) = - inf
# (ii) y = x
# log score(x_i, t) = 0
# (iii) y = masked token
# log score(x_i, t) = - log k
# where k = exp(- sigma) / (1 - exp(- sigma))
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
assert log_k.ndim == 1
masked_score = model_output + log_k[:, None, None]
masked_score[:, :, self.mask_index] = 0
unmasked_score = self.neg_infinity * torch.ones_like(
model_output)
unmasked_score = torch.scatter(
unmasked_score,
-1,
x[..., None],
torch.zeros_like(unmasked_score[..., :1]))
unmasked_score[:, :, self.mask_index] = - (
log_k[:, None] * torch.ones_like(x))
masked_indices = (x == self.mask_index).to(
model_output.dtype)[:, :, None]
model_output = (
masked_score * masked_indices
+ unmasked_score * (1 - masked_indices))
return model_output.exp()
def _staggered_score(self, score, dsigma):
score = score.clone()
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
score *= dsigma.exp()[:, None]
score[..., self.mask_index] += extra_const
return score
def _analytic_update(self, x, t, step_size):
curr_sigma, _ = self.noise(t)
next_sigma, _ = self.noise(t - step_size)
dsigma = curr_sigma - next_sigma
score = self.get_score(x, curr_sigma)
stag_score = self._staggered_score(score, dsigma)
probs = stag_score * self._transp_transition(x, dsigma)
return _sample_categorical(probs)
def _denoiser_update(self, x, t):
sigma, _ = self.noise(t)
score = self.get_score(x, sigma)
stag_score = self._staggered_score(score, sigma)
probs = stag_score * self._transp_transition(x, sigma)
probs[..., self.mask_index] = 0
samples = _sample_categorical(probs)
return samples
def _transp_transition(self, i, sigma):
sigma = _unsqueeze(sigma, reference=i[..., None])
edge = torch.exp(-sigma) * F.one_hot(
i, num_classes=self.vocab_size)
edge += torch.where(i == self.mask_index,
1 - torch.exp(-sigma).squeeze(-1),
0)[..., None]
return edge
def _sample_t(self, n, device):
_eps_t = torch.rand(n, device=device)
if self.antithetic_sampling:
# for variance reduction
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
if self.importance_sampling:
return self.noise.importance_sampling_transformation(t)
return t
def _maybe_sub_sample(self, x0, attention_mask):
seqlen = x0.shape[1]
if seqlen > self.config.model.length:
raise NotImplementedError('Sub-sampling not implemented')
elif self.parameterization == 'ar':
input_tokens = x0[:, :-1]
output_tokens = x0[:, 1:]
new_attention_mask = attention_mask[:, 1:]
else:
input_tokens = x0
output_tokens = None
new_attention_mask = attention_mask
return input_tokens, output_tokens, new_attention_mask
def _reconstruction_loss(self, x0):
t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
device=self.device)
assert self.config.noise.type == 'loglinear'
# The above assert is for d3pm parameterization
unet_conditioning = self.noise(t0)[0][:, None]
model_output_t0 = self.forward(x0, unet_conditioning)
return - torch.gather(input=model_output_t0,
dim=-1,
index=x0[:, :, None]).squeeze(-1)
def _forward_pass_diffusion(self, x0, binary_clss=None):
t = self._sample_t(x0.shape[0], x0.device)
if self.T > 0:
# else ts are between 0 and 1
t = (t * self.T).to(torch.int)
t = t / self.T
# t \in {1/T, 2/T, ..., 1}
t += (1 / self.T)
if self.change_of_variables: # False
unet_conditioning = t[:, None]
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
move_chance = torch.exp(f_0 + t * (f_T - f_0))
move_chance = move_chance[:, None]
else:
sigma, dsigma = self.noise(t) # total noise, rate noise
unet_conditioning = sigma[:, None]
move_chance = 1 - torch.exp(-sigma[:, None])
xt = self.q_xt(x0, move_chance) # q(xt|x0)
model_output = self.forward(xt, unet_conditioning, binary_clss=binary_clss)
utils.print_nans(model_output, 'model_output')
if self.parameterization == 'sedd':
return dsigma[:, None] * self._score_entropy(
model_output, sigma[:, None], xt, x0)
if self.T > 0:
diffusion_loss = self._d3pm_loss(
model_output=model_output, xt=xt, x0=x0, t=t)
if self.parameterization == 'd3pm':
reconstruction_loss = self._reconstruction_loss(x0)
elif self.parameterization == 'subs':
reconstruction_loss = 0
return reconstruction_loss + diffusion_loss
# SUBS parameterization, continuous time.
log_p_theta = torch.gather(
input=model_output,
dim=-1,
index=x0[:, :, None]).squeeze(-1)
if self.change_of_variables or self.importance_sampling:
return log_p_theta * torch.log1p(
- torch.exp(- self.noise.sigma_min))
return - log_p_theta * (
dsigma / torch.expm1(sigma))[:, None]
def _loss(self, x0, attention_mask, binary_clss):
(input_tokens, output_tokens,
attention_mask) = self._maybe_sub_sample(
x0, attention_mask)
if self.parameterization == 'ar':
logprobs = self.backbone(input_tokens, None, cls=binary_clss)
loss = - logprobs.gather(
-1, output_tokens[:, :, None])[:, :, 0]
else:
loss = self._forward_pass_diffusion(input_tokens, binary_clss=binary_clss)
nlls = loss * attention_mask
count = attention_mask.sum()
batch_nll = nlls.sum()
token_nll = batch_nll / count
return Loss(loss=token_nll,
nlls=nlls,
token_mask=attention_mask)
def _score_entropy(self, log_score, sigma, xt, x0):
"""Computes the SEDD loss.
Args:
log_score: float torch.Tensor with shape (batch_size,
diffusion_model_input_length, vocab_size),
log score, output of the denoising network.
xt: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
x0: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
sigma: float torch.Tensor with shape (batch_size, 1).
Returns:
loss with shape (batch_size, diffusion_model_input_length)
"""
# seems that it takes y=x0,xt=M case
# what is the const term for, seems to be y=M,xt=x0 case and x0 is known so score estimation is precise
masked_indices = xt == self.mask_index
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
q_ratio = 1 / expsig_minus_1[masked_indices]
words_that_were_masked = x0[masked_indices]
neg_term = q_ratio * torch.gather(
log_score[masked_indices],
-1,
words_that_were_masked[..., None]).squeeze(-1)
score = log_score[masked_indices].exp()
if self.mask_index == self.vocab_size - 1:
pos_term = score[:, :-1].sum(dim=-1)
else:
pos_term = score[:, : self.mask_index].sum(
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
const = q_ratio * (q_ratio.log() - 1)
entropy = torch.zeros(* xt.shape, device=xt.device)
entropy[masked_indices] += pos_term - neg_term + const
return entropy