theorem-search / src /streamlit_app.py
Sophie
added filtering and citation counts
c38e4ce
raw
history blame
16 kB
import streamlit as st
import json
import numpy as np
from sentence_transformers import SentenceTransformer, util
import os
import boto3
import psycopg2
from psycopg2.extensions import connection
import torch
import re
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from dotenv import load_dotenv
from latex_clean import clean_latex_for_display
# Config
load_dotenv()
def get_rds_connection() -> connection:
region = os.getenv("AWS_REGION")
secret_arn = os.getenv("RDS_SECRET_ARN")
host = os.getenv("RDS_HOST")
dbname = os.getenv("RDS_DB_NAME")
sm = boto3.client("secretsmanager", region_name=region)
secret_value = sm.get_secret_value(SecretId=secret_arn)
secret_dict = json.loads(secret_value["SecretString"])
conn = psycopg2.connect(
host=host or secret_dict.get("host"),
port=int(secret_dict.get("port", 5432)),
dbname=dbname or secret_dict.get("dbname"),
user=secret_dict["username"],
password=secret_dict["password"],
sslmode="require",
)
return conn
AVAILABLE_TAGS = {
"arXiv": [
"math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO",
"math.CT", "math.CV", "math.DG", "math.DS", "math.FA", "math.GM",
"math.GN", "math.GR", "math.GT", "math.HO", "math.IT", "math.KT",
"math.LO", "math.MG", "math.MP", "math.NA", "math.NT", "math.OA",
"math.OC", "math.PR", "math.QA", "math.RA", "math.RT", "math.SG",
"math.SP", "math.ST", "Statistics Theory"
],
"Stacks Project": [
"Sets", "Schemes", "Algebraic Stacks", "Étale Cohomology"
]
}
ALLOWED_TYPES = [
"theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
]
ARXIV_ID_RE = re.compile(
r'arxiv\.org/(?:abs|pdf)/((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))(?:v\d+)?',
re.IGNORECASE
)
# Load the Embedding Model
@st.cache_resource
def load_model():
"""
Loads the specialized math embedding model from Hugging Face.
"""
try:
model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath')
return model
except Exception as e:
st.error(f"Error loading the embedding model: {e}")
return None
# Load Data from RDS
@st.cache_data
def load_papers_from_rds():
"""
Loads theorem data from the RDS database and prepares it for embedding.
Returns a list of theorem dictionaries with all necessary fields.
"""
try:
conn = get_rds_connection()
cur = conn.cursor()
# Fetch all papers with their theorems and embeddings
cur.execute("""
SELECT
tm.paper_id,
tm.title,
tm.authors,
tm.link,
tm.last_updated,
tm.summary,
tm.journal_ref,
tm.primary_category,
tm.categories,
tm.global_notations,
tm.global_definitions,
tm.global_assumptions,
te.theorem_name,
te.theorem_slogan,
te.theorem_body,
te.embedding
FROM theorem_metadata tm
JOIN theorem_embedding te ON tm.paper_id = te.paper_id
ORDER BY tm.paper_id, te.theorem_name;
""")
rows = cur.fetchall()
cur.close()
conn.close()
all_theorems_data = []
for row in rows:
(paper_id, title, authors, link, last_updated, summary,
journal_ref, primary_category, categories,
global_notations, global_definitions, global_assumptions,
theorem_name, theorem_slogan, theorem_body, embedding) = row
# Build global context
global_context_parts = []
if global_notations:
global_context_parts.append(f"**Global Notations:**\n{global_notations}")
if global_definitions:
global_context_parts.append(f"**Global Definitions:**\n{global_definitions}")
if global_assumptions:
global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}")
global_context = "\n\n".join(global_context_parts)
# Convert embedding to a numpy float array
if isinstance(embedding, str):
embedding = json.loads(embedding)
if isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float32)
elif isinstance(embedding, np.ndarray):
embedding = embedding.astype(np.float32)
# Determine source from url
link_str = link or ""
if link_str.startswith("http://arxiv.org") or link_str.startswith("https://arxiv.org"):
source = "arXiv"
else:
source = "Stacks Project"
# Determine type from name
def infer_type(name: str) -> str:
if not name:
return "theorem"
lower = name.lower()
for t in ["theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"]:
if t in lower:
return t
return "theorem"
inferred_type = infer_type(theorem_name or "")
all_theorems_data.append({
"paper_id": paper_id,
"authors": authors,
"paper_title": title,
"paper_url": link,
"year": last_updated.year,
"primary_category": primary_category,
"source": source,
"type": inferred_type,
"journal_published": bool(journal_ref),
"citations": None,
"theorem_name": theorem_name,
"theorem_slogan": theorem_slogan,
"theorem_body": theorem_body,
"global_context": global_context,
"stored_embedding": embedding,
})
return all_theorems_data
except Exception as e:
st.error(f"Error loading data from RDS: {e}")
return []
@st.cache_data(ttl=60*60*24) # cache for 24 hours
def fetch_citations(paper_url: str, title: str) -> int | None:
"""
Returns citation count if found, else None.
Tries the following sources in order:
1) OpenAlex by arXiv id
2) Semantic Scholar by arXiv id
3) Semantic Scholar by title
"""
arx_id = None
if paper_url:
m = ARXIV_ID_RE.search(paper_url)
if m:
arx_id = m.group(1)
# OpenAlex by arXiv id
if arx_id:
try:
r = requests.get(f"https://api.openalex.org/works/arXiv:{arx_id}", timeout=10)
if r.ok:
data = r.json()
c = data.get("cited_by_count")
if isinstance(c, int):
return c
except Exception:
pass
# Semantic Scholar by arXiv id
if arx_id:
try:
r = requests.get(
f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{arx_id}",
params={"fields": "citationCount"},
timeout=10
)
if r.ok:
j = r.json()
c = j.get("citationCount")
if isinstance(c, int):
return c
except Exception:
pass
# Fallback: Semantic Scholar by title
if title:
try:
r = requests.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={"query": title, "limit": 1, "fields": "title,citationCount"},
timeout=10
)
if r.ok:
j = r.json()
if j.get("data"):
c = j["data"][0].get("citationCount")
if isinstance(c, int):
return c
except Exception:
pass
return None
def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
# Select targets with missing citations
targets = [
it for it in candidates
if it.get("source") == "arXiv" and (it.get("citations") in (None, 0))
]
if not targets:
return
with ThreadPoolExecutor(max_workers=max_workers) as exe:
fut2item = {
exe.submit(fetch_citations, it.get("paper_url"), it.get("paper_title")): it
for it in targets
}
for fut in as_completed(fut2item):
it = fut2item[fut]
try:
c = fut.result()
if c is not None:
it["citations"] = c
except Exception:
pass
# --- Search and Display ---
def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
if not query:
st.info("Please enter a search query.")
return
if not filters['sources']:
st.warning("Please select at least one source.")
return
query_embedding = model.encode(query, convert_to_tensor=True)
cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
# Get a larger pool to filter from
top_k_pool = min(200, len(theorems_data))
top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
pool_items = [theorems_data[int(i.item())] for i in top_indices]
add_citations(pool_items)
results = []
low, high = filters['citation_range']
# Filter results
for item in pool_items:
type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
source_match = item.get('source') in filters['sources']
# Citations & year & journal only meaningful for arXiv
cit = item.get('citations')
if cit is None:
if not filters['include_unknown_citations']:
continue
citation_match = True
else:
citation_match = (low <= int(cit) <= high)
year_match = True
if filters['year_range'] and item.get('source') == 'arXiv':
y = item.get('year') or 0
yr0, yr1 = filters['year_range']
year_match = (yr0 <= y <= yr1)
journal_match = True
if item.get('source') == 'arXiv':
status = filters['journal_status']
jp = bool(item.get('journal_published'))
if status == "Journal Article":
journal_match = jp
elif status == "Preprint Only":
journal_match = not jp
if all([type_match, tag_match, author_match, source_match, citation_match, year_match, journal_match]):
results.append({"info": item, "similarity": float(cosine_scores[theorems_data.index(item)].item())})
if len(results) >= filters['top_k']:
break
st.subheader(f"Found {len(results)} Matching Results")
if not results:
st.warning("No results found for the current filters.")
return
for i, r in enumerate(results):
info = r["info"]
expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
with st.expander(expander_title):
st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
st.markdown(f"**Source:** {info.get('source')} ([Link]({info.get('paper_url')}))")
cit = info.get("citations")
cit_str = "Unknown" if cit is None else str(cit)
st.markdown(
f"**Math Tag:** `{info.get('primary_category')}` | "
f"**Citations:** {cit_str} | "
f"**Year:** {info.get('year', 'N/A')}"
)
st.markdown("---")
if info.get("theorem_slogan"):
st.markdown(f"**Slogan:** {info['theorem_slogan']}\n")
if info.get("global_context"):
cleaned_ctx = clean_latex_for_display(info["global_context"])
st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
cleaned_content = clean_latex_for_display(info['theorem_body'])
st.markdown("**Theorem Body:**")
st.markdown(cleaned_content)
# --- Main App Interface ---
st.set_page_config(page_title="Theorem Search Demo", layout="wide")
st.title("📚 Semantic Theorem Search")
st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
model = load_model()
theorems_data = load_papers_from_rds()
if model and theorems_data:
with st.spinner("Preparing embeddings from database..."):
corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!")
# --- Sidebar filters ---
with st.sidebar:
st.header("Search Filters")
all_sources = ['arXiv', 'Stacks Project']
selected_sources = st.multiselect(
"Filter by Source(s):",
all_sources,
default=all_sources[:1] if all_sources else [],
help="Select one or more sources to reveal more filters."
)
selected_authors, selected_types, selected_tags = [], [], []
year_range, journal_status = None, "All"
citation_range = (0, 1000)
top_k_results = 5
if selected_sources:
st.write("---")
selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
selected_authors = st.multiselect("Filter by Author(s):", all_authors)
# Tags come from union of categories per selected source
from collections import defaultdict
tags_per_source = defaultdict(set)
for it in theorems_data:
tags_per_source[it['source']].add(it.get('primary_category'))
union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
if 'arXiv' in selected_sources:
year_range = st.slider("Filter by Year (for arXiv):", 1991, 2025, (1991, 2025))
journal_status = st.radio("Publication Status (for arXiv):", ["All", "Journal Article", "Preprint Only"], horizontal=True)
citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
include_unknown_citations = st.checkbox(
"Include entries with unknown citation counts",
value=True,
help="If unchecked, results with unknown citation counts are excluded."
)
top_k_results = st.slider("Number of results to display:", 1, 20, 5)
filters = {
"authors": selected_authors,
"types": [t.lower() for t in selected_types],
"tags": selected_tags,
"sources": selected_sources,
"year_range": year_range,
"journal_status": journal_status,
"citation_range": citation_range,
"include_unknown_citations": include_unknown_citations,
"top_k": top_k_results
}
user_query = st.text_input("Enter your query:", "")
if st.button("Search") or user_query:
search_and_display_with_filters(user_query, model, theorems_data, corpus_embeddings, filters)
else:
st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")