Spaces:
Running
Running
Commit
·
d6eab4f
1
Parent(s):
3187d23
feature families
Browse files- .gitignore +3 -1
- app.py +382 -27
.gitignore
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
-
data/
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
__pycache__
|
| 3 |
+
__pycache__/
|
app.py
CHANGED
|
@@ -1,3 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import json
|
|
@@ -11,6 +140,10 @@ import plotly.express as px
|
|
| 11 |
from collections import Counter
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
import os
|
| 16 |
print(os.getenv('MODEL_REPO_ID'))
|
|
@@ -44,7 +177,15 @@ def download_all_files():
|
|
| 44 |
# "csLG_clean_families_64_9216.json",
|
| 45 |
# "astroPH_clean_families_64_9216.json",
|
| 46 |
"astroPH_family_analysis_64_9216.json",
|
| 47 |
-
"csLG_family_analysis_64_9216.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
]
|
| 49 |
|
| 50 |
for file in files_to_download:
|
|
@@ -74,9 +215,13 @@ def load_subject_data(subject):
|
|
| 74 |
feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
|
| 75 |
metadata_path = f'data/{subject}_paper_metadata.csv'
|
| 76 |
topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
|
|
|
|
| 77 |
topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
|
| 78 |
families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
|
| 79 |
family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
|
| 82 |
with open(texts_path, 'r') as f:
|
|
@@ -86,6 +231,7 @@ def load_subject_data(subject):
|
|
| 86 |
df_metadata = pd.read_csv(metadata_path)
|
| 87 |
topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
|
| 88 |
topk_values = np.load(topk_values_path).astype(np.float32)
|
|
|
|
| 89 |
|
| 90 |
model_filename = f"{subject}_64_9216.pth"
|
| 91 |
model_path = os.path.join("data", model_filename)
|
|
@@ -109,6 +255,9 @@ def load_subject_data(subject):
|
|
| 109 |
'df_metadata': df_metadata,
|
| 110 |
'topk_indices': topk_indices,
|
| 111 |
'topk_values': topk_values,
|
|
|
|
|
|
|
|
|
|
| 112 |
'ae': ae,
|
| 113 |
'decoder': decoder,
|
| 114 |
# 'feature_families': feature_families,
|
|
@@ -163,13 +312,15 @@ def get_feature_activations(subject, feature_index, m=5, min_length=100):
|
|
| 163 |
|
| 164 |
def calculate_co_occurrences(subject, target_index, n_features=9216):
|
| 165 |
topk_indices = subject_data[subject]['topk_indices']
|
|
|
|
| 166 |
|
| 167 |
mask = np.any(topk_indices == target_index, axis=1)
|
| 168 |
co_occurring_indices = topk_indices[mask].flatten()
|
| 169 |
co_occurrences = Counter(co_occurring_indices)
|
| 170 |
del co_occurrences[target_index]
|
| 171 |
-
result = np.zeros(n_features, dtype=
|
| 172 |
result[list(co_occurrences.keys())] = list(co_occurrences.values())
|
|
|
|
| 173 |
return result
|
| 174 |
|
| 175 |
def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
|
|
@@ -291,10 +442,175 @@ def visualize_feature(subject, index):
|
|
| 291 |
"Co-occurrences": topk_values_co_occurrence
|
| 292 |
})
|
| 293 |
df_co_occurrences_styled = df_co_occurrences.style.format({
|
| 294 |
-
"Co-occurrences": "{:.
|
| 295 |
})
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
# Modify the main interface function
|
| 300 |
def create_interface():
|
|
@@ -453,7 +769,10 @@ def create_interface():
|
|
| 453 |
def search_feature_labels(search_text):
|
| 454 |
if not search_text:
|
| 455 |
return gr.CheckboxGroup(choices=[])
|
| 456 |
-
matches = [f
|
|
|
|
|
|
|
|
|
|
| 457 |
return gr.CheckboxGroup(choices=matches[:10])
|
| 458 |
|
| 459 |
feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
|
|
@@ -536,24 +855,24 @@ def create_interface():
|
|
| 536 |
wrap=True
|
| 537 |
)
|
| 538 |
|
| 539 |
-
gr.Markdown("##
|
| 540 |
with gr.Row():
|
| 541 |
with gr.Column(scale=1):
|
| 542 |
-
gr.Markdown("###
|
| 543 |
-
|
| 544 |
-
headers=["Feature", "Cosine
|
| 545 |
interactive=False
|
| 546 |
)
|
| 547 |
with gr.Column(scale=1):
|
| 548 |
-
gr.Markdown("###
|
| 549 |
-
|
| 550 |
-
headers=["Feature", "Cosine
|
| 551 |
interactive=False
|
| 552 |
)
|
| 553 |
-
|
| 554 |
with gr.Row():
|
| 555 |
with gr.Column(scale=1):
|
| 556 |
-
gr.Markdown("## Top
|
| 557 |
co_occurring_features = gr.Dataframe(
|
| 558 |
headers=["Feature", "Co-occurrences"],
|
| 559 |
interactive=False
|
|
@@ -562,10 +881,31 @@ def create_interface():
|
|
| 562 |
gr.Markdown(f"## Activation Value Distribution")
|
| 563 |
activation_dist = gr.Plot()
|
| 564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
def search_feature_labels(search_text, current_subject):
|
| 566 |
if not search_text:
|
| 567 |
return gr.CheckboxGroup(choices=[])
|
| 568 |
-
matches = [f
|
|
|
|
|
|
|
|
|
|
| 569 |
return gr.CheckboxGroup(choices=matches[:10])
|
| 570 |
|
| 571 |
feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
|
|
@@ -576,15 +916,15 @@ def create_interface():
|
|
| 576 |
|
| 577 |
# Extract the feature index from the selected feature string
|
| 578 |
feature_index = int(selected_features[0].split('(')[-1].strip(')'))
|
| 579 |
-
feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist = visualize_feature(current_subject, feature_index)
|
| 580 |
|
| 581 |
# Return the visualization results along with empty values for search box and checkbox
|
| 582 |
-
return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", []
|
| 583 |
|
| 584 |
visualize_button.click(
|
| 585 |
on_visualize,
|
| 586 |
inputs=[feature_matches, subject],
|
| 587 |
-
outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches]
|
| 588 |
)
|
| 589 |
|
| 590 |
with gr.Tab("Feature Families"):
|
|
@@ -595,19 +935,26 @@ def create_interface():
|
|
| 595 |
family_matches = gr.CheckboxGroup(label="Matching Feature Families", choices=[])
|
| 596 |
visualize_family_button = gr.Button("Visualize Feature Family")
|
| 597 |
|
| 598 |
-
|
| 599 |
family_dataframe = gr.Dataframe(
|
| 600 |
-
headers=["Feature", "F1 Score", "Pearson
|
| 601 |
-
datatype=["markdown", "number", "number"],
|
| 602 |
label="Family and Child Features"
|
| 603 |
)
|
| 604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
def search_feature_families(search_text, current_subject):
|
| 607 |
family_analysis = subject_data[current_subject]['family_analysis']
|
| 608 |
if not search_text:
|
| 609 |
return gr.CheckboxGroup(choices=[])
|
| 610 |
-
matches = [family
|
|
|
|
|
|
|
|
|
|
| 611 |
return gr.CheckboxGroup(choices=matches[:10]) # Limit to top 10 matches
|
| 612 |
|
| 613 |
def visualize_feature_family(selected_families, current_subject):
|
|
@@ -627,16 +974,20 @@ def create_interface():
|
|
| 627 |
df_data = [
|
| 628 |
{
|
| 629 |
"Feature": f"## {family_data['superfeature']}",
|
|
|
|
| 630 |
"F1 Score": round(family_data['family_f1'], 2),
|
| 631 |
-
"Pearson
|
| 632 |
},
|
| 633 |
]
|
| 634 |
|
| 635 |
-
|
|
|
|
|
|
|
| 636 |
df_data.append({
|
| 637 |
"Feature": name,
|
|
|
|
| 638 |
"F1 Score": round(f1, 2),
|
| 639 |
-
"Pearson
|
| 640 |
})
|
| 641 |
|
| 642 |
df = pd.DataFrame(df_data)
|
|
@@ -645,13 +996,17 @@ def create_interface():
|
|
| 645 |
output += "## Super Reasoning\n"
|
| 646 |
output += f"{family_data['super_reasoning']}\n\n"
|
| 647 |
|
| 648 |
-
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
family_search.change(search_feature_families, inputs=[family_search, subject], outputs=[family_matches])
|
| 651 |
visualize_family_button.click(
|
| 652 |
visualize_feature_family,
|
| 653 |
inputs=[family_matches, subject],
|
| 654 |
-
outputs=[family_info, family_dataframe, family_search, family_matches]
|
|
|
|
| 655 |
)
|
| 656 |
|
| 657 |
|
|
|
|
| 1 |
+
# import gradio as gr
|
| 2 |
+
# import numpy as np
|
| 3 |
+
# import json
|
| 4 |
+
# import pandas as pd
|
| 5 |
+
# from openai import OpenAI
|
| 6 |
+
# import yaml
|
| 7 |
+
# from typing import Optional, List, Dict, Tuple, Any
|
| 8 |
+
# from topk_sae import FastAutoencoder
|
| 9 |
+
# import torch
|
| 10 |
+
# import plotly.express as px
|
| 11 |
+
# from collections import Counter
|
| 12 |
+
# from huggingface_hub import hf_hub_download
|
| 13 |
+
# import os
|
| 14 |
+
# import networkx as nx
|
| 15 |
+
# import plotly.graph_objs as go
|
| 16 |
+
# from ast import literal_eval as make_tuple
|
| 17 |
+
# import random
|
| 18 |
+
|
| 19 |
+
# import os
|
| 20 |
+
# print(os.getenv('MODEL_REPO_ID'))
|
| 21 |
+
|
| 22 |
+
# # Constants
|
| 23 |
+
# EMBEDDING_MODEL = "text-embedding-3-small"
|
| 24 |
+
# d_model = 1536
|
| 25 |
+
# n_dirs = d_model * 6
|
| 26 |
+
# k = 64
|
| 27 |
+
# auxk = 128
|
| 28 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 29 |
+
# torch.set_grad_enabled(False)
|
| 30 |
+
|
| 31 |
+
# # Function to download all necessary files
|
| 32 |
+
# def download_all_files():
|
| 33 |
+
# files_to_download = [
|
| 34 |
+
# "astroPH_paper_metadata.csv",
|
| 35 |
+
# "csLG_feature_analysis_results_64.json",
|
| 36 |
+
# "astroPH_topk_indices_64_9216_int32.npy",
|
| 37 |
+
# "astroPH_64_9216.pth",
|
| 38 |
+
# "astroPH_topk_values_64_9216_float16.npy",
|
| 39 |
+
# "csLG_abstract_texts.json",
|
| 40 |
+
# "csLG_topk_values_64_9216_float16.npy",
|
| 41 |
+
# "csLG_abstract_embeddings_float16.npy",
|
| 42 |
+
# "csLG_paper_metadata.csv",
|
| 43 |
+
# "csLG_64_9216.pth",
|
| 44 |
+
# "astroPH_abstract_texts.json",
|
| 45 |
+
# "astroPH_feature_analysis_results_64.json",
|
| 46 |
+
# "csLG_topk_indices_64_9216_int32.npy",
|
| 47 |
+
# "astroPH_abstract_embeddings_float16.npy",
|
| 48 |
+
# # "csLG_clean_families_64_9216.json",
|
| 49 |
+
# # "astroPH_clean_families_64_9216.json",
|
| 50 |
+
# # "astroPH_family_analysis_64_9216.json",
|
| 51 |
+
# "csLG_family_analysis_64_9216.json"
|
| 52 |
+
# ]
|
| 53 |
+
|
| 54 |
+
# for file in files_to_download:
|
| 55 |
+
# local_path = os.path.join("data", file)
|
| 56 |
+
# os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
| 57 |
+
# hf_hub_download(repo_id="charlieoneill/saerch-ai-data", filename=file, local_dir="data")
|
| 58 |
+
# print(f"Downloaded {file}")
|
| 59 |
+
|
| 60 |
+
# # Load configuration and initialize OpenAI client
|
| 61 |
+
# download_all_files()
|
| 62 |
+
|
| 63 |
+
# # Load the API key from the environment variable
|
| 64 |
+
# api_key = os.getenv('openai_key')
|
| 65 |
+
|
| 66 |
+
# # Ensure the API key is set
|
| 67 |
+
# if not api_key:
|
| 68 |
+
# raise ValueError("The environment variable 'openai_key' is not set.")
|
| 69 |
+
|
| 70 |
+
# # Initialize the OpenAI client with the API key
|
| 71 |
+
# client = OpenAI(api_key=api_key)
|
| 72 |
+
|
| 73 |
+
# # Function to load data for a specific subject
|
| 74 |
+
# def load_subject_data(subject):
|
| 75 |
+
|
| 76 |
+
# embeddings_path = f"data/{subject}_abstract_embeddings_float16.npy"
|
| 77 |
+
# texts_path = f"data/{subject}_abstract_texts.json"
|
| 78 |
+
# feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
|
| 79 |
+
# metadata_path = f'data/{subject}_paper_metadata.csv'
|
| 80 |
+
# topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
|
| 81 |
+
# norms_path = f"data/{subject}_norms_{k}_{n_dirs}.npy"
|
| 82 |
+
# topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
|
| 83 |
+
# families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
|
| 84 |
+
# family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
|
| 85 |
+
# nns_32to64 = json.load(open(f"data/{subject}_nns_32to64.json"))
|
| 86 |
+
# nns_16to32 = json.load(open(f"data/{subject}_nns_16to32.json"))
|
| 87 |
+
# nns_16to64 = json.load(open(f"data/{subject}_nns_16to64.json"))
|
| 88 |
+
|
| 89 |
+
# abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
|
| 90 |
+
# with open(texts_path, 'r') as f:
|
| 91 |
+
# abstract_texts = json.load(f)
|
| 92 |
+
# with open(feature_analysis_path, 'r') as f:
|
| 93 |
+
# feature_analysis = json.load(f)
|
| 94 |
+
# df_metadata = pd.read_csv(metadata_path)
|
| 95 |
+
# topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
|
| 96 |
+
# topk_values = np.load(topk_values_path).astype(np.float32)
|
| 97 |
+
# norms = np.load(norms_path).astype(np.float32)
|
| 98 |
+
|
| 99 |
+
# model_filename = f"{subject}_64_9216.pth"
|
| 100 |
+
# model_path = os.path.join("data", model_filename)
|
| 101 |
+
|
| 102 |
+
# ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik=0).to(device)
|
| 103 |
+
# ae.load_state_dict(torch.load(model_path))
|
| 104 |
+
# ae.eval()
|
| 105 |
+
|
| 106 |
+
# weights = torch.load(model_path)
|
| 107 |
+
# decoder = weights['decoder.weight'].cpu().numpy()
|
| 108 |
+
# del weights
|
| 109 |
+
|
| 110 |
+
# with open(family_analysis_path, 'r') as f:
|
| 111 |
+
# family_analysis = json.load(f)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# return {
|
| 115 |
+
# 'abstract_embeddings': abstract_embeddings,
|
| 116 |
+
# 'abstract_texts': abstract_texts,
|
| 117 |
+
# 'feature_analysis': feature_analysis,
|
| 118 |
+
# 'df_metadata': df_metadata,
|
| 119 |
+
# 'topk_indices': topk_indices,
|
| 120 |
+
# 'topk_values': topk_values,
|
| 121 |
+
# 'norms': norms,
|
| 122 |
+
# 'nns_32to64': nns_32to64,
|
| 123 |
+
# 'nns_16to64': nns_16to64,
|
| 124 |
+
# 'ae': ae,
|
| 125 |
+
# 'decoder': decoder,
|
| 126 |
+
# # 'feature_families': feature_families,
|
| 127 |
+
# 'family_analysis': family_analysis
|
| 128 |
+
# }
|
| 129 |
+
|
| 130 |
import gradio as gr
|
| 131 |
import numpy as np
|
| 132 |
import json
|
|
|
|
| 140 |
from collections import Counter
|
| 141 |
from huggingface_hub import hf_hub_download
|
| 142 |
import os
|
| 143 |
+
import networkx as nx
|
| 144 |
+
import plotly.graph_objs as go
|
| 145 |
+
from ast import literal_eval as make_tuple
|
| 146 |
+
import random
|
| 147 |
|
| 148 |
import os
|
| 149 |
print(os.getenv('MODEL_REPO_ID'))
|
|
|
|
| 177 |
# "csLG_clean_families_64_9216.json",
|
| 178 |
# "astroPH_clean_families_64_9216.json",
|
| 179 |
"astroPH_family_analysis_64_9216.json",
|
| 180 |
+
"csLG_family_analysis_64_9216.json",
|
| 181 |
+
"csLG_nns_32to64.json",
|
| 182 |
+
"csLG_nns_16to32.json",
|
| 183 |
+
"csLG_nns_16to64.json",
|
| 184 |
+
"astroPH_nns_32to64.json",
|
| 185 |
+
"astroPH_nns_16to32.json",
|
| 186 |
+
"astroPH_nns_16to64.json",
|
| 187 |
+
"csLG_norms_64_9216_float16.npy",
|
| 188 |
+
"astroPH_norms_64_9216_float16.npy"
|
| 189 |
]
|
| 190 |
|
| 191 |
for file in files_to_download:
|
|
|
|
| 215 |
feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
|
| 216 |
metadata_path = f'data/{subject}_paper_metadata.csv'
|
| 217 |
topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
|
| 218 |
+
norms_path = f"data/{subject}_norms_{k}_{n_dirs}_float16.npy"
|
| 219 |
topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
|
| 220 |
families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
|
| 221 |
family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
|
| 222 |
+
nns_32to64 = json.load(open(f"data/{subject}_nns_32to64.json"))
|
| 223 |
+
nns_16to32 = json.load(open(f"data/{subject}_nns_16to32.json"))
|
| 224 |
+
nns_16to64 = json.load(open(f"data/{subject}_nns_16to64.json"))
|
| 225 |
|
| 226 |
abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
|
| 227 |
with open(texts_path, 'r') as f:
|
|
|
|
| 231 |
df_metadata = pd.read_csv(metadata_path)
|
| 232 |
topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
|
| 233 |
topk_values = np.load(topk_values_path).astype(np.float32)
|
| 234 |
+
norms = np.load(norms_path).astype(np.float32)
|
| 235 |
|
| 236 |
model_filename = f"{subject}_64_9216.pth"
|
| 237 |
model_path = os.path.join("data", model_filename)
|
|
|
|
| 255 |
'df_metadata': df_metadata,
|
| 256 |
'topk_indices': topk_indices,
|
| 257 |
'topk_values': topk_values,
|
| 258 |
+
'norms': norms,
|
| 259 |
+
'nns_32to64': nns_32to64,
|
| 260 |
+
'nns_16to64': nns_16to64,
|
| 261 |
'ae': ae,
|
| 262 |
'decoder': decoder,
|
| 263 |
# 'feature_families': feature_families,
|
|
|
|
| 312 |
|
| 313 |
def calculate_co_occurrences(subject, target_index, n_features=9216):
|
| 314 |
topk_indices = subject_data[subject]['topk_indices']
|
| 315 |
+
norms = subject_data[subject]['norms']
|
| 316 |
|
| 317 |
mask = np.any(topk_indices == target_index, axis=1)
|
| 318 |
co_occurring_indices = topk_indices[mask].flatten()
|
| 319 |
co_occurrences = Counter(co_occurring_indices)
|
| 320 |
del co_occurrences[target_index]
|
| 321 |
+
result = np.zeros(n_features, dtype=np.float32)
|
| 322 |
result[list(co_occurrences.keys())] = list(co_occurrences.values())
|
| 323 |
+
result[list(co_occurrences.keys())] /= np.minimum(norms[list(co_occurrences.keys())], norms[target_index])
|
| 324 |
return result
|
| 325 |
|
| 326 |
def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
|
|
|
|
| 442 |
"Co-occurrences": topk_values_co_occurrence
|
| 443 |
})
|
| 444 |
df_co_occurrences_styled = df_co_occurrences.style.format({
|
| 445 |
+
"Co-occurrences": "{:.2f}" # 2 decimal points
|
| 446 |
})
|
| 447 |
|
| 448 |
+
# Add new code for feature splitting
|
| 449 |
+
nns_16to64 = subject_data[subject]['nns_16to64']
|
| 450 |
+
nns_32to64 = subject_data[subject]['nns_32to64']
|
| 451 |
+
|
| 452 |
+
# Get nearest neighbors for 16 and 32
|
| 453 |
+
#nn_16 = nns_16to64[str(index)]
|
| 454 |
+
|
| 455 |
+
# this is really involved it's a lot easier the other direction
|
| 456 |
+
nn_16 = []
|
| 457 |
+
for key in nns_16to64.keys():
|
| 458 |
+
for match in nns_16to64[key]:
|
| 459 |
+
if index == match['feature'][0]:
|
| 460 |
+
nn_16.append([key, float(match['similarity'])])
|
| 461 |
+
|
| 462 |
+
#nn_32 = nns_32to64[str(index)]
|
| 463 |
+
nn_32 = []
|
| 464 |
+
for key in nns_32to64.keys():
|
| 465 |
+
for match in nns_32to64[key]:
|
| 466 |
+
if index == match['feature'][0]:
|
| 467 |
+
nn_32.append([key, float(match['similarity'])])
|
| 468 |
+
|
| 469 |
+
# Create dataframes for 16 and 32 nearest neighbors
|
| 470 |
+
try:
|
| 471 |
+
df_16 = pd.DataFrame(nn_16, columns=["Feature", "Cosine Similarity"])
|
| 472 |
+
df_16 = df_16.style.format({"Cosine Similarity": "{:.4f}"})
|
| 473 |
+
except:
|
| 474 |
+
df_16 = pd.DataFrame(["No Match"], columns=["Feature"])
|
| 475 |
+
|
| 476 |
+
try:
|
| 477 |
+
df_32 = pd.DataFrame(nn_32, columns=["Feature", "Cosine Similarity"])
|
| 478 |
+
df_32 = df_32.style.format({"Cosine Similarity": "{:.4f}"})
|
| 479 |
+
except:
|
| 480 |
+
df_32 = pd.DataFrame(["No Match"], columns=["Feature"])
|
| 481 |
+
|
| 482 |
+
return output, styled_top_abstracts, df_top_correlated_styled, df_bottom_correlated_styled, df_co_occurrences_styled, fig2, df_16, df_32
|
| 483 |
+
|
| 484 |
+
def create_interactive_directed_graph(family):
|
| 485 |
+
matrix = np.array(family['matrix'])
|
| 486 |
+
matrix[matrix < 0.07] = 0
|
| 487 |
+
densities = family['densities']
|
| 488 |
+
for i in range(len(densities)):
|
| 489 |
+
for j in range(len(densities)):
|
| 490 |
+
if densities[i] < densities[j]:
|
| 491 |
+
matrix[i][j] = 0
|
| 492 |
+
|
| 493 |
+
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph())
|
| 494 |
+
|
| 495 |
+
num_nodes = len(family['feature_f1'])
|
| 496 |
+
all_f1s = family['feature_pearson'] + [family['family_pearson']]
|
| 497 |
+
node_info = {i: {"name": f"{family['feature_names'][i]}", "density": family['densities'][i], "pearson": all_f1s[i]} for i in range(num_nodes)}
|
| 498 |
+
nx.set_node_attributes(G, node_info)
|
| 499 |
+
|
| 500 |
+
# Create node trace
|
| 501 |
+
node_x = []
|
| 502 |
+
node_y = []
|
| 503 |
+
node_text = []
|
| 504 |
+
node_size = []
|
| 505 |
+
node_color = []
|
| 506 |
+
pos = nx.spring_layout(G, k = np.sqrt(1/num_nodes) * 3)
|
| 507 |
+
for node in G.nodes():
|
| 508 |
+
x, y = pos[node]
|
| 509 |
+
node_x.append(x)
|
| 510 |
+
node_y.append(y)
|
| 511 |
+
node_text.append(G.nodes[node]['name'] + "<br>log density: " + str(round(np.log10(G.nodes[node]['density'] + 1e-5), 3)))
|
| 512 |
+
node_size.append((np.log10(G.nodes[node]['density'] + 1e-5) + 6) * 10)
|
| 513 |
+
node_color.append(G.nodes[node]['pearson'])
|
| 514 |
+
|
| 515 |
+
node_trace = go.Scatter(
|
| 516 |
+
x=node_x, y=node_y,
|
| 517 |
+
mode='markers',
|
| 518 |
+
hoverinfo='text',
|
| 519 |
+
marker=dict(
|
| 520 |
+
showscale=True,
|
| 521 |
+
colorscale='purples',
|
| 522 |
+
size=node_size, # Set node marker size to node['f1']
|
| 523 |
+
color=node_color,
|
| 524 |
+
cmin = 0,
|
| 525 |
+
cmax = 1,
|
| 526 |
+
colorbar=dict(
|
| 527 |
+
thickness=15,
|
| 528 |
+
title='Pearson Correlation',
|
| 529 |
+
xanchor='left',
|
| 530 |
+
titleside='right',
|
| 531 |
+
),
|
| 532 |
+
line_width=2,
|
| 533 |
+
opacity = 1,),
|
| 534 |
+
opacity = 1)
|
| 535 |
+
|
| 536 |
+
node_trace.text = node_text
|
| 537 |
+
|
| 538 |
+
# Create edge trace
|
| 539 |
+
edge_traces = []
|
| 540 |
+
annotations = []
|
| 541 |
+
for edge in G.edges():
|
| 542 |
+
x0, y0 = pos[edge[0]]
|
| 543 |
+
x1, y1 = pos[edge[1]]
|
| 544 |
+
weight = matrix[edge[0], edge[1]]
|
| 545 |
+
|
| 546 |
+
# Calculate offset (adjust this value to move arrows further from or closer to nodes)
|
| 547 |
+
offset = 0.00
|
| 548 |
+
start_x = x0
|
| 549 |
+
start_y = y0
|
| 550 |
+
end_x = x1
|
| 551 |
+
end_y = y1
|
| 552 |
+
|
| 553 |
+
# # Calculate new start and end points
|
| 554 |
+
# if start_x > end_x:
|
| 555 |
+
# start_x = x0 - offset
|
| 556 |
+
# end_x = x0 + offset
|
| 557 |
+
# else:
|
| 558 |
+
# start_x = x0 + offset
|
| 559 |
+
# end_x = x1 - offset
|
| 560 |
+
# if start_y > end_y:
|
| 561 |
+
# start_y = y0 - offset
|
| 562 |
+
# end_y = y1 + offset
|
| 563 |
+
# else:
|
| 564 |
+
# start_y = y0 + offset
|
| 565 |
+
# end_y = y1 - offset
|
| 566 |
+
|
| 567 |
+
edge_trace = go.Scatter(
|
| 568 |
+
x=[start_x, end_x, None],
|
| 569 |
+
y=[start_y, end_y, None],
|
| 570 |
+
line=dict(width=weight * 20, color='#888'), # Multiply weight by 20 for better visibility
|
| 571 |
+
hovertext="weight: " + str(round(weight, 3)), # Set the hover text to the edge weight
|
| 572 |
+
mode='lines',
|
| 573 |
+
line_shape='spline',
|
| 574 |
+
opacity = 0.5,
|
| 575 |
+
)
|
| 576 |
+
edge_traces.append(edge_trace)
|
| 577 |
+
|
| 578 |
+
annotation = dict(
|
| 579 |
+
ax=start_x,
|
| 580 |
+
ay=start_y,
|
| 581 |
+
x=end_x,
|
| 582 |
+
y=end_y,
|
| 583 |
+
xref='x',
|
| 584 |
+
yref='y',
|
| 585 |
+
axref='x',
|
| 586 |
+
ayref='y',
|
| 587 |
+
showarrow=True,
|
| 588 |
+
arrowhead=4,
|
| 589 |
+
arrowsize=4, #max(min(weight * 3, 0.3), 2), # Reduced from 30 to 10
|
| 590 |
+
arrowwidth=1, # Reduced from 30 to 2
|
| 591 |
+
arrowcolor='#999',
|
| 592 |
+
opacity = 1,
|
| 593 |
+
)
|
| 594 |
+
annotations.append(annotation)
|
| 595 |
+
|
| 596 |
+
annotation_trace = go.Scatter(x=[], y=[], mode='markers', hoverinfo='none', marker=dict(opacity=0))
|
| 597 |
+
|
| 598 |
+
# Create the figure
|
| 599 |
+
fig = go.Figure(data=[annotation_trace, *edge_traces, node_trace],
|
| 600 |
+
layout=go.Layout(
|
| 601 |
+
showlegend=False,
|
| 602 |
+
hovermode='closest',
|
| 603 |
+
margin=dict(b=20,l=5,r=5,t=40),
|
| 604 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 605 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)),
|
| 606 |
+
)
|
| 607 |
+
fig.update_xaxes(showline=False, linewidth=0, gridcolor='white')
|
| 608 |
+
fig.update_yaxes(showline=False, linewidth=0, gridcolor='white')
|
| 609 |
+
fig.update_layout(
|
| 610 |
+
plot_bgcolor='white',
|
| 611 |
+
annotations=annotations,
|
| 612 |
+
)
|
| 613 |
+
return fig
|
| 614 |
|
| 615 |
# Modify the main interface function
|
| 616 |
def create_interface():
|
|
|
|
| 769 |
def search_feature_labels(search_text):
|
| 770 |
if not search_text:
|
| 771 |
return gr.CheckboxGroup(choices=[])
|
| 772 |
+
matches = [f for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
|
| 773 |
+
matches = sorted(matches, key=lambda x: x['pearson_correlation'], reverse=True)
|
| 774 |
+
matches = [f"{f['label']} ({f['index']})" for f in matches]
|
| 775 |
+
|
| 776 |
return gr.CheckboxGroup(choices=matches[:10])
|
| 777 |
|
| 778 |
feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
|
|
|
|
| 855 |
wrap=True
|
| 856 |
)
|
| 857 |
|
| 858 |
+
gr.Markdown("## Feature Splitting")
|
| 859 |
with gr.Row():
|
| 860 |
with gr.Column(scale=1):
|
| 861 |
+
gr.Markdown("### Best Match in SAE16")
|
| 862 |
+
nn_16_table = gr.Dataframe(
|
| 863 |
+
headers=["Feature", "Cosine Similarity"],
|
| 864 |
interactive=False
|
| 865 |
)
|
| 866 |
with gr.Column(scale=1):
|
| 867 |
+
gr.Markdown("### Best Match in SAE32")
|
| 868 |
+
nn_32_table = gr.Dataframe(
|
| 869 |
+
headers=["Feature", "Cosine Similarity"],
|
| 870 |
interactive=False
|
| 871 |
)
|
| 872 |
+
|
| 873 |
with gr.Row():
|
| 874 |
with gr.Column(scale=1):
|
| 875 |
+
gr.Markdown("## Top Co-occurring Features")
|
| 876 |
co_occurring_features = gr.Dataframe(
|
| 877 |
headers=["Feature", "Co-occurrences"],
|
| 878 |
interactive=False
|
|
|
|
| 881 |
gr.Markdown(f"## Activation Value Distribution")
|
| 882 |
activation_dist = gr.Plot()
|
| 883 |
|
| 884 |
+
gr.Markdown("## Correlated Features")
|
| 885 |
+
with gr.Row():
|
| 886 |
+
with gr.Column(scale=1):
|
| 887 |
+
gr.Markdown("### Top Correlated Features")
|
| 888 |
+
top_correlated = gr.Dataframe(
|
| 889 |
+
headers=["Feature", "Cosine similarity"],
|
| 890 |
+
interactive=False
|
| 891 |
+
)
|
| 892 |
+
with gr.Column(scale=1):
|
| 893 |
+
gr.Markdown("### Bottom Correlated Features")
|
| 894 |
+
bottom_correlated = gr.Dataframe(
|
| 895 |
+
headers=["Feature", "Cosine similarity"],
|
| 896 |
+
interactive=False
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
|
| 902 |
def search_feature_labels(search_text, current_subject):
|
| 903 |
if not search_text:
|
| 904 |
return gr.CheckboxGroup(choices=[])
|
| 905 |
+
matches = [f for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
|
| 906 |
+
matches = sorted(matches, key=lambda x: x['pearson_correlation'], reverse=True)
|
| 907 |
+
matches = [f"{f['label']} ({f['index']})" for f in matches]
|
| 908 |
+
|
| 909 |
return gr.CheckboxGroup(choices=matches[:10])
|
| 910 |
|
| 911 |
feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
|
|
|
|
| 916 |
|
| 917 |
# Extract the feature index from the selected feature string
|
| 918 |
feature_index = int(selected_features[0].split('(')[-1].strip(')'))
|
| 919 |
+
feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, nn_16, nn_32 = visualize_feature(current_subject, feature_index)
|
| 920 |
|
| 921 |
# Return the visualization results along with empty values for search box and checkbox
|
| 922 |
+
return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", [], nn_16, nn_32
|
| 923 |
|
| 924 |
visualize_button.click(
|
| 925 |
on_visualize,
|
| 926 |
inputs=[feature_matches, subject],
|
| 927 |
+
outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches, nn_16_table, nn_32_table]
|
| 928 |
)
|
| 929 |
|
| 930 |
with gr.Tab("Feature Families"):
|
|
|
|
| 935 |
family_matches = gr.CheckboxGroup(label="Matching Feature Families", choices=[])
|
| 936 |
visualize_family_button = gr.Button("Visualize Feature Family")
|
| 937 |
|
| 938 |
+
|
| 939 |
family_dataframe = gr.Dataframe(
|
| 940 |
+
headers=["Feature", "Parent Co-Occurrence", "F1 Score", "Pearson"],
|
| 941 |
+
datatype=["markdown", "number", "number", "number"],
|
| 942 |
label="Family and Child Features"
|
| 943 |
)
|
| 944 |
|
| 945 |
+
gr.Markdown("# Family Graph")
|
| 946 |
+
graph_plot = gr.Plot(label="Directed Graph")
|
| 947 |
+
|
| 948 |
+
# family_info = gr.Markdown()
|
| 949 |
|
| 950 |
def search_feature_families(search_text, current_subject):
|
| 951 |
family_analysis = subject_data[current_subject]['family_analysis']
|
| 952 |
if not search_text:
|
| 953 |
return gr.CheckboxGroup(choices=[])
|
| 954 |
+
matches = [family for family in family_analysis if search_text.lower() in family['superfeature'].lower()]
|
| 955 |
+
matches = sorted(matches, key=lambda x: x['family_pearson'], reverse=True)
|
| 956 |
+
matches = [family['superfeature'] for family in matches]
|
| 957 |
+
matches = list(dict.fromkeys(matches))
|
| 958 |
return gr.CheckboxGroup(choices=matches[:10]) # Limit to top 10 matches
|
| 959 |
|
| 960 |
def visualize_feature_family(selected_families, current_subject):
|
|
|
|
| 974 |
df_data = [
|
| 975 |
{
|
| 976 |
"Feature": f"## {family_data['superfeature']}",
|
| 977 |
+
"Parent Co-Occurrence": 1,
|
| 978 |
"F1 Score": round(family_data['family_f1'], 2),
|
| 979 |
+
"Pearson": round(family_data['family_pearson'], 4)
|
| 980 |
},
|
| 981 |
]
|
| 982 |
|
| 983 |
+
coocs = np.array(family_data['matrix'])[:, -1]
|
| 984 |
+
# print(coocs)
|
| 985 |
+
for name, cooc, f1, pearson in zip(family_data['feature_names'], coocs, family_data['feature_f1'], family_data['feature_pearson']):
|
| 986 |
df_data.append({
|
| 987 |
"Feature": name,
|
| 988 |
+
"Parent Co-Occurrence": round(cooc, 2),
|
| 989 |
"F1 Score": round(f1, 2),
|
| 990 |
+
"Pearson": round(pearson, 4)
|
| 991 |
})
|
| 992 |
|
| 993 |
df = pd.DataFrame(df_data)
|
|
|
|
| 996 |
output += "## Super Reasoning\n"
|
| 997 |
output += f"{family_data['super_reasoning']}\n\n"
|
| 998 |
|
| 999 |
+
graph = create_interactive_directed_graph(family_data)
|
| 1000 |
+
|
| 1001 |
+
#return output, df, "", [], graph # Return empty string for search box and empty list for checkbox
|
| 1002 |
+
return df, "", [], graph
|
| 1003 |
|
| 1004 |
family_search.change(search_feature_families, inputs=[family_search, subject], outputs=[family_matches])
|
| 1005 |
visualize_family_button.click(
|
| 1006 |
visualize_feature_family,
|
| 1007 |
inputs=[family_matches, subject],
|
| 1008 |
+
#outputs=[family_info, family_dataframe, family_search, family_matches, graph_plot]
|
| 1009 |
+
outputs=[family_dataframe, family_search, family_matches, graph_plot]
|
| 1010 |
)
|
| 1011 |
|
| 1012 |
|