Tasmay-Tib's picture
update handler.py
fa3113f
from typing import Dict, List, Any
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
import torch
import base64
import io
from PIL import Image
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any):
"""
Initialize the inference pipeline and load LoRA adapters.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model_id = "Qwen/Qwen-Image"
self.pipe = QwenImagePipeline.from_pretrained(
model_id,
torch_dtype=self.torch_dtype,
).to(self.device)
# Load LoRA
self.pipe.load_lora_weights(
model_dir,
weight_name="pytorch_lora_weights.safetensors",
adapter_name="default",
)
self.pipe.set_adapters(["default"], [1])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Perform image generation using the model.
Args:
data: dict with keys:
- prompt: str
- negative_prompt: str (optional)
- width: int
- height: int
- steps: int
- seed: int
- lora_denoising_steps: int
Returns:
List of dicts containing generated image (as PIL and base64).
"""
data = data.get("inputs", "")
prompt = data.get("prompt", " ")
negative_prompt = data.get("negative_prompt", "")
width = int(data.get("width", 1024))
height = int(data.get("height", 1024))
steps = int(data.get("steps", 50))
lora_denoising_steps = int(data.get("lora_denoising_steps", 10))
seed = int(data.get("seed", -1))
if not prompt:
raise ValueError("Prompt cannot be empty")
def callback_on_step_end(pipe, i, t, callback_kwargs):
# Disable LoRA after specified denoising step
if i == lora_denoising_steps:
pipe.disable_lora()
return {}
# Enable LoRA initially
self.pipe.enable_lora()
# Handle random seed
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
# Run inference
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
width=width,
height=height,
generator=torch.Generator(device=self.device).manual_seed(seed),
callback_on_step_end=callback_on_step_end,
)
return result.images[0]
if __name__ == "__main__":
# Example run
handler = EndpointHandler()
result = handler({
"prompt": "A futuristic Indian city skyline at sunset, vibrant colors",
"width": 1024,
"height": 1024,
"steps": 30,
"seed": -1,
"lora_denoising_steps": 10,
})
result.show()