Spaces:
Running
Running
| import streamlit as st | |
| import json | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer, util | |
| import os | |
| import boto3 | |
| import psycopg2 | |
| from psycopg2.extensions import connection | |
| import torch | |
| import re | |
| import requests | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from dotenv import load_dotenv | |
| from latex_clean import clean_latex_for_display | |
| # Config | |
| load_dotenv() | |
| def get_rds_connection() -> connection: | |
| region = os.getenv("AWS_REGION") | |
| secret_arn = os.getenv("RDS_SECRET_ARN") | |
| host = os.getenv("RDS_HOST") | |
| dbname = os.getenv("RDS_DB_NAME") | |
| sm = boto3.client("secretsmanager", region_name=region) | |
| secret_value = sm.get_secret_value(SecretId=secret_arn) | |
| secret_dict = json.loads(secret_value["SecretString"]) | |
| conn = psycopg2.connect( | |
| host=host or secret_dict.get("host"), | |
| port=int(secret_dict.get("port", 5432)), | |
| dbname=dbname or secret_dict.get("dbname"), | |
| user=secret_dict["username"], | |
| password=secret_dict["password"], | |
| sslmode="require", | |
| ) | |
| return conn | |
| AVAILABLE_TAGS = { | |
| "arXiv": [ | |
| "math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO", | |
| "math.CT", "math.CV", "math.DG", "math.DS", "math.FA", "math.GM", | |
| "math.GN", "math.GR", "math.GT", "math.HO", "math.IT", "math.KT", | |
| "math.LO", "math.MG", "math.MP", "math.NA", "math.NT", "math.OA", | |
| "math.OC", "math.PR", "math.QA", "math.RA", "math.RT", "math.SG", | |
| "math.SP", "math.ST", "Statistics Theory" | |
| ], | |
| "Stacks Project": [ | |
| "Sets", "Schemes", "Algebraic Stacks", "Étale Cohomology" | |
| ] | |
| } | |
| ALLOWED_TYPES = [ | |
| "theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption" | |
| ] | |
| ARXIV_ID_RE = re.compile( | |
| r'arxiv\.org/(?:abs|pdf)/((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))(?:v\d+)?', | |
| re.IGNORECASE | |
| ) | |
| # Load the Embedding Model | |
| def load_model(): | |
| """ | |
| Loads the specialized math embedding model from Hugging Face. | |
| """ | |
| try: | |
| model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath') | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading the embedding model: {e}") | |
| return None | |
| # Load Data from RDS | |
| def load_papers_from_rds(): | |
| """ | |
| Loads theorem data from the RDS database and prepares it for embedding. | |
| Returns a list of theorem dictionaries with all necessary fields. | |
| """ | |
| try: | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| # Fetch all papers with their theorems and embeddings | |
| cur.execute(""" | |
| SELECT | |
| tm.paper_id, | |
| tm.title, | |
| tm.authors, | |
| tm.link, | |
| tm.last_updated, | |
| tm.summary, | |
| tm.journal_ref, | |
| tm.primary_category, | |
| tm.categories, | |
| tm.global_notations, | |
| tm.global_definitions, | |
| tm.global_assumptions, | |
| te.theorem_name, | |
| te.theorem_slogan, | |
| te.theorem_body, | |
| te.embedding | |
| FROM theorem_metadata tm | |
| JOIN theorem_embedding te ON tm.paper_id = te.paper_id | |
| ORDER BY tm.paper_id, te.theorem_name; | |
| """) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| all_theorems_data = [] | |
| for row in rows: | |
| (paper_id, title, authors, link, last_updated, summary, | |
| journal_ref, primary_category, categories, | |
| global_notations, global_definitions, global_assumptions, | |
| theorem_name, theorem_slogan, theorem_body, embedding) = row | |
| # Build global context | |
| global_context_parts = [] | |
| if global_notations: | |
| global_context_parts.append(f"**Global Notations:**\n{global_notations}") | |
| if global_definitions: | |
| global_context_parts.append(f"**Global Definitions:**\n{global_definitions}") | |
| if global_assumptions: | |
| global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}") | |
| global_context = "\n\n".join(global_context_parts) | |
| # Convert embedding to a numpy float array | |
| if isinstance(embedding, str): | |
| embedding = json.loads(embedding) | |
| if isinstance(embedding, list): | |
| embedding = np.array(embedding, dtype=np.float32) | |
| elif isinstance(embedding, np.ndarray): | |
| embedding = embedding.astype(np.float32) | |
| # Determine source from url | |
| link_str = link or "" | |
| if link_str.startswith("http://arxiv.org") or link_str.startswith("https://arxiv.org"): | |
| source = "arXiv" | |
| else: | |
| source = "Stacks Project" | |
| # Determine type from name | |
| def infer_type(name: str) -> str: | |
| if not name: | |
| return "theorem" | |
| lower = name.lower() | |
| for t in ["theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"]: | |
| if t in lower: | |
| return t | |
| return "theorem" | |
| inferred_type = infer_type(theorem_name or "") | |
| all_theorems_data.append({ | |
| "paper_id": paper_id, | |
| "authors": authors, | |
| "paper_title": title, | |
| "paper_url": link, | |
| "year": last_updated.year, | |
| "primary_category": primary_category, | |
| "source": source, | |
| "type": inferred_type, | |
| "journal_published": bool(journal_ref), | |
| "citations": None, | |
| "theorem_name": theorem_name, | |
| "theorem_slogan": theorem_slogan, | |
| "theorem_body": theorem_body, | |
| "global_context": global_context, | |
| "stored_embedding": embedding, | |
| }) | |
| return all_theorems_data | |
| except Exception as e: | |
| st.error(f"Error loading data from RDS: {e}") | |
| return [] | |
| # cache for 24 hours | |
| def fetch_citations(paper_url: str, title: str) -> int | None: | |
| """ | |
| Returns citation count if found, else None. | |
| Tries the following sources in order: | |
| 1) OpenAlex by arXiv id | |
| 2) Semantic Scholar by arXiv id | |
| 3) Semantic Scholar by title | |
| """ | |
| arx_id = None | |
| if paper_url: | |
| m = ARXIV_ID_RE.search(paper_url) | |
| if m: | |
| arx_id = m.group(1) | |
| # OpenAlex by arXiv id | |
| if arx_id: | |
| try: | |
| r = requests.get(f"https://api.openalex.org/works/arXiv:{arx_id}", timeout=10) | |
| if r.ok: | |
| data = r.json() | |
| c = data.get("cited_by_count") | |
| if isinstance(c, int): | |
| return c | |
| except Exception: | |
| pass | |
| # Semantic Scholar by arXiv id | |
| if arx_id: | |
| try: | |
| r = requests.get( | |
| f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{arx_id}", | |
| params={"fields": "citationCount"}, | |
| timeout=10 | |
| ) | |
| if r.ok: | |
| j = r.json() | |
| c = j.get("citationCount") | |
| if isinstance(c, int): | |
| return c | |
| except Exception: | |
| pass | |
| # Fallback: Semantic Scholar by title | |
| if title: | |
| try: | |
| r = requests.get( | |
| "https://api.semanticscholar.org/graph/v1/paper/search", | |
| params={"query": title, "limit": 1, "fields": "title,citationCount"}, | |
| timeout=10 | |
| ) | |
| if r.ok: | |
| j = r.json() | |
| if j.get("data"): | |
| c = j["data"][0].get("citationCount") | |
| if isinstance(c, int): | |
| return c | |
| except Exception: | |
| pass | |
| return None | |
| def add_citations(candidates: list[dict], max_workers: int = 6) -> None: | |
| # Select targets with missing citations | |
| targets = [ | |
| it for it in candidates | |
| if it.get("source") == "arXiv" and (it.get("citations") in (None, 0)) | |
| ] | |
| if not targets: | |
| return | |
| with ThreadPoolExecutor(max_workers=max_workers) as exe: | |
| fut2item = { | |
| exe.submit(fetch_citations, it.get("paper_url"), it.get("paper_title")): it | |
| for it in targets | |
| } | |
| for fut in as_completed(fut2item): | |
| it = fut2item[fut] | |
| try: | |
| c = fut.result() | |
| if c is not None: | |
| it["citations"] = c | |
| except Exception: | |
| pass | |
| # --- Search and Display --- | |
| def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters): | |
| if not query: | |
| st.info("Please enter a search query.") | |
| return | |
| if not filters['sources']: | |
| st.warning("Please select at least one source.") | |
| return | |
| query_embedding = model.encode(query, convert_to_tensor=True) | |
| cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0] | |
| # Get a larger pool to filter from | |
| top_k_pool = min(200, len(theorems_data)) | |
| top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices | |
| pool_items = [theorems_data[int(i.item())] for i in top_indices] | |
| add_citations(pool_items) | |
| results = [] | |
| low, high = filters['citation_range'] | |
| # Filter results | |
| for item in pool_items: | |
| type_match = (not filters['types']) or (item.get('type','').lower() in filters['types']) | |
| tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags']) | |
| author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors']) | |
| source_match = item.get('source') in filters['sources'] | |
| # Citations & year & journal only meaningful for arXiv | |
| cit = item.get('citations') | |
| if cit is None: | |
| if not filters['include_unknown_citations']: | |
| continue | |
| citation_match = True | |
| else: | |
| citation_match = (low <= int(cit) <= high) | |
| year_match = True | |
| if filters['year_range'] and item.get('source') == 'arXiv': | |
| y = item.get('year') or 0 | |
| yr0, yr1 = filters['year_range'] | |
| year_match = (yr0 <= y <= yr1) | |
| journal_match = True | |
| if item.get('source') == 'arXiv': | |
| status = filters['journal_status'] | |
| jp = bool(item.get('journal_published')) | |
| if status == "Journal Article": | |
| journal_match = jp | |
| elif status == "Preprint Only": | |
| journal_match = not jp | |
| if all([type_match, tag_match, author_match, source_match, citation_match, year_match, journal_match]): | |
| results.append({"info": item, "similarity": float(cosine_scores[theorems_data.index(item)].item())}) | |
| if len(results) >= filters['top_k']: | |
| break | |
| st.subheader(f"Found {len(results)} Matching Results") | |
| if not results: | |
| st.warning("No results found for the current filters.") | |
| return | |
| for i, r in enumerate(results): | |
| info = r["info"] | |
| expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**" | |
| with st.expander(expander_title): | |
| st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*") | |
| st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}") | |
| st.markdown(f"**Source:** {info.get('source')} ([Link]({info.get('paper_url')}))") | |
| cit = info.get("citations") | |
| cit_str = "Unknown" if cit is None else str(cit) | |
| st.markdown( | |
| f"**Math Tag:** `{info.get('primary_category')}` | " | |
| f"**Citations:** {cit_str} | " | |
| f"**Year:** {info.get('year', 'N/A')}" | |
| ) | |
| st.markdown("---") | |
| if info.get("theorem_slogan"): | |
| st.markdown(f"**Slogan:** {info['theorem_slogan']}\n") | |
| if info.get("global_context"): | |
| cleaned_ctx = clean_latex_for_display(info["global_context"]) | |
| st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") ) | |
| cleaned_content = clean_latex_for_display(info['theorem_body']) | |
| st.markdown("**Theorem Body:**") | |
| st.markdown(cleaned_content) | |
| # --- Main App Interface --- | |
| st.set_page_config(page_title="Theorem Search Demo", layout="wide") | |
| st.title("📚 Semantic Theorem Search") | |
| st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.") | |
| model = load_model() | |
| theorems_data = load_papers_from_rds() | |
| if model and theorems_data: | |
| with st.spinner("Preparing embeddings from database..."): | |
| corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data]) | |
| st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!") | |
| # --- Sidebar filters --- | |
| with st.sidebar: | |
| st.header("Search Filters") | |
| all_sources = ['arXiv', 'Stacks Project'] | |
| selected_sources = st.multiselect( | |
| "Filter by Source(s):", | |
| all_sources, | |
| default=all_sources[:1] if all_sources else [], | |
| help="Select one or more sources to reveal more filters." | |
| ) | |
| selected_authors, selected_types, selected_tags = [], [], [] | |
| year_range, journal_status = None, "All" | |
| citation_range = (0, 1000) | |
| top_k_results = 5 | |
| if selected_sources: | |
| st.write("---") | |
| selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES) | |
| all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or [])))) | |
| selected_authors = st.multiselect("Filter by Author(s):", all_authors) | |
| # Tags come from union of categories per selected source | |
| from collections import defaultdict | |
| tags_per_source = defaultdict(set) | |
| for it in theorems_data: | |
| tags_per_source[it['source']].add(it.get('primary_category')) | |
| union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t}) | |
| selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags) | |
| if 'arXiv' in selected_sources: | |
| year_range = st.slider("Filter by Year (for arXiv):", 1991, 2025, (1991, 2025)) | |
| journal_status = st.radio("Publication Status (for arXiv):", ["All", "Journal Article", "Preprint Only"], horizontal=True) | |
| citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000)) | |
| include_unknown_citations = st.checkbox( | |
| "Include entries with unknown citation counts", | |
| value=True, | |
| help="If unchecked, results with unknown citation counts are excluded." | |
| ) | |
| top_k_results = st.slider("Number of results to display:", 1, 20, 5) | |
| filters = { | |
| "authors": selected_authors, | |
| "types": [t.lower() for t in selected_types], | |
| "tags": selected_tags, | |
| "sources": selected_sources, | |
| "year_range": year_range, | |
| "journal_status": journal_status, | |
| "citation_range": citation_range, | |
| "include_unknown_citations": include_unknown_citations, | |
| "top_k": top_k_results | |
| } | |
| user_query = st.text_input("Enter your query:", "") | |
| if st.button("Search") or user_query: | |
| search_and_display_with_filters(user_query, model, theorems_data, corpus_embeddings, filters) | |
| else: | |
| st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.") |