Spaces:
Running
Running
| import os | |
| import json | |
| import csv | |
| import pandas as pd | |
| import random | |
| def validate_dataset(file_path, options): | |
| """ | |
| Validates that a dataset file can be processed with the given options. | |
| Args: | |
| file_path: Path to the dataset file | |
| options: Dictionary of processing options | |
| Returns: | |
| Tuple of (is_valid, message) | |
| """ | |
| if not os.path.exists(file_path): | |
| return False, f"File not found: {file_path}" | |
| file_format = options.get("format", "").lower() | |
| try: | |
| if file_format == "csv": | |
| # Validate CSV format | |
| separator = options.get("csv_separator", ",") | |
| prompt_col = options.get("csv_prompt_col", "prompt") | |
| completion_col = options.get("csv_completion_col", "completion") | |
| df = pd.read_csv(file_path, sep=separator) | |
| if prompt_col not in df.columns: | |
| return False, f"Prompt column '{prompt_col}' not found in CSV file" | |
| if completion_col not in df.columns: | |
| return False, f"Completion column '{completion_col}' not found in CSV file" | |
| # Check for empty values | |
| if df[prompt_col].isnull().any(): | |
| return False, "CSV file contains empty prompt values" | |
| if df[completion_col].isnull().any(): | |
| return False, "CSV file contains empty completion values" | |
| elif file_format == "jsonl": | |
| # Validate JSONL format | |
| prompt_key = options.get("jsonl_prompt_key", "prompt") | |
| completion_key = options.get("jsonl_completion_key", "completion") | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| line_count = 0 | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| data = json.loads(line) | |
| line_count += 1 | |
| if prompt_key not in data: | |
| return False, f"Prompt key '{prompt_key}' not found in JSONL at line {line_count}" | |
| if completion_key not in data: | |
| return False, f"Completion key '{completion_key}' not found in JSONL at line {line_count}" | |
| if not data[prompt_key] or not isinstance(data[prompt_key], str): | |
| return False, f"Invalid prompt value at line {line_count}" | |
| if not data[completion_key] or not isinstance(data[completion_key], str): | |
| return False, f"Invalid completion value at line {line_count}" | |
| if line_count == 0: | |
| return False, "JSONL file is empty" | |
| elif file_format == "plain text": | |
| # Validate plain text format | |
| separator = options.get("text_separator", "###") | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| parts = content.split(separator) | |
| if len(parts) < 3: # Need at least one prompt and one completion | |
| return False, f"Text file doesn't contain enough sections separated by '{separator}'" | |
| # Check if there's an odd number of parts (should be prompt, completion, prompt, completion, ...) | |
| if len(parts) % 2 == 0: | |
| return False, f"Text file has an invalid number of sections separated by '{separator}'" | |
| else: | |
| return False, f"Unsupported format: {file_format}" | |
| return True, "Dataset is valid" | |
| except Exception as e: | |
| return False, f"Error validating dataset: {str(e)}" | |
| def process_dataset(file_path, options): | |
| """ | |
| Processes a dataset file according to the given options. | |
| Args: | |
| file_path: Path to the dataset file | |
| options: Dictionary of processing options | |
| Returns: | |
| Tuple of (processed_data, stats, preview) | |
| """ | |
| file_format = options.get("format", "").lower() | |
| if file_format == "csv": | |
| return _process_csv(file_path, options) | |
| elif file_format == "jsonl": | |
| return _process_jsonl(file_path, options) | |
| elif file_format == "plain text": | |
| return _process_text(file_path, options) | |
| else: | |
| raise ValueError(f"Unsupported format: {file_format}") | |
| def _process_csv(file_path, options): | |
| """Process a CSV dataset file.""" | |
| separator = options.get("csv_separator", ",") | |
| prompt_col = options.get("csv_prompt_col", "prompt") | |
| completion_col = options.get("csv_completion_col", "completion") | |
| df = pd.read_csv(file_path, sep=separator) | |
| # Extract prompts and completions | |
| data = [] | |
| for _, row in df.iterrows(): | |
| data.append({ | |
| "prompt": str(row[prompt_col]), | |
| "completion": str(row[completion_col]) | |
| }) | |
| # Generate statistics | |
| stats = { | |
| "num_examples": len(data), | |
| "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data), | |
| "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data), | |
| "format": "csv" | |
| } | |
| # Create a preview DataFrame (showing first 5 rows) | |
| preview = df[[prompt_col, completion_col]].head(5) | |
| return data, stats, preview | |
| def _process_jsonl(file_path, options): | |
| """Process a JSONL dataset file.""" | |
| prompt_key = options.get("jsonl_prompt_key", "prompt") | |
| completion_key = options.get("jsonl_completion_key", "completion") | |
| data = [] | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| item = json.loads(line) | |
| data.append({ | |
| "prompt": item[prompt_key], | |
| "completion": item[completion_key] | |
| }) | |
| # Generate statistics | |
| stats = { | |
| "num_examples": len(data), | |
| "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data), | |
| "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data), | |
| "format": "jsonl" | |
| } | |
| # Create a preview DataFrame | |
| preview_data = [] | |
| for i, item in enumerate(data[:5]): | |
| preview_data.append({ | |
| "prompt": item["prompt"], | |
| "completion": item["completion"] | |
| }) | |
| preview = pd.DataFrame(preview_data) | |
| return data, stats, preview | |
| def _process_text(file_path, options): | |
| """Process a plain text dataset file.""" | |
| separator = options.get("text_separator", "###") | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| parts = content.split(separator) | |
| data = [] | |
| for i in range(0, len(parts) - 1, 2): | |
| prompt = parts[i].strip() | |
| completion = parts[i + 1].strip() | |
| if prompt and completion: | |
| data.append({ | |
| "prompt": prompt, | |
| "completion": completion | |
| }) | |
| # Generate statistics | |
| stats = { | |
| "num_examples": len(data), | |
| "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data), | |
| "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data), | |
| "format": "text" | |
| } | |
| # Create a preview DataFrame | |
| preview_data = [] | |
| for i, item in enumerate(data[:5]): | |
| preview_data.append({ | |
| "prompt": item["prompt"], | |
| "completion": item["completion"] | |
| }) | |
| preview = pd.DataFrame(preview_data) | |
| return data, stats, preview | |
| def format_for_training(dataset, tokenizer, max_length=512): | |
| """ | |
| Formats a processed dataset for training with Gemma. | |
| Args: | |
| dataset: List of prompt/completion pairs | |
| tokenizer: Tokenizer for the model | |
| max_length: Maximum sequence length | |
| Returns: | |
| Dictionary of training data | |
| """ | |
| input_ids = [] | |
| labels = [] | |
| attention_mask = [] | |
| for item in dataset: | |
| prompt = item["prompt"] | |
| completion = item["completion"] | |
| # Format as the model expects | |
| full_text = f"{prompt}{tokenizer.eos_token}{completion}{tokenizer.eos_token}" | |
| # Tokenize | |
| encoded = tokenizer(full_text, max_length=max_length, padding="max_length", truncation=True) | |
| # For input_ids, we use the full sequence | |
| input_ids.append(encoded["input_ids"]) | |
| attention_mask.append(encoded["attention_mask"]) | |
| # For labels, we set the prompt tokens to -100 so they're ignored in loss calculation | |
| prompt_encoded = tokenizer(f"{prompt}{tokenizer.eos_token}", add_special_tokens=False) | |
| prompt_length = len(prompt_encoded["input_ids"]) | |
| # Create label tensor: -100 for prompt tokens (ignored in loss), actual token IDs for completion | |
| label = [-100] * prompt_length + encoded["input_ids"][prompt_length:] | |
| # Pad to max_length | |
| if len(label) < max_length: | |
| label = label + [-100] * (max_length - len(label)) | |
| else: | |
| label = label[:max_length] | |
| labels.append(label) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": labels | |
| } | |
| def create_train_val_split(dataset, val_size=0.1, seed=42): | |
| """ | |
| Splits a dataset into training and validation sets. | |
| Args: | |
| dataset: List of examples | |
| val_size: Fraction of examples to use for validation | |
| seed: Random seed for reproducibility | |
| Returns: | |
| Tuple of (train_dataset, val_dataset) | |
| """ | |
| random.seed(seed) | |
| random.shuffle(dataset) | |
| val_count = max(1, int(len(dataset) * val_size)) | |
| val_dataset = dataset[:val_count] | |
| train_dataset = dataset[val_count:] | |
| return train_dataset, val_dataset | |