Spaces:
Runtime error
Runtime error
| import os | |
| import platform | |
| import re | |
| from collections import defaultdict | |
| import gradio as gr | |
| from cachetools import TTLCache, cached | |
| from cytoolz import groupby | |
| from huggingface_hub import CollectionItem, get_collection, list_datasets, list_models | |
| from tqdm.auto import tqdm | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from apscheduler.triggers.cron import CronTrigger | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| is_macos = platform.system() == "Darwin" | |
| local = platform.system() == "Darwin" | |
| LIMIT = 1000 if is_macos else None # limit for local dev because slooow internet | |
| CACHE_TIME = 60 * 15 # 15 minutes | |
| def get_models(): | |
| print("getting models...") | |
| return list(tqdm(iter(list_models(full=True, limit=LIMIT)))) | |
| def get_datasets(): | |
| print("getting datasets...") | |
| return list(tqdm(iter(list_datasets(full=True, limit=LIMIT)))) | |
| get_models() # warm up the cache | |
| get_datasets() # warm up the cache | |
| def check_for_arxiv_id(model): | |
| return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False | |
| def extract_arxiv_id(input_string: str) -> str: | |
| pattern = re.compile(r"\barxiv:(\d+\.\d+)\b") | |
| match = pattern.search(input_string) | |
| return match[1] if match else None | |
| def create_model_to_arxiv_id_dict(): | |
| models = get_models() | |
| model_to_arxiv_id = {} | |
| for model in models: | |
| if arxiv_papers := check_for_arxiv_id(model): | |
| clean_arxiv_ids = [] | |
| for paper in arxiv_papers: | |
| if arxiv_id := extract_arxiv_id(paper): | |
| clean_arxiv_ids.append(arxiv_id) | |
| model_to_arxiv_id[model.modelId] = clean_arxiv_ids | |
| return model_to_arxiv_id | |
| def create_dataset_to_arxiv_id_dict(): | |
| datasets = get_datasets() | |
| dataset_to_arxiv_id = {} | |
| for dataset in datasets: | |
| if arxiv_papers := check_for_arxiv_id(dataset): | |
| clean_arxiv_ids = [] | |
| for paper in arxiv_papers: | |
| if arxiv_id := extract_arxiv_id(paper): | |
| clean_arxiv_ids.append(arxiv_id) | |
| dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids | |
| return dataset_to_arxiv_id | |
| def get_collection_type(collection_item: CollectionItem): | |
| try: | |
| return f"{collection_item.item_type}s" | |
| except AttributeError: | |
| return None | |
| def group_collection_items(collection_slug: str): | |
| collection = get_collection(collection_slug) | |
| items = collection.items | |
| return groupby(get_collection_type, items) | |
| def get_papers_for_collection(collection_slug: str): | |
| dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict() | |
| models_to_arxiv_id = create_model_to_arxiv_id_dict() | |
| collection = group_collection_items(collection_slug) | |
| collection_datasets = collection.get("datasets", None) | |
| collection_models = collection.get("models", None) | |
| papers = collection.get("papers", None) | |
| dataset_papers = defaultdict(dict) | |
| model_papers = defaultdict(dict) | |
| collection_papers = defaultdict(dict) | |
| if collection_datasets is not None: | |
| for dataset in collection_datasets: | |
| if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None): | |
| data = { | |
| "arxiv_ids": arxiv_ids, | |
| "hub_paper_links": [ | |
| f"https://huggingface.co/papers/{arxiv_id}" | |
| for arxiv_id in arxiv_ids | |
| ], | |
| } | |
| dataset_papers[dataset.item_id] = data | |
| if collection_models is not None: | |
| for model in collection.get("models", []): | |
| if arxiv_ids := models_to_arxiv_id.get(model.item_id, None): | |
| data = { | |
| "arxiv_ids": arxiv_ids, | |
| "hub_paper_links": [ | |
| f"https://huggingface.co/papers/{arxiv_id}" | |
| for arxiv_id in arxiv_ids | |
| ], | |
| } | |
| model_papers[model.item_id] = data | |
| if papers is not None: | |
| for paper in papers: | |
| data = { | |
| "arxiv_ids": [paper.item_id], | |
| "hub_paper_links": [f"https://huggingface.co/papers/{paper.item_id}"], | |
| } | |
| collection_papers[paper.item_id] = data | |
| if not dataset_papers: | |
| dataset_papers = None | |
| if not model_papers: | |
| model_papers = None | |
| if not collection_papers: | |
| collection_papers = None | |
| return { | |
| "dataset papers": dataset_papers, | |
| "model papers": model_papers, | |
| "papers": collection_papers, | |
| } | |
| scheduler = BackgroundScheduler() | |
| scheduler.add_job(get_datasets, "interval", minutes=15) | |
| scheduler.add_job(get_models, "interval", minutes=15) | |
| scheduler.start() | |
| placeholder_url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9" | |
| slug_input = gr.Textbox( | |
| placeholder=placeholder_url, interactive=True, label="Collection slug", max_lines=1 | |
| ) | |
| description = ( | |
| "Enter a Collection slug to get the arXiv IDs and Hugging Face Paper links for" | |
| " papers associated with models and datasets in the collection. If the collection" | |
| " includes papers the arXiv IDs and Hugging Face Paper links will be returned for" | |
| " those papers as well." | |
| ) | |
| examples = [ | |
| placeholder_url, | |
| "davanstrien/historic-language-modeling-64f99e243188ade79d7ad74b", | |
| ] | |
| gr.Interface( | |
| get_papers_for_collection, | |
| slug_input, | |
| "json", | |
| title="ππ: Extract linked papers from a Hugging Face Collection", | |
| description=description, | |
| examples=examples, | |
| cache_examples=True, | |
| ).queue(concurrency_count=4).launch() | |