from typing import Dict, List, Any from diffusers import QwenImagePipeline, QwenImageTransformer2DModel import torch import base64 import io from PIL import Image class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any): """ Initialize the inference pipeline and load LoRA adapters. """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model_id = "Qwen/Qwen-Image" self.pipe = QwenImagePipeline.from_pretrained( model_id, torch_dtype=self.torch_dtype, ).to(self.device) # Load LoRA self.pipe.load_lora_weights( model_dir, weight_name="pytorch_lora_weights.safetensors", adapter_name="default", ) self.pipe.set_adapters(["default"], [1]) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Perform image generation using the model. Args: data: dict with keys: - prompt: str - negative_prompt: str (optional) - width: int - height: int - steps: int - seed: int - lora_denoising_steps: int Returns: List of dicts containing generated image (as PIL and base64). """ data = data.get("inputs", "") prompt = data.get("prompt", " ") negative_prompt = data.get("negative_prompt", "") width = int(data.get("width", 1024)) height = int(data.get("height", 1024)) steps = int(data.get("steps", 50)) lora_denoising_steps = int(data.get("lora_denoising_steps", 10)) seed = int(data.get("seed", -1)) if not prompt: raise ValueError("Prompt cannot be empty") def callback_on_step_end(pipe, i, t, callback_kwargs): # Disable LoRA after specified denoising step if i == lora_denoising_steps: pipe.disable_lora() return {} # Enable LoRA initially self.pipe.enable_lora() # Handle random seed if seed == -1: seed = torch.randint(0, 2**32 - 1, (1,)).item() # Run inference result = self.pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=steps, width=width, height=height, generator=torch.Generator(device=self.device).manual_seed(seed), callback_on_step_end=callback_on_step_end, ) return result.images[0] if __name__ == "__main__": # Example run handler = EndpointHandler() result = handler({ "prompt": "A futuristic Indian city skyline at sunset, vibrant colors", "width": 1024, "height": 1024, "steps": 30, "seed": -1, "lora_denoising_steps": 10, }) result.show()