Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import json | |
| import time | |
| import subprocess | |
| import threading | |
| import uuid | |
| from pathlib import Path | |
| from huggingface_hub import InferenceClient, HfFolder | |
| """ | |
| Shedify app - Using fine-tuned Llama 3.3 49B for document assistance | |
| """ | |
| # Model settings | |
| DEFAULT_MODEL = "Borislav18/Shedify" # Your Hugging Face username/model name | |
| LOCAL_MODEL = os.environ.get("LOCAL_MODEL", None) # Set this if testing locally | |
| # Get Hugging Face token | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # App title and description | |
| title = "Shedify - Document Assistant powered by Llama 3.3" | |
| description = """ | |
| This app uses a fine-tuned version of Llama 3.3 49B model trained on your documents. | |
| Ask questions about the documents, generate insights, or request summaries! | |
| """ | |
| # Initialize inference client with your model | |
| client = InferenceClient( | |
| DEFAULT_MODEL, | |
| token=HF_TOKEN, | |
| ) | |
| # Training status tracking | |
| class TrainingState: | |
| def __init__(self): | |
| self.status = "idle" # idle, running, success, failed | |
| self.progress = 0.0 # 0.0 to 1.0 | |
| self.message = "" | |
| self.id = str(uuid.uuid4())[:8] # Generate a unique ID for this session | |
| # Check if state file exists and load it | |
| self.state_file = Path("training_state.json") | |
| self.load_state() | |
| def load_state(self): | |
| """Load state from file if it exists""" | |
| if self.state_file.exists(): | |
| try: | |
| with open(self.state_file, "r") as f: | |
| state = json.load(f) | |
| self.status = state.get("status", "idle") | |
| self.progress = state.get("progress", 0.0) | |
| self.message = state.get("message", "") | |
| self.id = state.get("id", self.id) | |
| except Exception as e: | |
| print(f"Error loading state: {e}") | |
| def save_state(self): | |
| """Save current state to file""" | |
| try: | |
| with open(self.state_file, "w") as f: | |
| json.dump({ | |
| "status": self.status, | |
| "progress": self.progress, | |
| "message": self.message, | |
| "id": self.id | |
| }, f) | |
| except Exception as e: | |
| print(f"Error saving state: {e}") | |
| def update(self, status=None, progress=None, message=None): | |
| """Update state and save it""" | |
| if status is not None: | |
| self.status = status | |
| if progress is not None: | |
| self.progress = progress | |
| if message is not None: | |
| self.message = message | |
| self.save_state() | |
| return self.status, self.progress, self.message | |
| # Initialize the training state | |
| training_state = TrainingState() | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| messages = [{"role": "system", "content": system_message}] | |
| # Format history to match chat completion format | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| # Use streaming to get real-time responses | |
| for message in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| token = message.choices[0].delta.content | |
| response += token | |
| yield response | |
| def run_training_process(pdf_dir, output_name, progress_callback): | |
| """Run the PDF processing and fine-tuning process""" | |
| try: | |
| # Create processed_data directory if it doesn't exist | |
| os.makedirs("processed_data", exist_ok=True) | |
| # Update state | |
| progress_callback("running", 0.05, "Processing PDFs...") | |
| # Process PDFs | |
| pdf_process = subprocess.run( | |
| ["python", "pdf_processor.py", "--pdf_dir", pdf_dir, "--output_dir", "processed_data"], | |
| capture_output=True, | |
| text=True | |
| ) | |
| if pdf_process.returncode != 0: | |
| progress_callback("failed", 0.0, f"PDF processing failed: {pdf_process.stderr}") | |
| return False | |
| # Update state | |
| progress_callback("running", 0.3, "PDFs processed. Starting fine-tuning...") | |
| # Get Hugging Face token | |
| hf_token = HF_TOKEN or HfFolder.get_token() | |
| if not hf_token: | |
| progress_callback("failed", 0.0, "No Hugging Face token found. Please set the HF_TOKEN environment variable.") | |
| return False | |
| # Run fine-tuning | |
| finetune_process = subprocess.run( | |
| [ | |
| "python", "finetune_llama3.py", | |
| "--dataset_path", "processed_data/training_data", | |
| "--hub_model_id", f"Borislav18/{output_name}", | |
| "--epochs", "1", # Starting with 1 epoch for quicker feedback | |
| "--gradient_accumulation_steps", "4" | |
| ], | |
| env={**os.environ, "HF_TOKEN": hf_token}, | |
| capture_output=True, | |
| text=True | |
| ) | |
| if finetune_process.returncode != 0: | |
| progress_callback("failed", 0.0, f"Fine-tuning failed: {finetune_process.stderr}") | |
| return False | |
| # Update state | |
| progress_callback("success", 1.0, f"Training complete! Model pushed to Hugging Face as Borislav18/{output_name}") | |
| return True | |
| except Exception as e: | |
| progress_callback("failed", 0.0, f"Training process failed with error: {str(e)}") | |
| return False | |
| def training_thread(pdf_dir, output_name): | |
| """Background thread for running training""" | |
| def progress_callback(status, progress, message): | |
| training_state.update(status, progress, message) | |
| # Simulate progress updates for UI feedback | |
| progress_callback("running", 0.01, "Starting training process...") | |
| # Run the actual training process | |
| run_training_process(pdf_dir, output_name, progress_callback) | |
| def start_training(pdf_dir, output_name): | |
| """Start the training process in a background thread""" | |
| if not pdf_dir or not output_name: | |
| return "Please provide both a PDF directory and output model name", 0.0, "idle" | |
| # Check if already running | |
| if training_state.status == "running": | |
| return f"Training already in progress: {training_state.message}", training_state.progress, training_state.status | |
| # Start background thread | |
| thread = threading.Thread( | |
| target=training_thread, | |
| args=(pdf_dir, output_name), | |
| daemon=True | |
| ) | |
| thread.start() | |
| return "Training started...", 0.0, "running" | |
| def get_training_status(): | |
| """Get the current training status for UI updates""" | |
| return training_state.message, training_state.progress, training_state.status | |
| # Create the main application | |
| with gr.Blocks(title="Shedify - Document Assistant") as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Column(scale=1): | |
| # Training controls | |
| with gr.Group(visible=True): | |
| gr.Markdown("## Train New Model") | |
| pdf_dir = gr.Textbox(label="PDF Directory", placeholder="Path to directory containing PDFs") | |
| output_name = gr.Textbox(label="Model Name", placeholder="Name for your fine-tuned model", value="Shedify-v1") | |
| train_btn = gr.Button("Start Training") | |
| training_message = gr.Textbox(label="Training Status", interactive=False) | |
| training_progress = gr.Slider( | |
| minimum=0, maximum=1, value=0, | |
| label="Progress", interactive=False | |
| ) | |
| training_status = gr.Textbox(visible=False) | |
| # Chat interface | |
| chatbot = gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are an AI assistant trained on specific documents. Answer questions based only on information from these documents. If you don't know the answer from the documents, say so clearly.", | |
| label="System message" | |
| ), | |
| gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| examples=[ | |
| ["Summarize the key points from all documents you were trained on."], | |
| ["What are the main themes discussed in the documents?"], | |
| ["Extract the most important concepts mentioned in the documents."], | |
| ["Explain the relationship between the different topics in the documents."], | |
| ["What recommendations or conclusions can be drawn from the documents?"], | |
| ] | |
| ) | |
| # Set up event handlers | |
| train_btn.click( | |
| fn=start_training, | |
| inputs=[pdf_dir, output_name], | |
| outputs=[training_message, training_progress, training_status] | |
| ) | |
| # Setup periodic status checking | |
| demo.load(get_training_status, outputs=[training_message, training_progress, training_status]) | |
| def update_ui(message, progress, status): | |
| is_running = status == "running" | |
| color = { | |
| "idle": "gray", | |
| "running": "blue", | |
| "success": "green", | |
| "failed": "red" | |
| }.get(status, "gray") | |
| message_with_color = f"<span style='color: {color}'>{message}</span>" | |
| return message_with_color, progress, train_btn.update(interactive=not is_running) | |
| training_status.change( | |
| fn=update_ui, | |
| inputs=[training_message, training_progress, training_status], | |
| outputs=[training_message, training_progress, train_btn] | |
| ) | |
| # Set interval to update the UI every few seconds | |
| demo.add_event_handler("load", None, None, None, None, interval=5.0, inputs=None, outputs=[training_message, training_progress, training_status], _js=None, fn=get_training_status) | |
| if __name__ == "__main__": | |
| demo.launch() | |