File size: 3,507 Bytes
89b6f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#!/usr/bin/env python3
import torch
import fire
import json
from pathlib import Path
import sys
from nGPT_pytorch import nGPT
def exists(v):
return v is not None
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
def log(t, eps=1e-20):
return torch.log(t.clamp(min=eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(
dim=dim, keepdim=keepdim
)
def min_p_filter(logits, min_p=0.1):
probs = logits.softmax(dim=-1)
max_probs = probs.amax(dim=-1, keepdim=True)
limit = min_p * max_probs
return torch.where(probs < limit, float("-inf"), logits)
def base_decoding(
net,
prompt: torch.Tensor,
seq_len: int,
temperature=1.5,
min_p=1e-1,
filter_thres=0.9,
):
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
sample_num_times = max(0, seq_len - prompt_seq_len)
for _ in range(sample_num_times):
logits = net(out)
logits = logits[:, -1]
logits = min_p_filter(logits, min_p=min_p)
sample = gumbel_sample(logits, temperature=temperature, dim=-1)
out = torch.cat((out, sample), dim=-1)
return out[..., prompt_seq_len:]
def main(
checkpoint_path: str,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 1.0,
min_p: float = 0.1,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Generate text using a trained nGPT model."""
# Load checkpoint
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found at {checkpoint_path}")
sys.exit(1)
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
# Get config from checkpoint or file
config = checkpoint.get("config", {})
if not config and checkpoint_path.parent.joinpath("config.json").exists():
with open(checkpoint_path.parent.joinpath("config.json")) as f:
config = json.load(f)
use_parametrize = config.get("use_parametrize", True)
# Initialize model
model = nGPT(
num_tokens=256,
dim=512,
depth=8,
tied_embedding=True,
add_value_residual=True,
attn_norm_qk=False,
manual_norm_weights=not use_parametrize,
).to(device)
# Load weights
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("\nModel loaded successfully. Generating with:")
print(f" Temperature: {temperature}")
print(f" Min-p: {min_p}")
print(f" Max new tokens: {max_new_tokens}")
# Convert prompt to tensor
prompt_tensor = torch.tensor(
[ord(c) for c in prompt], dtype=torch.long, device=device
)
prompt_tensor = prompt_tensor.unsqueeze(0)
# Generate
with torch.no_grad():
sampled = base_decoding(
model,
prompt_tensor,
seq_len=max_new_tokens,
temperature=temperature,
min_p=min_p,
)
generated = decode_tokens(sampled[0])
print("\nGenerated text:")
print("-" * 80)
print(prompt + generated)
print("-" * 80)
return generated
if __name__ == "__main__":
fire.Fire(main)
|