|
|
from typing import Dict, List, Any |
|
|
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str, **kwargs: Any): |
|
|
""" |
|
|
Initialize the inference pipeline and load LoRA adapters. |
|
|
""" |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
model_id = "Qwen/Qwen-Image" |
|
|
|
|
|
self.pipe = QwenImagePipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=self.torch_dtype, |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
self.pipe.load_lora_weights( |
|
|
model_dir, |
|
|
weight_name="pytorch_lora_weights.safetensors", |
|
|
adapter_name="default", |
|
|
) |
|
|
self.pipe.set_adapters(["default"], [1]) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Perform image generation using the model. |
|
|
|
|
|
Args: |
|
|
data: dict with keys: |
|
|
- prompt: str |
|
|
- negative_prompt: str (optional) |
|
|
- width: int |
|
|
- height: int |
|
|
- steps: int |
|
|
- seed: int |
|
|
- lora_denoising_steps: int |
|
|
|
|
|
Returns: |
|
|
List of dicts containing generated image (as PIL and base64). |
|
|
""" |
|
|
data = data.get("inputs", "") |
|
|
prompt = data.get("prompt", " ") |
|
|
negative_prompt = data.get("negative_prompt", "") |
|
|
width = int(data.get("width", 1024)) |
|
|
height = int(data.get("height", 1024)) |
|
|
steps = int(data.get("steps", 50)) |
|
|
lora_denoising_steps = int(data.get("lora_denoising_steps", 10)) |
|
|
seed = int(data.get("seed", -1)) |
|
|
|
|
|
if not prompt: |
|
|
raise ValueError("Prompt cannot be empty") |
|
|
|
|
|
def callback_on_step_end(pipe, i, t, callback_kwargs): |
|
|
|
|
|
if i == lora_denoising_steps: |
|
|
pipe.disable_lora() |
|
|
return {} |
|
|
|
|
|
|
|
|
self.pipe.enable_lora() |
|
|
|
|
|
|
|
|
if seed == -1: |
|
|
seed = torch.randint(0, 2**32 - 1, (1,)).item() |
|
|
|
|
|
|
|
|
result = self.pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=steps, |
|
|
width=width, |
|
|
height=height, |
|
|
generator=torch.Generator(device=self.device).manual_seed(seed), |
|
|
callback_on_step_end=callback_on_step_end, |
|
|
) |
|
|
|
|
|
return result.images[0] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
handler = EndpointHandler() |
|
|
result = handler({ |
|
|
"prompt": "A futuristic Indian city skyline at sunset, vibrant colors", |
|
|
"width": 1024, |
|
|
"height": 1024, |
|
|
"steps": 30, |
|
|
"seed": -1, |
|
|
"lora_denoising_steps": 10, |
|
|
}) |
|
|
|
|
|
result.show() |
|
|
|