#-----------------------------------------libraries-------------------------------------- import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import gradio as gr import matplotlib.pyplot as plt import pandas as pd from sklearn.decomposition import PCA from sklearn.manifold import TSNE from Bio import SeqIO import io # Load ESM2 model and tokenizer MODEL_NAME = "facebook/esm2_t36_3B_UR50D" model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model.eval() torch.set_grad_enabled(False) def get_representation_vectors(fasta_file, target_csv_file, method="PCA"): # Read and parse FASTA fasta_records = list(SeqIO.parse(io.StringIO(fasta_file.read().decode()), "fasta")) seq_dict = {rec.id: str(rec.seq) for rec in fasta_records} # Read CSV (expected: seq_id,residue_index,aa_type,metal) csv_df = pd.read_csv(target_csv_file) # Store vectors and metal types embedding_vectors = [] metal_labels = [] for _, row in csv_df.iterrows(): seq_id = row["seq_id"] res_idx = int(row["residue_index"]) # 1-based metal = row["metal"] if seq_id not in seq_dict: continue sequence = seq_dict[seq_id] if res_idx < 1 or res_idx > len(sequence): continue # Tokenize inputs = tokenizer(sequence, return_tensors="pt") outputs = model(**inputs) hidden_states = outputs.hidden_states[-1] # final layer # Get vector for target residue (accounting for CLS token) vector = hidden_states[0, res_idx, :].numpy() embedding_vectors.append(vector) metal_labels.append(metal) if len(embedding_vectors) == 0: raise ValueError("No valid target residues found. Please check your input files.") # Perform dimensionality reduction X = torch.tensor(embedding_vectors) if method == "t-SNE": reducer = TSNE(n_components=2, random_state=0) else: reducer = PCA(n_components=2) X_2d = reducer.fit_transform(X) # Plot plt.figure(figsize=(8, 6)) metals = list(set(metal_labels)) for metal in metals: idxs = [i for i, m in enumerate(metal_labels) if m == metal] plt.scatter(X_2d[idxs, 0], X_2d[idxs, 1], label=metal) plt.title(f"Target Residue Embeddings ({method})") plt.xlabel("Component 1") plt.ylabel("Component 2") plt.legend() plt.tight_layout() return plt.gcf() # Gradio interface gr.Interface( fn=get_representation_vectors, inputs=[ gr.File(label="FASTA File"), gr.File(label="Target Residue CSV"), gr.Radio(choices=["PCA", "t-SNE"], value="PCA", label="Dimensionality Reduction Method") ], outputs=gr.Plot(label="Embedding Space"), title="Metal-Binding Residue Embedding Visualizer", description="Upload a FASTA file and a CSV file of metal-binding residues. This tool will visualize the high-dimensional embedding of the target residues using PCA or t-SNE." ).launch()