TR2-D2 / tr2d2-pep /diffusion.py
Sophia Tang
Initial commit
5e90249
raw
history blame
60.3 kB
import numpy as np
import sys
import itertools
import time
import torch
from torch import Tensor
import math
import torch.nn.functional as F
import numpy as np
import random as rd
import lightning as L
import torchmetrics
from dataclasses import dataclass
import gc
import utils.utils as utils
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
import noise_schedule
from torch.optim.lr_scheduler import _LRScheduler
import roformer as roformer
from utils.app import PeptideAnalyzer
import pandas as pd
base_path = '/path/to/your/home'
def _sample_categorical(categorical_probs):
gumbel_norm = (
1e-10
- (torch.rand_like(categorical_probs) + 1e-10).log())
return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
def _sample_categorical_gradient(categorical_probs, temp = 1.0):
gumbel_norm = (
1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
return output
def _unsqueeze(x, reference):
return x.view(
* x.shape,
* ((1,) * (len(reference.shape) - len(x.shape))))
def sample_batched_categorical(categorical_probs, batch_size):
"""
Generates `m` distinct sequences sampled from categorical probabilities
using the Gumbel distribution to ensure randomness while following probabilities
Args:
categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
representing categorical probabilities
m (int): number of distinct sequences to sample
Returns:
torch.Tensor: tensor of shape (m, sequence_length), where each row is a
distinct sequence of sampled category indices.
"""
_, sequence_length, vocab_size = categorical_probs.shape
# add Gumbel noise and sample m sequences
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
# select the highest score (most likely category after Gumbel noise)
sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
return sampled_sequences
def sample_batched_top_k(categorical_probs, batch_size, k):
"""
Generates `m` sequences sampled from the top-k probabilities of each token
using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
Args:
categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
representing categorical probabilities.
m (int): Number of sequences to sample.
k (int): Number of top probabilities to consider for sampling.
Returns:
torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
sampled sequence of category indices.
"""
_, sequence_length, vocab_length = categorical_probs.shape
# Add Gumbel noise to the log probabilities
gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
# Get the top-k categories based on noisy scores
top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
# Convert top-k scores back to probabilities and normalize
top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
# Sample randomly from the top-k probabilities
sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
# Map sampled indices back to the original vocabulary indices
sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
return sampled_sequences
@dataclass
class Loss:
loss: torch.FloatTensor
nlls: torch.FloatTensor
attn_mask: torch.FloatTensor
class NLL(torchmetrics.aggregation.MeanMetric):
pass
class BPD(NLL):
def compute(self) -> Tensor:
"""Computes the bits per dimension.
Returns:
bpd
"""
return self.mean_value / self.weight / math.log(2)
class Perplexity(NLL):
def compute(self) -> Tensor:
"""Computes the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.mean_value / self.weight)
class Diffusion(L.LightningModule):
def __init__(
self,
config,
tokenizer = None,
mode="finetune",
device=None,
):
super().__init__()
self.config = config
#self.save_hyperparameters()
# PeptideCLM tokenizer
if tokenizer is None:
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_vocab.txt',
f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_splits.txt')
else:
self.tokenizer = tokenizer
self.vocab_size = self.tokenizer.vocab_size
self.mask_index = self.tokenizer.mask_token_id
self.sampler = self.config.sampling.predictor
self.analyzer = PeptideAnalyzer()
# backbone LM PeptideCLM model
self.backbone = roformer.Roformer(self.config, self.tokenizer, device=device)
if mode == "finetune":
self.backbone.freeze_model()
self.backbone.unfreeze_n_layers(n=8)
elif mode == "eval":
self.backbone.freeze_model()
self.backbone.requires_grad_(False)
self.backbone.eval()
elif mode == "train":
self.backbone.requires_grad_(True)
self.backbone.train()
self.neg_infinity = -1000000.0
self.T = config.T
# noise schedule for non-peptide bond tokens (default to log-linear)
self.noise = noise_schedule.get_noise(config)
# noise schedule for peptide bonds (log-polynomial)
self.bond_noise = noise_schedule.LogPolyNoise()
self.time_conditioning = self.config.time_conditioning
self.fast_forward_epochs = None
self.fast_forward_batches = None
self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
self.gen_ppl_metric = Perplexity()
self.lr = self.config.optim.lr
self.sampling_eps = self.config.training.sampling_eps
metrics = torchmetrics.MetricCollection({
'nll': NLL(),
'bpd': BPD(),
'ppl': Perplexity(),
})
metrics.set_dtype(torch.float64)
self.train_metrics = metrics.clone(prefix='trainer/')
self.valid_metrics = metrics.clone(prefix='val/')
self.test_metrics = metrics.clone(prefix='test/')
### FOR THE EXPANSION AND ROLLOUT STEP ###
def sample_finetuned_with_rnd(self, args, reward_model, pretrained, eps=1e-5):
num_steps = args.total_num_steps
B = args.batch_size
x_rollout = self.sample_prior(
B, args.seq_length).to(self.device)
log_rnd = torch.zeros(args.batch_size, device=self.device)
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
dt = (1 - eps) / num_steps
for i in range(num_steps):
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
log_p, x_next, log_policy_step, log_pretrained_step = \
self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
log_rnd += log_pretrained_step - log_policy_step
x_rollout = x_next
# if mask token remains, fully unmask
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
# does **any** mask remain in any sequence
any_mask_global = mask_positions.any().item() # true if mask remains
if any_mask_global:
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
x_rollout = x_next
childSequences = self.tokenizer.batch_decode(x_rollout)
# change rewards for peptides
valid_x_final = []
validSequences = []
valid_log_rnd = []
for i in range(B):
# string sequence
childSeq = childSequences[i]
# check if the peptide is valid
if self.analyzer.is_peptide(childSeq):
valid_x_final.append(x_rollout[i])
validSequences.append(childSeq)
valid_log_rnd.append(log_rnd[i])
# compute multi-objective rewards
score_vectors = reward_model(input_seqs=validSequences)
scalar_rewards = np.sum(score_vectors, axis=-1)
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=self.device)
print(f"scalar reward dim{len(scalar_rewards)}")
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
log_rnd = valid_log_rnd + (scalar_rewards / args.alpha) # scale down by alpha
valid_x_final = torch.stack(valid_x_final, dim=0)
return valid_x_final, log_rnd, scalar_rewards
def sample_finetuned(self, args, reward_model, batch_size=None, dataframe=False, eps=1e-5):
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
print(f"device:{self.device}")
if batch_size is None:
batch_size = args.batch_size
num_steps = args.total_num_steps
x_rollout = self.sample_prior(
batch_size,
args.seq_length).to(self.device, dtype=torch.long)
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
dt = torch.tensor((1 - eps) / num_steps, device=self.device)
for i in range(num_steps):
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
x_rollout = x_next
x_rollout = x_rollout.to(self.device)
# if mask token remains, fully unmask
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
# does **any** mask remain in any sequence
any_mask_global = mask_positions.any().item() # true if mask remains
if any_mask_global:
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
x_rollout = x_next
x_rollout = x_rollout.to(self.device)
childSequences = self.tokenizer.batch_decode(x_rollout)
valid_x_final = []
validSequences = []
for idx, seq in enumerate(childSequences):
if self.analyzer.is_peptide(seq):
valid_x_final.append(x_rollout[idx])
validSequences.append(seq)
valid_fraction = len(validSequences) / batch_size
if (len(validSequences) != 0):
# add scores to log
score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
average_scores = score_vectors.T
affinity = average_scores[0]
sol = average_scores[1]
hemo = average_scores[2]
nf = average_scores[3]
permeability = average_scores[4]
else:
zeros = [0.0]
affinity = zeros
sol = zeros
hemo = zeros
nf = zeros
permeability = zeros
if dataframe:
df = pd.DataFrame({
"Peptide Sequence": validSequences,
"Binding Affinity": affinity if len(validSequences) else [0.0],
"Solubility": sol if len(validSequences) else [0.0],
"Hemolysis": hemo if len(validSequences) else [0.0],
"Nonfouling": nf if len(validSequences) else [0.0],
"Permeability": permeability if len(validSequences) else [0.0],
})
return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction, df
return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction
def compute_log_policy(self, token_array, x_next, t, dt, attn_mask=None):
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
sigma_t, _ = self.noise(t)
if token_array.ndim == 1:
token_array = token_array.unsqueeze(0)
if x_next.ndim == 1:
x_next = x_next.unsqueeze(0)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if attn_mask is None:
attn_mask = torch.ones_like(token_array).to(self.device)
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
p_x0 = log_p.exp()
assert change_prob_t.ndim == p_x0.ndim
q_xs = p_x0 * (change_prob_t - change_prob_s)
# zero-masking probability
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
copy_flag = (token_array != self.mask_index)
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
changed_mask = (~copy_flag)
# compute the per-sequence log-probability under the pretrained model
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
# returns:
# log_policy_step (B, ) log probability x_next tokens under policy
if log_policy_step.ndim == 1:
log_policy_step = log_policy_step.squeeze(0)
return log_policy_step
def single_reverse_step(self, token_array, t, dt, p_x0=None, attn_mask=None):
torch.cuda.empty_cache()
dev = self.device
self.backbone.to(dev).eval()
self.noise.eval()
t = t.to(dev)
dt = torch.as_tensor(dt, device=dev, dtype=t.dtype)
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
sigma_t = sigma_t.to(dev)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if attn_mask is None:
attn_mask = torch.ones_like(token_array, device=dev, dtype=torch.long)
else:
attn_mask = attn_mask.to(dev)
if p_x0 is None:
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
p_x0 = log_p.exp()
else:
# ensure provided p_x0 is on dev
log_p = None
p_x0 = p_x0.to(dev)
assert change_prob_t.ndim == p_x0.ndim
q_xs = p_x0 * (change_prob_t - change_prob_s)
# zero-masking probability
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
x_changed = _sample_categorical(q_xs)
if x_changed.device != dev or x_changed.dtype != token_array.dtype:
x_changed = x_changed.to(dev, dtype=token_array.dtype)
copy_flag = (token_array != self.mask_index)
int_copy_flag = copy_flag.to(token_array.dtype)
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
# returns:
# log_p (B, L, D) log probabilties of each token under the policy model
# x_next (B, L) next sequences
return log_p, x_next
def single_noise_removal(self, token_array, t, dt, p_x0=None, attn_mask=None):
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if attn_mask is None:
attn_mask = torch.ones_like(token_array).to(self.device)
if p_x0 is None:
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
p_x0 = log_p.exp()
assert change_prob_t.ndim == p_x0.ndim
# changed for noise removal
p_x0 = p_x0.clone()
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
q_xs = p_x0 * (change_prob_t - change_prob_s)
x_changed = _sample_categorical(q_xs)
copy_flag = (token_array != self.mask_index)
int_copy_flag = copy_flag.to(token_array.dtype)
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
# returns:
# log_p (B, L, D) log probabilties of each token under the policy model
# x_next (B, L) next sequences
return log_p, x_next
def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if attn_mask is None:
attn_mask = torch.ones_like(token_array).to(self.device)
if p_x0 is None:
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
p_x0 = log_p.exp()
assert change_prob_t.ndim == p_x0.ndim
q_xs = p_x0 * (change_prob_t - change_prob_s)
# zero-masking probability
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
x_changed = _sample_categorical(q_xs)
copy_flag = (token_array != self.mask_index)
int_copy_flag = copy_flag.to(token_array.dtype)
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
# compute the log-probability under pretrained model at each step
with torch.no_grad():
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
# log-prob of the *sampled token* at each position
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
# sum only over the sites actually sampled this step (i.e., where parent was mask)
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
changed_mask = (~copy_flag)
# mask of tokens that were unmasked in this step
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
# compute the per-sequence log-probability under the pretrained model
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
# returns:
# log_p (B, L, D) log probabilties of each token under the policy model
# x_next (B, L) next sequences
# log_policy_step (B, ) log probability of all unmasked tokens under policy
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
return log_p, x_next, log_policy_step, log_pretrained_step
def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if attn_mask is None:
attn_mask = torch.ones_like(token_array).to(self.device)
if p_x0 is None:
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
p_x0 = log_p.exp()
assert change_prob_t.ndim == p_x0.ndim
# changed for noise removal
p_x0 = p_x0.clone()
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
q_xs = p_x0 * (change_prob_t - change_prob_s)
x_changed = _sample_categorical(q_xs)
copy_flag = (token_array != self.mask_index)
int_copy_flag = copy_flag.to(token_array.dtype)
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
# compute the log-probability under pretrained model at each step
with torch.no_grad():
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
# log-prob of the *sampled token* at each position
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
# sum only over the sites actually sampled this step (i.e., where parent was mask)
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
changed_mask = (~copy_flag)
# mask of tokens that were unmasked in this step
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
# compute the per-sequence log-probability under the pretrained model
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
# returns:
# log_p (B, L, D) log probabilties of each token under the policy model
# x_next (B, L) next sequences
# log_policy_step (B, ) log probability of all unmasked tokens under policy
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
return log_p, x_next, log_policy_step, log_pretrained_step
# first step in expansion
def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None, attn_mask=None):
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if token_array.dim() == 1:
token_array = token_array.unsqueeze(0)
# expand to match (num_children, L)
if attn_mask is None:
attn_mask = torch.ones_like(token_array).to(self.device)
token_array = token_array.to(self.device)
sigma_t = sigma_t.to(self.device)
if p_x0 is None:
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
p_x0 = log_p.exp()
assert change_prob_t.ndim == p_x0.ndim
q_xs = p_x0 * (change_prob_t - change_prob_s)
# zero-masking probability
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
token_array_expanded = token_array.repeat(batch_size, 1)
if self.config.mcts.sampling == 0:
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
else:
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
copy_flag = (token_array_expanded != self.mask_index)
int_copy_flag = copy_flag.to(token_array.dtype)
x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
# compute the log-probability under pretrained model at each step
with torch.no_grad():
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
# expand to match the shape of x_children
log_pre = log_pre.repeat(batch_size, 1, 1)
# log-prob of the *sampled token* at each position
log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
# sum only over the sites actually sampled this step (i.e., where parent was mask)
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
changed_mask = (~copy_flag)
# mask of tokens that were unmasked in this step
unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
# compute the per-child log-probability under the pretrained model
log_p = log_p.repeat(batch_size, 1, 1)
log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
#print(log_policy_token)
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
# returns:
# log_p (B, L, D) log probabilties of each token under the policy model
# x_children (B, L) child sequences
# log_policy_step (B, ) log probability of all unmasked tokens under policy
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
return log_p, x_children, log_policy_step, log_pretrained_step
def compute_invalid_loss(self, logits, k=None, temp=None):
"""
Penalizes logits that produce invalid sequences using the `is_peptide` function,
scaling penalties inversely with token probabilities.
Args:
logits: Tensor of shape [batch_size, seq_len, vocab_size].
k: Number of samples for Gumbel-Rao.
temp: Temperature for softmax.
Returns:
loss: A scalar tensor representing the total loss for invalid sequences.
"""
#samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
# Convert logits to sequences using the tokenizer
batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
# Check validity of each sampled sequence (not differentiable)
penalties = torch.tensor(
[1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
dtype=torch.float32,
device=self.device
)
#print(penalties)
# Compute probabilities for each token (batch_size, seq_length)
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
# scale penalties by softmax probability of sampled tokens
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
return scaled_penalty.to(self.device)
### DIFFUSION LOSS ###
def sample_t(self, n, device):
"""
Sample random time steps for batch training
"""
# sample values uniformly at random from [0, 1)
eps_t = torch.rand(n, device=device)
# antithetic sampling: reduce variance by pairing each sample with complementary sample
if self.config.training.antithetic_sampling:
# compute interval between sampled time steps
offset = torch.arange(n, device=device) / n
# ensure that each eps value is evenly spaced between [0, 1)
eps_t = ((eps_t / n) + offset) % 1
# ensures values are not exactly 0 or 1
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
return t
"""def mask_samples(self, x0, mask_prob):
# generate array of values in range [0, 1] uniformly at random
# will be used to determine which tokens are masked
mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
# select tokens to mask if the random value in mask_indices is less than mask_prob
# this will mask approximately the fraction of tokens indicated by mask_prob
zt = torch.where(mask_indices < mask_prob, self.mask_index, x0)
return zt"""
def q_xt(self, x, mask_prob):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape (batch_size, 1).
"""
actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
#print(actual_seq_length)
max_mask_length = (actual_seq_length * 0.75).long()
mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
for i in range(x.shape[0]):
true_positions = torch.where(mask_indices[i])[0]
if len(true_positions) > max_mask_length[i]:
selected_positions = true_positions[:max_mask_length[i].item()]
restricted_move_indices[i, selected_positions] = True
else:
restricted_move_indices[i] = mask_indices[i]
xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
return xt
def sample_prior(self, *batch_dims):
"""
Returns array of fully masked sequences with same shape as input
"""
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
### COMPUTING LOSS ###
def compute_diffusion_loss(self, model_output, xt, x0, t):
"""
Computes diffusion loss term in ELBO
(evaluates how accurately the model predicts the token probabilities at each time step)
Inputs:
- model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
- zt: corrupted version of original input x0 at timestep t
- x0: original input sequence
- t: timestep
"""
# compute interval between each timestep
dt = 1 / self.T
# compute vectorized alpha scaling terms for the logits at timestep s and t
alpha_t = 1 - t + torch.zeros_like(x0)
# s = t - dt
alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
# gather vector of log-probabilities for each token in x0
# log<x_theta, x>
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
# gather log-probabillities for assigning a masked token at each position in the sequence at time t
# log<x_theta, m>
log_x_theta_at_m = model_output[:, :, self.mask_index]
# obtain non-log probability of assigning a masked token
# <xt, m>
x_theta_at_m = log_x_theta_at_m.exp()
# first term of diffusion loss
term_1_coef = dt / t
term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
term_1_log_denom = log_x_theta_at_x0
# second term of diffusion loss
term_2_coef = 1 - (dt / t)
term_2_log_numerator = term_1_log_numerator
term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
term_2_coef * (term_2_log_numerator - term_2_log_denom))
# multiply by <zt, m> term
L_vb = L_vb_masked * (xt == self.mask_index)
# scale by T and return
return self.T * L_vb
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
"""
Training reverse diffusion model x_theta to reconstruct samples x0
bond_mask: (batch, seq_length)
"""
# randomly sample time steps to start the denoising process for each x0 in batch
t = self.sample_t(x0.shape[0], self.device)
# if we are training the intermediate transition blocks
if self.T > 0:
# scale by total timesteps T and cast to integer
t = (t * self.T).to(torch.int)
# scale down by T to get a multiple of 1/T
t = t / self.T
# add 1/T to ensure no 0 values
t += (1 / self.T)
# get noise and rate of noise at timestep t
# sigma = -log(1-t); dsigma = 1 / (1-t)
sigma, dsigma = self.noise(t)
time_conditioning = sigma[:, None]
# Get masking probabilities for all tokens for each batch
# log-linear: 1 - alpha = t
base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
if self.config.noise.state_dependent and (bond_mask is not None):
# log-polynomial masking schedule: alpha = 1 - t^w
# bond_sigma = -log(1-t^w) for w = 3 (default)
# bond_dsigma = -wt^(w-1) / (1-t^w)
bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
# expand dimensions for broadcasting to (B, L)
bond_sigma = bond_sigma[:, None]
bond_dsigma = bond_dsigma[:, None]
sigma = sigma[:, None]
dsigma = dsigma[:, None]
# compute masking probability for peptide bonds 1 - bond_alpha = t^w
bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
# piece together (B, L) tensor with modified masking prob at peptide-bond locations
mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
#print(mask_prob)
dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
else:
mask_prob = base_mask_prob.to(self.device)
# get masked samples at different timesteps
if mask is None:
zt = self.q_xt(x0, mask_prob).to(self.device)
else:
zt = x0.where(mask==1, torch.full_like(x0, self.mask_index)).to(self.device)
model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
# debugging
assert not torch.isnan(model_output).any()
assert model_output.is_cuda
utils.print_nans(model_output, 'model_output')
# compute invalid loss
invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
#print(invalid_loss)
if self.T > 0:
# compute diffusion loss
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
return diffusion_loss
# compute loss for the final that converts from z0 to x0
# -log(p_theta)
# get (batch_size, L) array of log-probabilities
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
if self.config.noise.state_dependent and (bond_mask is not None):
return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
else:
return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
# negative log loss
nlls = loss * attn_mask
# count number of tokens
num_tokens = attn_mask.sum()
# compute batch loss
batch_nll = nlls.sum()
# compute per token loss
token_nll = batch_nll / num_tokens
# return losses
return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
def _compute_loss(self, batch, prefix, bond_mask=None):
attn_mask = batch['attention_mask'].to(self.device)
if 'mask' in batch:
mask = batch['mask'].to(self.device)
else:
mask = None
if 'bond_mask' in batch:
bond_mask = batch['bond_mask'].to(self.device)
else:
bond_mask = None
losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
loss = losses.loss
if prefix == 'train':
self.train_metrics.update(
losses.nlls.to(self.device),
losses.attn_mask.to(self.device)
)
metrics = self.train_metrics
elif prefix == 'val':
self.valid_metrics.update(
losses.nlls.to(self.device),
losses.attn_mask.to(self.device)
)
metrics = self.valid_metrics
elif prefix == 'test':
self.test_metrics.update(losses.nlls, losses.attn_mask)
metrics = self.test_metrics
else:
raise ValueError(f'Invalid prefix: {prefix}')
self.log_dict(metrics,
on_step=False,
on_epoch=True,
sync_dist=True)
return loss
### SAMPLING ###
def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
# get number of timesteps
if sample_steps is None:
sample_steps = self.config.sampling.steps
if seq_length is None:
seq_length = self.config.sampling.seq_length
# sample fully masked sequences
z = self.sample_prior(num_samples, seq_length).to(self.device)
# create vector of sample_steps timesteps
timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
# compute interval between timesteps
dt = (1 - eps) / sample_steps
for i in range(sample_steps):
t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
z = self.single_reverse_step(z, t, dt)
return z
### SAMPLING STEP ###
"""
def single_reverse_step(self, zt, t, dt, attn_mask=None):
# get sigma values that determine masking prob
sigma_t, _ = self.noise(t)
sigma_s, _ = self.noise(t - dt)
# reshape sigmas
if sigma_t.ndim > 1:
sigma_t = sigma_t.squeeze(-1)
if sigma_s.ndim > 1:
sigma_s = sigma_s.squeeze(-1)
assert sigma_t.ndim == 1, sigma_t.shape
assert sigma_s.ndim == 1, sigma_s.shape
# compute masking probabilities for each timestep
change_prob_t = 1 - torch.exp(-sigma_t)
change_prob_s = 1 - torch.exp(-sigma_s)
# expand dimensions
change_prob_t = change_prob_t[:, None, None]
change_prob_s = change_prob_s[:, None, None]
# get prodiction model that outputs token probabilities
log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
# check dimensions match
assert change_prob_t.ndim == log_p_x0.ndim
# compute reverse diffusion probability of being unmasked at timestep s
# (sigma_s - sigma_t)*x_theta
q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
# compute reverse diffusion probability of remaining masked at timestep s
# (1 - sigma_s)*m
q_zs[:, :, self.mask_index] = change_prob_s[:, :, 0]
# sample sequence at timestep s from categorical distribution of q_zs
z_changed = _sample_categorical(q_zs)
copy_flag = (zt != self.mask_index).to(zt.dtype)
return (copy_flag * zt) + ((1 - copy_flag) * z_changed)"""
def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if p_x0 is None:
p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
assert change_prob_t.ndim == p_x0.ndim
q_xs = p_x0 * (change_prob_t - change_prob_s)
# zero-masking probability
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
x_changed = _sample_categorical(q_xs)
copy_flag = (x != self.mask_index).to(x.dtype)
return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
# first step in expansion
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
"""
Generates batch_size different samples from the same starting point for the
first expansion step of MCTS
"""
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
change_prob_t = t[:, None, None]
change_prob_s = (t - dt)[:, None, None]
assert change_prob_t.ndim == 3, change_prob_t.shape
if token_array.dim() == 1:
token_array = token_array.unsqueeze(0)
#token_array = token_array.repeat(batch_size, 1)
attn_mask = torch.ones_like(token_array).to(self.device)
if p_x0 is None:
p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
assert change_prob_t.ndim == p_x0.ndim
q_xs = p_x0 * (change_prob_t - change_prob_s)
# zero-masking probability
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
token_array = token_array.repeat(batch_size, 1)
if self.config.mcts.sampling == 0:
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
else:
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
copy_flag = (token_array != self.mask_index).to(token_array.dtype)
return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
def _process_sigma(self, sigma):
if sigma.ndim > 1:
sigma = sigma.squeeze(-1)
if not self.time_conditioning:
sigma = torch.zeros_like(sigma)
assert sigma.ndim == 1, sigma.shape
return sigma
def forward(self, zt, attn_mask, sigma):
"""
Predicts the token log-probabilities from zt at time t with noise schedule sigma
"""
sigma = self._process_sigma(sigma)
with torch.cuda.amp.autocast(dtype=torch.float32):
logits = self.backbone(zt, attn_mask).to(self.device)
return self.subs_parameterization(logits, zt)
def subs_parameterization(self, logits, zt):
"""
Updates reverse diffusion logits based on SUBS parameterization:
- zero masking probabilities: -infinity probability of being masked during reverse diffusion
- carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
Args:
logits: vector of token probabilities for unmasking masked tokens
zt: partially unmasked sequence at current timestep
"""
logits[:, :, self.mask_index] += self.neg_infinity # [sequence index, current token, next token]
logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
unmasked_indices = (zt != self.mask_index).to(self.device) # shape: [200, seq_length]
batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
batch_idx = batch_idx.to(self.device)
seq_idx = seq_idx.to(self.device)
tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
#assert logits.is_contiguous(), "logits tensor is not contiguous"
#assert unmasked_indices.shape == zt.shape, "same shape"
#assert not torch.isnan(logits).any(), "NaN values found in logits"
#assert tokens.max() < logits.shape[-1], "token indices out of bounds"
#assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
#assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
#assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
logits[unmasked_indices] = self.neg_infinity # Set everything to -inf first
logits[unmasked_indices, zt[unmasked_indices]] = 0 # Set only the specific token positions to 0
# return logits with SUBS parameterization
return logits.to(self.device)
"""SAMPLING"""
@torch.no_grad()
def _sample(self, num_steps=None, eps=1e-5, x_input=None):
"""
Generate samples
"""
batch_size_per_gpu = self.config.eval.perplexity_batch_size
if num_steps is None:
num_steps = self.config.sampling.steps
if x_input is not None:
x = x_input['input_ids'].to(self.device)
attn_mask = x_input['attention_mask'].to(self.device)
else:
x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
attn_mask = torch.ones_like(x).to(self.device)
timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
dt = (1 - eps) / num_steps
p_x0_cache = None
generation_history = [] # used to track which tokens are unmasked
for i in range(num_steps):
t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
if self.sampler == 'ddpm':
x = self.single_reverse_step(x, t, dt).to(self.device)
elif self.sampler == 'ddpm_cache':
p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
if (not torch.allclose(x_next, x) or self.time_conditioning):
# Disable caching
p_x0_cache = None
x = x_next.to(self.device)
#print(self.tokenizer.decode(x.squeeze()))
else:
x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
if self.config.sampling.noise_removal:
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
if self.sampler == 'analytic':
x = self._denoiser_update(x, t).to(self.device)
else:
time_conditioning = self.noise(t)[0].to(self.device)
x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
#print(self.tokenizer.decode(x.squeeze()))
return x.to(self.device)
def restore_model_and_sample(self, num_steps, eps=1e-5):
"""Generate samples from the model."""
self.backbone.eval()
self.noise.eval()
samples = self._sample(num_steps=num_steps, eps=eps)
self.backbone.train()
self.noise.train()
return samples
def get_score(self, zt, sigma, attn_mask=None):
# score(x, t) = p_t(y) / p_t(x)
# => log score(x, t) = log p_t(y) - log p_t(x)
# case 1: x = masked
# (i) y = unmasked
# log score(x, t) = log p_\theta(x)|_y + log k
# where k = exp(- sigma) / (1 - exp(- sigma))
# (ii) y = masked
# log score(x, t) = 0
# case 2: x = unmasked
# (i) y != masked, y != x
# log score(x_i, t) = - inf
# (ii) y = x
# log score(x_i, t) = 0
# (iii) y = masked token
# log score(x_i, t) = - log k
# where k = exp(- sigma) / (1 - exp(- sigma))
model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
assert log_k.ndim == 1
masked_score = model_output + log_k[:, None, None]
masked_score[:, :, self.mask_index] = 0
unmasked_score = self.neg_infinity * torch.ones_like(model_output)
unmasked_score = torch.scatter(
unmasked_score, -1,
zt[..., None],
torch.zeros_like(unmasked_score[..., :1]))
unmasked_score[:, :, self.mask_index] = - (log_k[:, None] * torch.ones_like(zt))
masked_indices = (zt == self.mask_index).to(model_output.dtype)[:, :, None]
model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
return model_output.exp()
def _staggered_score(self, score, dsigma):
score = score.clone()
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
score *= dsigma.exp()[:, None]
score[..., self.mask_index] += extra_const
return score
def _analytic_update(self, x, t, step_size, attn_mask=None):
curr_sigma, _ = self.noise(t)
next_sigma, _ = self.noise(t - step_size)
dsigma = curr_sigma - next_sigma
score = self.get_score(x, attn_mask, curr_sigma)
stag_score = self._staggered_score(score, dsigma)
probs = stag_score * self._transp_transition(x, dsigma)
return _sample_categorical(probs)
def _denoiser_update(self, x, t):
sigma, _ = self.noise(t)
score = self.get_score(x, sigma)
stag_score = self._staggered_score(score, sigma)
probs = stag_score * self._transp_transition(x, sigma)
probs[..., self.mask_index] = 0
samples = _sample_categorical(probs)
return samples
def _transp_transition(self, i, sigma):
sigma = unsqueeze(sigma, reference=i[..., None])
edge = torch.exp(-sigma) * F.one_hot(
i, num_classes=self.vocab_size)
edge += torch.where(i == self.mask_index,
1 - torch.exp(-sigma).squeeze(-1),
0)[..., None]
return edge
"""TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
def on_train_epoch_start(self):
torch.cuda.empty_cache()
self.backbone.train()
self.noise.train()
def training_step(self, batch, batch_idx):
# Initialize throughput calculation
start_time = time.time()
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
else:
loss = self._compute_loss(batch, prefix='train')
self.log(name='trainer/loss',
value=loss.item(),
on_step=True,
on_epoch=False,
sync_dist=True)
# Calculate throughput
elapsed_time = time.time() - start_time
total_tokens = batch['input_ids'].numel()
throughput = total_tokens / elapsed_time
self.log(name='trainer/throughput',
value=throughput,
on_step=True,
on_epoch=False,
sync_dist=True)
return loss
def on_load_checkpoint(self, checkpoint):
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
### VALIDATION ###
def on_validation_epoch_start(self):
gc.collect()
torch.cuda.empty_cache()
self.backbone.eval()
self.noise.eval()
assert self.valid_metrics.nll.mean_value == 0
assert self.valid_metrics.nll.weight == 0
def validation_step(self, batch, batch_idx):
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
else:
loss = self._compute_loss(batch, prefix='val')
self.log(name='trainer/val_loss',
value=loss.item(),
on_step=True,
on_epoch=False,
prog_bar=True,
sync_dist=True)
return loss
def on_validation_epoch_end(self):
gc.collect()
torch.cuda.empty_cache()
### OPTIMIZATION ###
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
gc.collect()
torch.cuda.empty_cache()
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
itertools.chain(self.backbone.parameters(),self.noise.parameters()),
lr=self.config.optim.lr,
betas=(self.config.optim.beta1, self.config.optim.beta2),
eps=self.config.optim.eps,
weight_decay=self.config.optim.weight_decay
)
self.total_steps = self.config.trainer.max_steps
scheduler = CosineWarmup(optimizer,
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
total_steps=self.total_steps)
scheduler_dict = {
'scheduler': scheduler,
'interval': 'step',
'frequency': 1,
'monitor': 'val/loss',
'name': 'trainer/lr'
}
return [optimizer], [scheduler_dict]
@torch.no_grad()
def compute_masked_perplexity(self, generated_ids, input_ids):
"""
Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
"""
total_nll = 0
total_tokens = 0
input_ids = torch.tensor(input_ids).to(self.device)
#print(input_ids)
for sequence in generated_ids:
# tokenize the sequence
gt_ids = torch.tensor(sequence).to(self.device)
#print(gt_ids)
sys.stdout.flush()
# forward pass thorugh backbone peptideclm model
attn_mask = torch.ones_like(input_ids).to(self.device)
# compute logits using backbone
if self.config.mode in ['train', 'ppl_eval']:
outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
elif self.config.mode == 'sample_eval':
outputs = self.backbone.forward(input_ids=input_ids)
# get logits for each position in sequence across all tokens in vocab
#logits = outputs[-1] # (batch_size, seq_length, vocab_size)
logits = outputs.view(-1, outputs.size(-1))
gt_ids = gt_ids.view(-1)
#print(logits.shape)
#print(gt_ids.shape)
# compute loss
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
# shift_labels = input_ids[:, 1:].contiguous()
# print(masked)
loss = F.cross_entropy(logits,
gt_ids.where(input_ids==self.mask_index, torch.full_like(gt_ids, -100)).view(-1),
reduction='sum')
total_nll += loss.item()
# count all non-padding tokens
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
# compute pseudo-perplexity
# print(total_nll, ",;,", total_tokens)
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
self.gen_ppl_metric.update(pseudo_perplexity)
return pseudo_perplexity.item()
def unsqueeze(x, reference):
return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
class CosineWarmup(_LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
super(CosineWarmup, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
return [decayed_lr * base_lr for base_lr in self.base_lrs]