Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import json | |
| import pandas as pd | |
| from openai import OpenAI | |
| import yaml | |
| from typing import Optional, List, Dict, Tuple, Any | |
| from topk_sae import FastAutoencoder | |
| import torch | |
| import plotly.express as px | |
| from collections import Counter | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import os | |
| print(os.getenv('MODEL_REPO_ID')) | |
| # Constants | |
| EMBEDDING_MODEL = "text-embedding-3-small" | |
| d_model = 1536 | |
| n_dirs = d_model * 6 | |
| k = 64 | |
| auxk = 128 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| torch.set_grad_enabled(False) | |
| # Function to download all necessary files | |
| def download_all_files(): | |
| files_to_download = [ | |
| "astroPH_paper_metadata.csv", | |
| "csLG_feature_analysis_results_64.json", | |
| "astroPH_topk_indices_64_9216_int32.npy", | |
| "astroPH_64_9216.pth", | |
| "astroPH_topk_values_64_9216_float16.npy", | |
| "csLG_abstract_texts.json", | |
| "csLG_topk_values_64_9216_float16.npy", | |
| "csLG_abstract_embeddings_float16.npy", | |
| "csLG_paper_metadata.csv", | |
| "csLG_64_9216.pth", | |
| "astroPH_abstract_texts.json", | |
| "astroPH_feature_analysis_results_64.json", | |
| "csLG_topk_indices_64_9216_int32.npy", | |
| "astroPH_abstract_embeddings_float16.npy" | |
| ] | |
| for file in files_to_download: | |
| local_path = os.path.join("data", file) | |
| os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
| hf_hub_download(repo_id="charlieoneill/saerch-ai-data", filename=file, local_dir="data") | |
| print(f"Downloaded {file}") | |
| # Load configuration and initialize OpenAI client | |
| download_all_files() | |
| # config = yaml.safe_load(open('../config.yaml', 'r')) | |
| # client = OpenAI(api_key=config['jwu_openai_key']) | |
| # Load the API key from the environment variable | |
| api_key = os.getenv('openai_key') | |
| # Ensure the API key is set | |
| if not api_key: | |
| raise ValueError("The environment variable 'openai_key' is not set.") | |
| # Initialize the OpenAI client with the API key | |
| client = OpenAI(api_key=api_key) | |
| # Function to load data for a specific subject | |
| def load_subject_data(subject): | |
| # embeddings_path = f"data/{subject}_abstract_embeddings.npy" | |
| # texts_path = f"data/{subject}_abstract_texts.json" | |
| # feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json" | |
| # metadata_path = f'data/{subject}_paper_metadata.csv' | |
| # topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}.npy" | |
| # topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}.npy" | |
| embeddings_path = f"data/{subject}_abstract_embeddings_float16.npy" | |
| texts_path = f"data/{subject}_abstract_texts.json" | |
| feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json" | |
| metadata_path = f'data/{subject}_paper_metadata.csv' | |
| topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy" | |
| topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy" | |
| # abstract_embeddings = np.load(embeddings_path) | |
| # with open(texts_path, 'r') as f: | |
| # abstract_texts = json.load(f) | |
| # with open(feature_analysis_path, 'r') as f: | |
| # feature_analysis = json.load(f) | |
| # df_metadata = pd.read_csv(metadata_path) | |
| # topk_indices = np.load(topk_indices_path) | |
| # topk_values = np.load(topk_values_path) | |
| abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32 | |
| with open(texts_path, 'r') as f: | |
| abstract_texts = json.load(f) | |
| with open(feature_analysis_path, 'r') as f: | |
| feature_analysis = json.load(f) | |
| df_metadata = pd.read_csv(metadata_path) | |
| topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed | |
| topk_values = np.load(topk_values_path).astype(np.float32) | |
| model_filename = f"{subject}_64_9216.pth" | |
| model_path = os.path.join("data", model_filename) | |
| ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik=0).to(device) | |
| ae.load_state_dict(torch.load(model_path)) | |
| ae.eval() | |
| weights = torch.load(model_path) | |
| decoder = weights['decoder.weight'].cpu().numpy() | |
| del weights | |
| return { | |
| 'abstract_embeddings': abstract_embeddings, | |
| 'abstract_texts': abstract_texts, | |
| 'feature_analysis': feature_analysis, | |
| 'df_metadata': df_metadata, | |
| 'topk_indices': topk_indices, | |
| 'topk_values': topk_values, | |
| 'ae': ae, | |
| 'decoder': decoder | |
| } | |
| # Load data for both subjects | |
| subject_data = { | |
| 'astroPH': load_subject_data('astroPH'), | |
| 'csLG': load_subject_data('csLG') | |
| } | |
| # Update existing functions to use the selected subject's data | |
| def get_embedding(text: Optional[str], model: str = EMBEDDING_MODEL) -> Optional[np.ndarray]: | |
| try: | |
| embedding = client.embeddings.create(input=[text], model=model).data[0].embedding | |
| return np.array(embedding, dtype=np.float32) | |
| except Exception as e: | |
| print(f"Error getting embedding: {e}") | |
| return None | |
| def intervened_hidden_to_intervened_embedding(topk_indices, topk_values, ae): | |
| with torch.no_grad(): | |
| return ae.decode_sparse(topk_indices, topk_values) | |
| # Function definitions for feature activation, co-occurrence, styling, etc. | |
| def get_feature_activations(subject, feature_index, m=5, min_length=100): | |
| abstract_texts = subject_data[subject]['abstract_texts'] | |
| abstract_embeddings = subject_data[subject]['abstract_embeddings'] | |
| topk_indices = subject_data[subject]['topk_indices'] | |
| topk_values = subject_data[subject]['topk_values'] | |
| doc_ids = abstract_texts['doc_ids'] | |
| abstracts = abstract_texts['abstracts'] | |
| feature_mask = topk_indices == feature_index | |
| activated_indices = np.where(feature_mask.any(axis=1))[0] | |
| activation_values = np.where(feature_mask, topk_values, 0).max(axis=1) | |
| sorted_activated_indices = activated_indices[np.argsort(-activation_values[activated_indices])] | |
| top_m_abstracts = [] | |
| top_m_indices = [] | |
| for i in sorted_activated_indices: | |
| if len(abstracts[i]) > min_length: | |
| top_m_abstracts.append((doc_ids[i], abstracts[i], activation_values[i])) | |
| top_m_indices.append(i) | |
| if len(top_m_abstracts) == m: | |
| break | |
| return top_m_abstracts | |
| def calculate_co_occurrences(subject, target_index, n_features=9216): | |
| topk_indices = subject_data[subject]['topk_indices'] | |
| mask = np.any(topk_indices == target_index, axis=1) | |
| co_occurring_indices = topk_indices[mask].flatten() | |
| co_occurrences = Counter(co_occurring_indices) | |
| del co_occurrences[target_index] | |
| result = np.zeros(n_features, dtype=int) | |
| result[list(co_occurrences.keys())] = list(co_occurrences.values()) | |
| return result | |
| def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame: | |
| cosine_values = df['Cosine similarity'].astype(float) | |
| min_val = cosine_values.min() | |
| max_val = cosine_values.max() | |
| def color_similarity(val): | |
| val = float(val) | |
| # Normalize the value between 0 and 1 | |
| if is_top: | |
| normalized_val = (val - min_val) / (max_val - min_val) | |
| else: | |
| # For bottom correlated, reverse the normalization | |
| normalized_val = (max_val - val) / (max_val - min_val) | |
| # Adjust the color intensity to avoid zero intensity | |
| color_intensity = 0.2 + (normalized_val * 0.8) # This ensures the range is from 0.2 to 1.0 | |
| if is_top: | |
| color = f'background-color: rgba(0, 255, 0, {color_intensity:.2f})' | |
| else: | |
| color = f'background-color: rgba(255, 0, 0, {color_intensity:.2f})' | |
| return color | |
| return df.style.applymap(color_similarity, subset=['Cosine similarity']) | |
| def get_feature_from_index(subject, index): | |
| feature = next((f for f in subject_data[subject]['feature_analysis'] if f['index'] == index), None) | |
| return feature | |
| def visualize_feature(subject, index): | |
| feature = next((f for f in subject_data[subject]['feature_analysis'] if f['index'] == index), None) | |
| if feature is None: | |
| return "Invalid feature index", None, None, None, None, None, None | |
| output = f"# {feature['label']}\n\n" | |
| output += f"* Pearson correlation: {feature['pearson_correlation']:.4f}\n\n" | |
| output += f"* Density: {feature['density']:.4f}\n\n" | |
| # Top m abstracts | |
| top_m_abstracts = get_feature_activations(subject, index) | |
| # Create dataframe for top abstracts | |
| df_data = [ | |
| {"Title": m[1].split('\n\n')[0], "Activation value": f"{m[2]:.4f}"} | |
| for m in top_m_abstracts | |
| ] | |
| df_top_abstracts = pd.DataFrame(df_data) | |
| # Activation value distribution | |
| topk_indices = subject_data[subject]['topk_indices'] | |
| topk_values = subject_data[subject]['topk_values'] | |
| activation_values = np.where(topk_indices == index, topk_values, 0).max(axis=1) | |
| fig2 = px.histogram(x=activation_values, nbins=50) | |
| fig2.update_layout( | |
| #title=f'{feature["label"]}', | |
| xaxis_title='Activation value', | |
| yaxis_title=None, | |
| yaxis_type='log', | |
| height=220, | |
| ) | |
| # Correlated features | |
| decoder = subject_data[subject]['decoder'] | |
| feature_vector = decoder[:, index] | |
| decoder_without_feature = np.delete(decoder, index, axis=1) | |
| cosine_similarities = np.dot(feature_vector, decoder_without_feature) / (np.linalg.norm(decoder_without_feature, axis=0) * np.linalg.norm(feature_vector)) | |
| topk = 5 | |
| topk_indices_cosine = np.argsort(-cosine_similarities)[:topk] | |
| topk_values_cosine = cosine_similarities[topk_indices_cosine] | |
| # Create dataframe for top 5 correlated features | |
| df_top_correlated = pd.DataFrame({ | |
| "Feature": [get_feature_from_index(subject, i)['label'] for i in topk_indices_cosine], | |
| "Cosine similarity": [f"{v:.4f}" for v in topk_values_cosine] | |
| }) | |
| df_top_correlated_styled = style_dataframe(df_top_correlated, is_top=True) | |
| bottomk = 5 | |
| bottomk_indices_cosine = np.argsort(cosine_similarities)[:bottomk] | |
| bottomk_values_cosine = cosine_similarities[bottomk_indices_cosine] | |
| # Create dataframe for bottom 5 correlated features | |
| df_bottom_correlated = pd.DataFrame({ | |
| "Feature": [get_feature_from_index(subject, i)['label'] for i in bottomk_indices_cosine], | |
| "Cosine similarity": [f"{v:.4f}" for v in bottomk_values_cosine] | |
| }) | |
| df_bottom_correlated_styled = style_dataframe(df_bottom_correlated, is_top=False) | |
| # Co-occurrences | |
| co_occurrences = calculate_co_occurrences(subject, index) | |
| topk = 5 | |
| topk_indices_co_occurrence = np.argsort(-co_occurrences)[:topk] | |
| topk_values_co_occurrence = co_occurrences[topk_indices_co_occurrence] | |
| # Create dataframe for top 5 co-occurring features | |
| df_co_occurrences = pd.DataFrame({ | |
| "Feature": [get_feature_from_index(subject, i)['label'] for i in topk_indices_co_occurrence], | |
| "Co-occurrences": topk_values_co_occurrence | |
| }) | |
| return output, df_top_abstracts, df_top_correlated_styled, df_bottom_correlated_styled, df_co_occurrences, fig2 | |
| # Modify the main interface function | |
| def create_interface(): | |
| custom_css = """ | |
| #custom-slider-* { | |
| background-color: #ffe6e6; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| subject = gr.Dropdown(choices=['astroPH', 'csLG'], label="Select Subject", value='astroPH') | |
| with gr.Tabs(): | |
| with gr.Tab("SAErch"): | |
| input_text = gr.Textbox(label="input") | |
| search_results_state = gr.State([]) | |
| feature_values_state = gr.State([]) | |
| feature_indices_state = gr.State([]) | |
| manually_added_features_state = gr.State([]) | |
| def update_search_results(feature_values, feature_indices, manually_added_features, current_subject): | |
| ae = subject_data[current_subject]['ae'] | |
| abstract_embeddings = subject_data[current_subject]['abstract_embeddings'] | |
| abstract_texts = subject_data[current_subject]['abstract_texts'] | |
| df_metadata = subject_data[current_subject]['df_metadata'] | |
| # Combine manually added features with query-generated features | |
| all_indices = [] | |
| all_values = [] | |
| # Add manually added features first | |
| for index in manually_added_features: | |
| if index not in all_indices: | |
| all_indices.append(index) | |
| all_values.append(feature_values[feature_indices.index(index)] if index in feature_indices else 0.0) | |
| # Add remaining query-generated features | |
| for index, value in zip(feature_indices, feature_values): | |
| if index not in all_indices: | |
| all_indices.append(index) | |
| all_values.append(value) | |
| # Reconstruct query embedding | |
| topk_indices = torch.tensor(all_indices).to(device) | |
| topk_values = torch.tensor(all_values).to(device) | |
| intervened_embedding = intervened_hidden_to_intervened_embedding(topk_indices, topk_values, ae) | |
| intervened_embedding = intervened_embedding.cpu().numpy().flatten() | |
| # Perform similarity search | |
| sims = np.dot(abstract_embeddings, intervened_embedding) | |
| topk_indices_search = np.argsort(sims)[::-1][:10] | |
| doc_ids = abstract_texts['doc_ids'] | |
| topk_doc_ids = [doc_ids[i] for i in topk_indices_search] | |
| # Prepare search results | |
| search_results = [] | |
| for doc_id in topk_doc_ids: | |
| metadata = df_metadata[df_metadata['arxiv_id'] == doc_id].iloc[0] | |
| title = metadata['title'].replace('[', '').replace(']', '') | |
| search_results.append([ | |
| title, | |
| int(metadata['citation_count']), | |
| int(metadata['year']) | |
| ]) | |
| return search_results, all_values, all_indices | |
| def show_components(text, search_results, feature_values, feature_indices, manually_added_features, current_subject): | |
| if len(text) == 0: | |
| return gr.Markdown("## No Input Provided") | |
| if not search_results or text != getattr(show_components, 'last_query', None): | |
| show_components.last_query = text | |
| query_embedding = get_embedding(text) | |
| ae = subject_data[current_subject]['ae'] | |
| with torch.no_grad(): | |
| recons, z_dict = ae(torch.tensor(query_embedding).unsqueeze(0).to(device)) | |
| topk_indices = z_dict['topk_indices'][0].cpu().numpy() | |
| topk_values = z_dict['topk_values'][0].cpu().numpy() | |
| feature_values = topk_values.tolist() | |
| feature_indices = topk_indices.tolist() | |
| search_results, feature_values, feature_indices = update_search_results(feature_values, feature_indices, manually_added_features, current_subject) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| df = gr.Dataframe( | |
| headers=["Title", "Citation Count", "Year"], | |
| value=search_results, | |
| label="Top 10 Search Results" | |
| ) | |
| feature_search = gr.Textbox(label="Search Feature Labels") | |
| feature_matches = gr.CheckboxGroup(label="Matching Features", choices=[]) | |
| add_button = gr.Button("Add Selected Features") | |
| def search_feature_labels(search_text): | |
| if not search_text: | |
| return gr.CheckboxGroup(choices=[]) | |
| matches = [f"{f['label']} ({f['index']})" for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()] | |
| return gr.CheckboxGroup(choices=matches[:10]) | |
| feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches]) | |
| def on_add_features(selected_features, current_values, current_indices, manually_added_features): | |
| if selected_features: | |
| new_indices = [int(f.split('(')[-1].strip(')')) for f in selected_features] | |
| # Add new indices to manually_added_features if they're not already there | |
| manually_added_features = list(dict.fromkeys(manually_added_features + new_indices)) | |
| return gr.CheckboxGroup(value=[]), current_values, current_indices, manually_added_features | |
| return gr.CheckboxGroup(value=[]), current_values, current_indices, manually_added_features | |
| add_button.click( | |
| on_add_features, | |
| inputs=[feature_matches, feature_values_state, feature_indices_state, manually_added_features_state], | |
| outputs=[feature_matches, feature_values_state, feature_indices_state, manually_added_features_state] | |
| ) | |
| with gr.Column(scale=1): | |
| update_button = gr.Button("Update Results") | |
| sliders = [] | |
| for i, (value, index) in enumerate(zip(feature_values, feature_indices)): | |
| feature = next((f for f in subject_data[current_subject]['feature_analysis'] if f['index'] == index), None) | |
| label = f"{feature['label']} ({index})" if feature else f"Feature {index}" | |
| # Add prefix and change color for manually added features | |
| if index in manually_added_features: | |
| label = f"[Custom] {label}" | |
| slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=value, label=label, key=f"slider-{index}", elem_id=f"custom-slider-{index}") | |
| else: | |
| slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=value, label=label, key=f"slider-{index}") | |
| sliders.append(slider) | |
| def on_slider_change(*values): | |
| manually_added_features = values[-1] | |
| slider_values = list(values[:-1]) | |
| # Reconstruct feature_indices based on the order of sliders | |
| reconstructed_indices = [int(slider.label.split('(')[-1].split(')')[0]) for slider in sliders] | |
| new_results, new_values, new_indices = update_search_results(slider_values, reconstructed_indices, manually_added_features, current_subject) | |
| return new_results, new_values, new_indices, manually_added_features | |
| update_button.click( | |
| on_slider_change, | |
| inputs=sliders + [manually_added_features_state], | |
| outputs=[search_results_state, feature_values_state, feature_indices_state, manually_added_features_state] | |
| ) | |
| return [df, feature_search, feature_matches, add_button, update_button] + sliders | |
| with gr.Tab("Feature Visualisation"): | |
| gr.Markdown("# Feature Visualiser") | |
| with gr.Row(): | |
| feature_search = gr.Textbox(label="Search Feature Labels") | |
| feature_matches = gr.CheckboxGroup(label="Matching Features", choices=[]) | |
| visualize_button = gr.Button("Visualize Feature") | |
| feature_info = gr.Markdown() | |
| abstracts_heading = gr.Markdown("## Top 5 Abstracts") | |
| top_abstracts = gr.Dataframe( | |
| headers=["Title", "Activation value"], | |
| interactive=False | |
| ) | |
| gr.Markdown("## Correlated Features") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Top 5 Correlated Features") | |
| top_correlated = gr.Dataframe( | |
| headers=["Feature", "Cosine similarity"], | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Bottom 5 Correlated Features") | |
| bottom_correlated = gr.Dataframe( | |
| headers=["Feature", "Cosine similarity"], | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Top 5 Co-occurring Features") | |
| co_occurring_features = gr.Dataframe( | |
| headers=["Feature", "Co-occurrences"], | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"## Activation Value Distribution") | |
| activation_dist = gr.Plot() | |
| def search_feature_labels(search_text, current_subject): | |
| if not search_text: | |
| return gr.CheckboxGroup(choices=[]) | |
| matches = [f"{f['label']} ({f['index']})" for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()] | |
| return gr.CheckboxGroup(choices=matches[:10]) | |
| feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches]) | |
| def on_visualize(selected_features, current_subject): | |
| if not selected_features: | |
| return "Please select a feature to visualize.", None, None, None, None, None, "", [] | |
| # Extract the feature index from the selected feature string | |
| feature_index = int(selected_features[0].split('(')[-1].strip(')')) | |
| feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist = visualize_feature(current_subject, feature_index) | |
| # Return the visualization results along with empty values for search box and checkbox | |
| return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", [] | |
| visualize_button.click( | |
| on_visualize, | |
| inputs=[feature_matches, subject], | |
| outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches] | |
| ) | |
| # Add logic to update components when subject changes | |
| def on_subject_change(new_subject): | |
| # Clear all states and return empty values for all components | |
| return [], [], [], [], "", [], "", [], None, None, None, None, None, None | |
| subject.change( | |
| on_subject_change, | |
| inputs=[subject], | |
| outputs=[search_results_state, feature_values_state, feature_indices_state, manually_added_features_state, | |
| input_text, feature_matches, feature_search, feature_matches, | |
| feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist] | |
| ) | |
| return demo | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() | |