gemma-fine-tuning / data_processing.py
fsadeek's picture
added some features
557c6b6
import os
import json
import csv
import pandas as pd
import random
def validate_dataset(file_path, options):
"""
Validates that a dataset file can be processed with the given options.
Args:
file_path: Path to the dataset file
options: Dictionary of processing options
Returns:
Tuple of (is_valid, message)
"""
if not os.path.exists(file_path):
return False, f"File not found: {file_path}"
file_format = options.get("format", "").lower()
try:
if file_format == "csv":
# Validate CSV format
separator = options.get("csv_separator", ",")
prompt_col = options.get("csv_prompt_col", "prompt")
completion_col = options.get("csv_completion_col", "completion")
df = pd.read_csv(file_path, sep=separator)
if prompt_col not in df.columns:
return False, f"Prompt column '{prompt_col}' not found in CSV file"
if completion_col not in df.columns:
return False, f"Completion column '{completion_col}' not found in CSV file"
# Check for empty values
if df[prompt_col].isnull().any():
return False, "CSV file contains empty prompt values"
if df[completion_col].isnull().any():
return False, "CSV file contains empty completion values"
elif file_format == "jsonl":
# Validate JSONL format
prompt_key = options.get("jsonl_prompt_key", "prompt")
completion_key = options.get("jsonl_completion_key", "completion")
with open(file_path, 'r', encoding='utf-8') as f:
line_count = 0
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
line_count += 1
if prompt_key not in data:
return False, f"Prompt key '{prompt_key}' not found in JSONL at line {line_count}"
if completion_key not in data:
return False, f"Completion key '{completion_key}' not found in JSONL at line {line_count}"
if not data[prompt_key] or not isinstance(data[prompt_key], str):
return False, f"Invalid prompt value at line {line_count}"
if not data[completion_key] or not isinstance(data[completion_key], str):
return False, f"Invalid completion value at line {line_count}"
if line_count == 0:
return False, "JSONL file is empty"
elif file_format == "plain text":
# Validate plain text format
separator = options.get("text_separator", "###")
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
parts = content.split(separator)
if len(parts) < 3: # Need at least one prompt and one completion
return False, f"Text file doesn't contain enough sections separated by '{separator}'"
# Check if there's an odd number of parts (should be prompt, completion, prompt, completion, ...)
if len(parts) % 2 == 0:
return False, f"Text file has an invalid number of sections separated by '{separator}'"
else:
return False, f"Unsupported format: {file_format}"
return True, "Dataset is valid"
except Exception as e:
return False, f"Error validating dataset: {str(e)}"
def process_dataset(file_path, options):
"""
Processes a dataset file according to the given options.
Args:
file_path: Path to the dataset file
options: Dictionary of processing options
Returns:
Tuple of (processed_data, stats, preview)
"""
file_format = options.get("format", "").lower()
if file_format == "csv":
return _process_csv(file_path, options)
elif file_format == "jsonl":
return _process_jsonl(file_path, options)
elif file_format == "plain text":
return _process_text(file_path, options)
else:
raise ValueError(f"Unsupported format: {file_format}")
def _process_csv(file_path, options):
"""Process a CSV dataset file."""
separator = options.get("csv_separator", ",")
prompt_col = options.get("csv_prompt_col", "prompt")
completion_col = options.get("csv_completion_col", "completion")
df = pd.read_csv(file_path, sep=separator)
# Extract prompts and completions
data = []
for _, row in df.iterrows():
data.append({
"prompt": str(row[prompt_col]),
"completion": str(row[completion_col])
})
# Generate statistics
stats = {
"num_examples": len(data),
"avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data),
"avg_completion_length": sum(len(item["completion"]) for item in data) / len(data),
"format": "csv"
}
# Create a preview DataFrame (showing first 5 rows)
preview = df[[prompt_col, completion_col]].head(5)
return data, stats, preview
def _process_jsonl(file_path, options):
"""Process a JSONL dataset file."""
prompt_key = options.get("jsonl_prompt_key", "prompt")
completion_key = options.get("jsonl_completion_key", "completion")
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
item = json.loads(line)
data.append({
"prompt": item[prompt_key],
"completion": item[completion_key]
})
# Generate statistics
stats = {
"num_examples": len(data),
"avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data),
"avg_completion_length": sum(len(item["completion"]) for item in data) / len(data),
"format": "jsonl"
}
# Create a preview DataFrame
preview_data = []
for i, item in enumerate(data[:5]):
preview_data.append({
"prompt": item["prompt"],
"completion": item["completion"]
})
preview = pd.DataFrame(preview_data)
return data, stats, preview
def _process_text(file_path, options):
"""Process a plain text dataset file."""
separator = options.get("text_separator", "###")
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
parts = content.split(separator)
data = []
for i in range(0, len(parts) - 1, 2):
prompt = parts[i].strip()
completion = parts[i + 1].strip()
if prompt and completion:
data.append({
"prompt": prompt,
"completion": completion
})
# Generate statistics
stats = {
"num_examples": len(data),
"avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data),
"avg_completion_length": sum(len(item["completion"]) for item in data) / len(data),
"format": "text"
}
# Create a preview DataFrame
preview_data = []
for i, item in enumerate(data[:5]):
preview_data.append({
"prompt": item["prompt"],
"completion": item["completion"]
})
preview = pd.DataFrame(preview_data)
return data, stats, preview
def format_for_training(dataset, tokenizer, max_length=512):
"""
Formats a processed dataset for training with Gemma.
Args:
dataset: List of prompt/completion pairs
tokenizer: Tokenizer for the model
max_length: Maximum sequence length
Returns:
Dictionary of training data
"""
input_ids = []
labels = []
attention_mask = []
for item in dataset:
prompt = item["prompt"]
completion = item["completion"]
# Format as the model expects
full_text = f"{prompt}{tokenizer.eos_token}{completion}{tokenizer.eos_token}"
# Tokenize
encoded = tokenizer(full_text, max_length=max_length, padding="max_length", truncation=True)
# For input_ids, we use the full sequence
input_ids.append(encoded["input_ids"])
attention_mask.append(encoded["attention_mask"])
# For labels, we set the prompt tokens to -100 so they're ignored in loss calculation
prompt_encoded = tokenizer(f"{prompt}{tokenizer.eos_token}", add_special_tokens=False)
prompt_length = len(prompt_encoded["input_ids"])
# Create label tensor: -100 for prompt tokens (ignored in loss), actual token IDs for completion
label = [-100] * prompt_length + encoded["input_ids"][prompt_length:]
# Pad to max_length
if len(label) < max_length:
label = label + [-100] * (max_length - len(label))
else:
label = label[:max_length]
labels.append(label)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
def create_train_val_split(dataset, val_size=0.1, seed=42):
"""
Splits a dataset into training and validation sets.
Args:
dataset: List of examples
val_size: Fraction of examples to use for validation
seed: Random seed for reproducibility
Returns:
Tuple of (train_dataset, val_dataset)
"""
random.seed(seed)
random.shuffle(dataset)
val_count = max(1, int(len(dataset) * val_size))
val_dataset = dataset[:val_count]
train_dataset = dataset[val_count:]
return train_dataset, val_dataset