File size: 6,570 Bytes
303c2e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# direct reward backpropagation
from diffusion import Diffusion
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
import numpy as np
import oracle
from scipy.stats import pearsonr
import torch
import torch.nn.functional as F
import argparse
import wandb
import os
import datetime
from utils import str2bool, set_seed
from finetune_dna import finetune
from mcts import MCTS
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument('--base_path', type=str, default="")
argparser.add_argument('--learning_rate', type=float, default=1e-4)
argparser.add_argument('--num_epochs', type=int, default=100)
argparser.add_argument('--num_accum_steps', type=int, default=4)
argparser.add_argument('--truncate_steps', type=int, default=50)
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
argparser.add_argument('--batch_size', type=int, default=32)
argparser.add_argument('--name', type=str, default='debug')
argparser.add_argument('--total_num_steps', type=int, default=128)
argparser.add_argument('--copy_flag_temp', type=float, default=None)
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
argparser.add_argument('--eval_every_n_epochs', type=int, default=200)
argparser.add_argument('--alpha', type=float, default=0.001)
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
argparser.add_argument("--seed", type=int, default=0)
# new
argparser.add_argument('--run_name', type=str, default='drakes')
argparser.add_argument("--device", default="cuda:0", type=str)
argparser.add_argument("--save_path_dir", default=None, type=str)
argparser.add_argument("--no_mcts", action='store_true', default=False)
argparser.add_argument("--centering", action='store_true', default=False)
argparser.add_argument("--reward_clip", action='store_true', default=False)
argparser.add_argument("--reward_clip_value", type=float, default=15.0)
argparser.add_argument("--select_topk", action='store_true', default=False)
argparser.add_argument('--select_topk_value', type=int, default=10)
argparser.add_argument("--restart_ckpt_path", type=str, default=None)
# mcts
argparser.add_argument('--num_sequences', type=int, default=10)
argparser.add_argument('--num_children', type=int, default=50)
argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
argparser.add_argument('--seq_length', type=int, default=200)
argparser.add_argument('--time_conditioning', action='store_true', default=False)
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
argparser.add_argument('--buffer_size', type=int, default=100)
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
argparser.add_argument('--noise_removal', action='store_true', default=False)
argparser.add_argument('--grad_clip', action='store_true', default=False)
argparser.add_argument('--resample_every_n_step', type=int, default=10)
argparser.add_argument('--exploration', type=float, default=0.1)
argparser.add_argument('--reset_tree', action='store_true', default=False)
# eval
args = argparser.parse_args()
print(args)
# pretrained model path
CKPT_PATH = os.path.join(args.base_path, 'mdlm/outputs_gosai/pretrained.ckpt')
log_base_dir = os.path.join(args.save_path_dir, 'mdlm/reward_bp_results_final')
# reinitialize Hydra
GlobalHydra.instance().clear()
# Initialize Hydra and compose the configuration
initialize(config_path="configs_gosai", job_name="load_model")
cfg = compose(config_name="config_gosai.yaml")
cfg.eval.checkpoint_path = CKPT_PATH
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if args.no_mcts:
run_name = f'MDNS_buffer{args.buffer_size}_alpha{args.alpha}_resample{args.resample_every_n_step}_centering{args.centering}_{curr_time}'
else:
run_name = f'MCTS_buffer{args.buffer_size}_alpha{args.alpha}_resample{args.resample_every_n_step}_num_iter{args.num_iter}_centering{args.centering}_select_topk{args.select_topk}_select_topk_value{args.select_topk_value}_{curr_time}'
args.save_path = os.path.join(args.save_path_dir, run_name)
os.makedirs(args.save_path, exist_ok=True)
# wandb init
wandb.init(project='search-rl', name=run_name, config=args, dir=args.save_path)
log_path = os.path.join(args.save_path, 'log.txt')
set_seed(args.seed, use_cuda=True)
# Initialize the model
if args.restart_ckpt_path is not None:
# Resume from saved ckpt
restart_ckpt_path = os.path.join(args.base_path, args.restart_ckpt_path)
restart_epoch = restart_ckpt_path.split('_')[-1].split('.')[0]
args.restart_epoch = restart_epoch
policy_model = Diffusion(cfg).to(args.device)
policy_model.load_state_dict(torch.load(restart_ckpt_path, map_location=args.device))
else:
# Start from pretrained model
policy_model = Diffusion.load_from_checkpoint(cfg.eval.checkpoint_path, config=cfg, map_location=args.device)
pretrained = Diffusion.load_from_checkpoint(cfg.eval.checkpoint_path, config=cfg, map_location=args.device)
reward_model = oracle.get_gosai_oracle(mode='train', device=args.device)
#reward_model_eval = oracle.get_gosai_oracle(mode='eval').to(args.device)
reward_model.eval()
pretrained.eval()
#reward_model_eval.eval()
# define mcts
mcts = MCTS(args, cfg, policy_model, pretrained, reward_model)
_, _, highexp_kmers_999, n_highexp_kmers_999, _, _, _ = oracle.cal_highexp_kmers(return_clss=True)
cal_atac_pred_new_mdl = oracle.get_cal_atac_orale(device=args.device)
cal_atac_pred_new_mdl.eval()
gosai_oracle = oracle.get_gosai_oracle(mode='eval', device=args.device)
gosai_oracle.eval()
print("args.device:", args.device)
print("policy_model device:", policy_model.device)
print("pretrained device:", pretrained.device)
print("reward_model device:", reward_model.device)
print("mcts device:", mcts.device)
print("gosai_oracle device:", gosai_oracle.device)
print("cal_atac_pred_new_mdl device:", cal_atac_pred_new_mdl.device)
eval_model_dict = {
"gosai_oracle": gosai_oracle,
"highexp_kmers_999": highexp_kmers_999,
"n_highexp_kmers_999": n_highexp_kmers_999,
"cal_atac_pred_new_mdl": cal_atac_pred_new_mdl,
"gosai_oracle": gosai_oracle
}
finetune(args = args, cfg = cfg, policy_model = policy_model,
reward_model = reward_model, mcts = mcts,
pretrained_model = pretrained,
eval_model_dict = eval_model_dict) |