import gradio as gr from dataclasses import dataclass import os from supabase import create_client, Client from supabase.client import ClientOptions from enum import Enum from datasets import get_dataset_infos from transformers import AutoConfig, GenerationConfig from huggingface_hub import whoami from typing import Optional, Union """ Still TODO: - validate the user is PRO - check the output dataset token is valid (hardcoded for now as a secret) - validate max model params """ class GenerationStatus(Enum): PENDING = "PENDING" RUNNING = "RUNNING" COMPLETED = "COMPLETED" FAILED = "FAILED" MAX_SAMPLES_PRO = 10000 # max number of samples for PRO/Enterprise users MAX_SAMPLES_FREE = 100 # max number of samples for free users MAX_TOKENS = 8192 MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now) # Cache for model generation parameters MODEL_GEN_PARAMS_CACHE = {} @dataclass class GenerationRequest: id: str created_at: str status: GenerationStatus input_dataset_name: str input_dataset_config: str input_dataset_split: str output_dataset_name: str prompt_column: str model_name_or_path: str model_revision: str model_token: str | None system_prompt: str | None max_tokens: int temperature: float top_k: int top_p: float input_dataset_token: str | None output_dataset_token: str username: str email: str num_output_examples: int private: bool = False num_retries: int = 0 SUPPORTED_MODELS = [ "Qwen/Qwen3-4B-Instruct-2507", "Qwen/Qwen3-30B-A3B-Instruct-2507", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "baidu/ERNIE-4.5-21B-A3B-Thinking", "LLM360/K2-Think", "openai/gpt-oss-20b", ] def fetch_model_generation_params(model_name: str) -> dict: """Fetch generation parameters and model config from the hub""" default_params = { "max_tokens": 1024, "temperature": 0.7, "top_k": 50, "top_p": 0.95, "max_position_embeddings": 2048, "recommended_max_tokens": 1024 } try: print(f"Attempting to fetch configs for: {model_name}") # Always try to load the model config first for max_position_embeddings model_config = None max_position_embeddings = default_params["max_position_embeddings"] try: output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN") model_config = AutoConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token) max_position_embeddings = getattr(model_config, 'max_position_embeddings', default_params["max_position_embeddings"]) print(f"Loaded AutoConfig for {model_name}, max_position_embeddings: {max_position_embeddings}") except Exception as e: print(f"Failed to load AutoConfig for {model_name}: {e}") # Calculate recommended max tokens (conservative estimate) # Leave some room for the prompt, so use ~75% of max_position_embeddings recommended_max_tokens = min(int(max_position_embeddings * 0.75), MAX_TOKENS) recommended_max_tokens = max(256, recommended_max_tokens) # Ensure minimum # Try to load the generation config gen_config = None try: gen_config = GenerationConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token) print(f"Successfully loaded generation config for {model_name}") except Exception as e: print(f"Failed to load GenerationConfig for {model_name}: {e}") # Extract parameters from generation config or use model-specific defaults if gen_config: params = { "max_tokens": getattr(gen_config, 'max_new_tokens', None) or getattr(gen_config, 'max_length', recommended_max_tokens), "temperature": getattr(gen_config, 'temperature', default_params["temperature"]), "top_k": getattr(gen_config, 'top_k', default_params["top_k"]), "top_p": getattr(gen_config, 'top_p', default_params["top_p"]), "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens } else: # Use model-specific defaults based on model name if "qwen" in model_name.lower(): params = {"max_tokens": recommended_max_tokens, "temperature": 0.7, "top_k": 50, "top_p": 0.8, "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens} elif "llama" in model_name.lower(): params = {"max_tokens": recommended_max_tokens, "temperature": 0.6, "top_k": 40, "top_p": 0.9, "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens} elif "ernie" in model_name.lower(): params = {"max_tokens": min(recommended_max_tokens, 1024), "temperature": 0.7, "top_k": 50, "top_p": 0.95, "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens} else: params = dict(default_params) params["max_position_embeddings"] = max_position_embeddings params["recommended_max_tokens"] = recommended_max_tokens # Ensure parameters are within valid ranges params["max_tokens"] = max(256, min(params["max_tokens"], MAX_TOKENS, params["recommended_max_tokens"])) params["temperature"] = max(0.0, min(params["temperature"], 2.0)) params["top_k"] = max(5, min(params["top_k"], 100)) params["top_p"] = max(0.0, min(params["top_p"], 1.0)) print(f"Final params for {model_name}: {params}") return params except Exception as e: print(f"Could not fetch configs for {model_name}: {e}") return default_params def update_generation_params(model_name: str): """Update generation parameters based on selected model""" global MODEL_GEN_PARAMS_CACHE print(f"Updating generation parameters for model: {model_name}") print(f"Cache is empty: {len(MODEL_GEN_PARAMS_CACHE) == 0}") print(f"Current cache keys: {list(MODEL_GEN_PARAMS_CACHE.keys())}") # If cache is empty, try to populate it now if len(MODEL_GEN_PARAMS_CACHE) == 0: print("Cache is empty, attempting to populate now...") cache_all_model_params() if model_name in MODEL_GEN_PARAMS_CACHE: params = MODEL_GEN_PARAMS_CACHE[model_name] print(f"Found cached params for {model_name}: {params}") # Set the max_tokens slider maximum to the model's recommended max max_tokens_limit = min(params.get("recommended_max_tokens", MAX_TOKENS), MAX_TOKENS) return ( gr.update(value=params["max_tokens"], maximum=max_tokens_limit), # max_tokens with dynamic maximum gr.update(value=params["temperature"]), # temperature gr.update(value=params["top_k"]), # top_k gr.update(value=params["top_p"]) # top_p ) else: # Fallback to defaults if model not in cache print(f"Model {model_name} not found in cache, using defaults") return ( gr.update(value=1024, maximum=MAX_TOKENS), # max_tokens gr.update(value=0.7), # temperature gr.update(value=50), # top_k gr.update(value=0.95) # top_p ) def cache_all_model_params(): """Cache generation parameters for all supported models at startup""" global MODEL_GEN_PARAMS_CACHE print(f"Starting to cache parameters for {len(SUPPORTED_MODELS)} models...") print(f"Supported models: {SUPPORTED_MODELS}") for model_name in SUPPORTED_MODELS: try: print(f"Processing model: {model_name}") params = fetch_model_generation_params(model_name) MODEL_GEN_PARAMS_CACHE[model_name] = params print(f"Successfully cached params for {model_name}: {params}") except Exception as e: print(f"Exception while caching params for {model_name}: {e}") # Use default parameters if caching fails default_params = { "max_tokens": 1024, "temperature": 0.7, "top_k": 50, "top_p": 0.95, "max_position_embeddings": 2048, "recommended_max_tokens": 1024 } MODEL_GEN_PARAMS_CACHE[model_name] = default_params print(f"Using default params for {model_name}: {default_params}") print(f"Caching complete. Final cache contents:") for model, params in MODEL_GEN_PARAMS_CACHE.items(): print(f" {model}: {params}") print(f"Cache size: {len(MODEL_GEN_PARAMS_CACHE)} models") def verify_pro_status(token: Optional[Union[gr.OAuthToken, str]]) -> bool: """Verifies if the user is a Hugging Face PRO user or part of an enterprise org.""" if not token: return False if isinstance(token, gr.OAuthToken): token_str = token.token elif isinstance(token, str): token_str = token else: return False try: user_info = whoami(token=token_str) return ( user_info.get("isPro", False) or any(org.get("isEnterprise", False) for org in user_info.get("orgs", [])) ) except Exception as e: print(f"Could not verify user's PRO/Enterprise status: {e}") return False def validate_request(request: GenerationRequest, oauth_token: Optional[Union[gr.OAuthToken, str]] = None) -> GenerationRequest: # checks that the request is valid # - input dataset exists and can be accessed with the provided token try: input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config] except Exception as e: raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.") # check that the input dataset split exists if request.input_dataset_split not in input_dataset_info.splits: raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}") # if num_output_examples is 0, set it to the number of examples in the input dataset split if request.num_output_examples == 0: request.num_output_examples = input_dataset_info.splits[request.input_dataset_split].num_examples else: if request.num_output_examples > input_dataset_info.splits[request.input_dataset_split].num_examples: raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the number of examples in the input dataset split {input_dataset_info.splits[request.input_dataset_split].num_examples}.") request.input_dataset_split = f"{request.input_dataset_split}[:{request.num_output_examples}]" # Check user tier and apply appropriate limits # Anonymous users (oauth_token is None) are treated as free tier is_pro = verify_pro_status(oauth_token) if oauth_token else False max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE if request.num_output_examples > max_samples: if oauth_token is None: user_tier = "non-signed-in" else: user_tier = "PRO/Enterprise" if is_pro else "free" raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the max limit of {max_samples} for {user_tier} users.") # check the prompt column exists in the dataset if request.prompt_column not in input_dataset_info.features: raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}") # This is currently not supported, the output dataset will be created under the org 'synthetic-data-universe' # check output_dataset name is valid if request.output_dataset_name.count("/") != 1: raise Exception("Output dataset name must be in the format 'dataset_name', e.g., 'my-dataset'. The dataset will be created under the org 'synthetic-data-universe/my-dataset'.") # check the output dataset is valid and accessible with the provided token try: output_dataset_info = get_dataset_infos(request.output_dataset_name, token=request.output_dataset_token) raise Exception(f"Output dataset {request.output_dataset_name} already exists. Please choose a different name.") except Exception: pass # dataset does not exist, which is expected # check the output dataset name doesn't already exist in the database try: url = os.getenv("SUPABASE_URL") key = os.getenv("SUPABASE_KEY") if url and key: supabase = create_client( url, key, options=ClientOptions( postgrest_client_timeout=10, storage_client_timeout=10, schema="public", ) ) existing_request = supabase.table("gen-requests").select("id").eq("output_dataset_name", request.output_dataset_name).execute() if existing_request.data: raise Exception(f"Output dataset {request.output_dataset_name} is already being generated or has been requested. Please choose a different name.") except Exception as e: # If it's our custom exception about dataset already existing, re-raise it if "already being generated" in str(e): raise e # Otherwise, ignore database connection errors and continue pass # check the models exists try: model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, force_download=True, token=False ) except Exception as e: print(e) raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed. The model may be private or gated, which is not supported at this time.") # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS if model_config.max_position_embeddings < request.max_tokens: raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.") if request.max_tokens > MAX_TOKENS: raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.") # check sampling parameters are valid if request.temperature < 0.0 or request.temperature > 2.0: raise Exception("Temperature must be between 0.0 and 2.0") if request.top_k < 1 or request.top_k > 100: raise Exception("Top K must be between 1 and 100") if request.top_p < 0.0 or request.top_p > 1.0: raise Exception("Top P must be between 0.0 and 1.0") return request def load_dataset_info(dataset_name, model_name, oauth_token=None, dataset_token=None): """Load dataset information and return choices for dropdowns""" if not dataset_name.strip(): return ( gr.update(choices=[], value=None), # config gr.update(choices=[], value=None), # split gr.update(choices=[], value=None), # prompt_column gr.update(value="", interactive=True), # output_dataset_name gr.update(interactive=False), # num_output_samples "Please enter a dataset name first." ) try: # Get dataset info dataset_infos = get_dataset_infos(dataset_name) if not dataset_infos: raise Exception("No configs found for this dataset") # Get available configs config_choices = list(dataset_infos.keys()) default_config = config_choices[0] if config_choices else None # Get splits and features for the default config if default_config: config_info = dataset_infos[default_config] split_choices = list(config_info.splits.keys()) default_split = split_choices[0] if split_choices else None # Get column choices (features) column_choices = list(config_info.features.keys()) default_column = None # Try to find a likely prompt column for col in column_choices: if any(keyword in col.lower() for keyword in ['prompt', 'text', 'question', 'input']): default_column = col break if not default_column and column_choices: default_column = column_choices[0] # Get sample count for the default split dataset_sample_count = config_info.splits[default_split].num_examples if default_split else 0 else: split_choices = [] column_choices = [] default_split = None default_column = None dataset_sample_count = 0 # Determine user limits is_pro = verify_pro_status(oauth_token) if oauth_token else False user_max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE # Set slider maximum to the minimum of dataset samples and user limit slider_max = min(dataset_sample_count, user_max_samples) if dataset_sample_count > 0 else user_max_samples # Get username from OAuth token username = "anonymous" if oauth_token: try: if isinstance(oauth_token, gr.OAuthToken): token_str = oauth_token.token elif isinstance(oauth_token, str): token_str = oauth_token else: token_str = None if token_str: user_info = whoami(token=token_str) username = user_info.get("name", "anonymous") except Exception: username = "anonymous" # Generate a suggested output dataset name: username-model-dataset dataset_base_name = dataset_name.split('/')[-1] if '/' in dataset_name else dataset_name # Extract model short name (e.g., "Qwen/Qwen3-4B-Instruct-2507" -> "qwen3-4b") model_short_name = model_name.split('/')[-1].lower() # Remove common suffixes and simplify model_short_name = model_short_name.replace('-instruct', '').replace('-2507', '').replace('_', '-') # Take first part if it's still long if len(model_short_name) > 15: parts = model_short_name.split('-') model_short_name = '-'.join(parts[:2]) if len(parts) > 1 else parts[0][:15] # Build the output name: username-model-dataset suggested_output_name = f"{username}-{model_short_name}-{dataset_base_name}" # Limit to 86 characters if len(suggested_output_name) > 86: # Truncate dataset name to fit within limit available_for_dataset = 86 - len(username) - len(model_short_name) - 2 # -2 for the hyphens if available_for_dataset > 0: dataset_base_name = dataset_base_name[:available_for_dataset] suggested_output_name = f"{username}-{model_short_name}-{dataset_base_name}" else: suggested_output_name = f"{username}-{model_short_name}" status_msg = f"✅ Dataset info loaded successfully! Found {len(config_choices)} config(s), {len(split_choices)} split(s), and {len(column_choices)} column(s)." if dataset_sample_count > 0: status_msg += f" Dataset has {dataset_sample_count:,} samples." if dataset_sample_count > user_max_samples: user_tier = "PRO/Enterprise" if is_pro else "free tier" status_msg += f" Limited to {user_max_samples:,} samples for {user_tier} users." return ( gr.update(choices=config_choices, value=default_config, interactive=True), # config gr.update(choices=split_choices, value=default_split, interactive=True), # split gr.update(choices=column_choices, value=default_column, interactive=True), # prompt_column gr.update(value=suggested_output_name, interactive=True), # output_dataset_name gr.update(interactive=True, maximum=slider_max, value=0), # num_output_samples status_msg ) except Exception as e: return ( gr.update(choices=[], value=None, interactive=False), # config gr.update(choices=[], value=None, interactive=False), # split gr.update(choices=[], value=None, interactive=False), # prompt_column gr.update(value="", interactive=False), # output_dataset_name gr.update(interactive=False), # num_output_samples f"❌ Error loading dataset info: {str(e)}" ) def add_request_to_db(request: GenerationRequest): url: str = os.getenv("SUPABASE_URL") key: str = os.getenv("SUPABASE_KEY") try: supabase: Client = create_client( url, key, options=ClientOptions( postgrest_client_timeout=10, storage_client_timeout=10, schema="public", ) ) data = { "status": request.status.value, "input_dataset_name": request.input_dataset_name, "input_dataset_config": request.input_dataset_config, "input_dataset_split": request.input_dataset_split, "output_dataset_name": request.output_dataset_name, "prompt_column": request.prompt_column, "model_name_or_path": request.model_name_or_path, "model_revision": request.model_revision, "model_token": request.model_token, "system_prompt": request.system_prompt, "max_tokens": request.max_tokens, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "input_dataset_token": request.input_dataset_token, "output_dataset_token": request.output_dataset_token, "username": request.username, "email": request.email, "num_output_examples": request.num_output_examples, "private": request.private, } supabase.table("gen-requests").insert(data).execute() except Exception as e: raise Exception("Failed to add request to database") def get_generation_stats_safe(): """Safely fetch generation request statistics with proper error handling""" try: url = os.getenv("SUPABASE_URL") key = os.getenv("SUPABASE_KEY") if not url or not key: raise Exception("Missing SUPABASE_URL or SUPABASE_KEY environment variables") supabase = create_client( url, key, options=ClientOptions( postgrest_client_timeout=10, storage_client_timeout=10, schema="public", ) ) # Fetch data excluding sensitive token fields response = supabase.table("gen-requests").select( "id, created_at, status, input_dataset_name, input_dataset_config, " "input_dataset_split, output_dataset_name, prompt_column, " "model_name_or_path, model_revision, max_tokens, temperature, " "top_k, top_p, username, num_output_examples, private" ).order("created_at", desc=True).limit(50).execute() return {"status": "success", "data": response.data} except Exception as e: return {"status": "error", "message": str(e), "data": []} # Old commented code removed - replaced with DatabaseManager and get_generation_stats_safe() def main(): # Cache model generation parameters at startup print("Caching model generation parameters...") cache_all_model_params() print("Model parameter caching complete.") with gr.Blocks(title="Synthetic Data Generation") as demo: gr.HTML("