Spaces:
Sleeping
Sleeping
| # app.py - AI Background Changer API | |
| from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import numpy as np | |
| import os | |
| import uuid | |
| import cv2 | |
| import torch | |
| from PIL import Image, ImageDraw | |
| import io | |
| import json | |
| from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator | |
| from diffusers import StableDiffusionXLInpaintPipeline | |
| # ====================================================== | |
| # CONFIGURATION | |
| # ====================================================== | |
| app = FastAPI(title="AI Background Changer API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ====================================================== | |
| # DIRECTORIES | |
| # ====================================================== | |
| UPLOAD_DIR = "uploads" | |
| RESULT_DIR = "results" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(RESULT_DIR, exist_ok=True) | |
| MODEL_DIR = "models" | |
| SAM_CHECKPOINT = os.path.join(MODEL_DIR, "sam_vit_b_01ec64.pth") | |
| # ====================================================== | |
| # GLOBAL STATE (In-memory session storage) | |
| # ====================================================== | |
| sessions = {} | |
| class SessionData: | |
| def __init__(self): | |
| self.original_image = None | |
| self.current_mask = None | |
| self.rgba_image = None | |
| self.image_set = False | |
| self.box_points = [] | |
| self.positive_points = [] | |
| self.negative_points = [] | |
| # ====================================================== | |
| # MODELS | |
| # ====================================================== | |
| sam_predictor = None | |
| inpaint_pipeline = None | |
| def load_sam_model(): | |
| global sam_predictor | |
| if sam_predictor is None: | |
| if not os.path.exists(SAM_CHECKPOINT): | |
| raise FileNotFoundError(f"SAM model not found at {SAM_CHECKPOINT}") | |
| sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| sam.to(device=device) | |
| sam_predictor = SamPredictor(sam) | |
| print(f"✅ SAM loaded on {device}") | |
| return sam_predictor | |
| def load_inpaint_pipeline(): | |
| global inpaint_pipeline | |
| if inpaint_pipeline is None: | |
| print("🎨 Loading Stable Diffusion XL Inpainting...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # CPU CANNOT LOAD FP16 | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| variant = "fp16" if device == "cuda" else None | |
| try: | |
| inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| torch_dtype=dtype, | |
| variant=variant, | |
| use_safetensors=True | |
| ) | |
| inpaint_pipeline = inpaint_pipeline.to(device) | |
| print(f"✅ Inpainting model loaded on {device}") | |
| except Exception as e: | |
| print("❌ Failed to load SDXL Inpainting:", str(e)) | |
| inpaint_pipeline = None | |
| return inpaint_pipeline | |
| # ====================================================== | |
| # PYDANTIC MODELS | |
| # ====================================================== | |
| class ClickRequest(BaseModel): | |
| session_id: str | |
| x: int | |
| y: int | |
| mode: str # "box", "positive", "negative" | |
| class MaskRequest(BaseModel): | |
| session_id: str | |
| class BackgroundRequest(BaseModel): | |
| session_id: str | |
| background_prompt: str | |
| negative_prompt: Optional[str] = "low quality, blurry" | |
| guidance_scale: Optional[float] = 7.0 | |
| num_steps: Optional[int] = 30 | |
| seed: Optional[int] = 42 | |
| # ====================================================== | |
| # HELPER FUNCTIONS | |
| # ====================================================== | |
| async def cleanup_files(paths: list): | |
| for path in paths: | |
| if os.path.exists(path): | |
| os.remove(path) | |
| print(f"🧹 File deleted: {path}") | |
| def get_session(session_id: str) -> SessionData: | |
| if session_id not in sessions: | |
| sessions[session_id] = SessionData() | |
| return sessions[session_id] | |
| def pil_to_bytes(img: Image.Image, format="PNG") -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format=format) | |
| buf.seek(0) | |
| return buf.getvalue() | |
| # ====================================================== | |
| # ENDPOINTS | |
| # ====================================================== | |
| def home(): | |
| return { | |
| "message": "✅ AI Background Changer API is running!", | |
| "endpoints": { | |
| "upload": "POST /upload - Upload image", | |
| "click": "POST /click - Add prompt point", | |
| "generate_mask": "POST /generate_mask - Generate segmentation mask", | |
| "auto_segment": "POST /auto_segment - Auto segment all objects", | |
| "auto_smart_select": "POST /auto_smart_select - Auto select foreground", | |
| "reset": "POST /reset - Reset session prompts", | |
| "generate_background": "POST /generate_background - Change background", | |
| "get_preview": "GET /preview/{session_id} - Get current preview" | |
| } | |
| } | |
| async def upload_image(file: UploadFile = File(...)): | |
| """Upload image and initialize SAM""" | |
| try: | |
| # Create session | |
| session_id = str(uuid.uuid4()) | |
| session = get_session(session_id) | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Store in session | |
| session.original_image = image.copy() | |
| session.image_set = False | |
| # Load SAM and set image | |
| predictor = load_sam_model() | |
| predictor.set_image(image) | |
| session.image_set = True | |
| print(f"✅ Image uploaded: {image.shape}, session: {session_id}") | |
| return { | |
| "session_id": session_id, | |
| "width": int(image.shape[1]), | |
| "height": int(image.shape[0]), | |
| "message": "Image uploaded successfully" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def add_click(request: ClickRequest): | |
| """Add click point (box, positive, or negative)""" | |
| try: | |
| session = get_session(request.session_id) | |
| if not session.image_set: | |
| raise HTTPException(status_code=400, detail="Upload image first") | |
| if request.mode == "box": | |
| if len(session.box_points) < 2: | |
| session.box_points.append([request.x, request.y]) | |
| else: | |
| session.box_points = [[request.x, request.y]] | |
| elif request.mode == "positive": | |
| session.positive_points.append([request.x, request.y]) | |
| elif request.mode == "negative": | |
| session.negative_points.append([request.x, request.y]) | |
| # Create preview with points | |
| img_pil = Image.fromarray(session.original_image.copy()) | |
| draw = ImageDraw.Draw(img_pil) | |
| # Draw box | |
| if len(session.box_points) > 0: | |
| if len(session.box_points) == 1: | |
| x, y = session.box_points[0] | |
| draw.ellipse([x-8, y-8, x+8, y+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3) | |
| else: | |
| x1, y1 = session.box_points[0] | |
| x2, y2 = session.box_points[1] | |
| draw.rectangle([x1, y1, x2, y2], outline=(0, 150, 255), width=4) | |
| draw.ellipse([x1-8, y1-8, x1+8, y1+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3) | |
| draw.ellipse([x2-8, y2-8, x2+8, y2+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3) | |
| # Draw positive points | |
| for px, py in session.positive_points: | |
| draw.ellipse([px-10, py-10, px+10, py+10], fill=(0, 255, 0), outline=(255, 255, 255), width=3) | |
| # Draw negative points | |
| for nx, ny in session.negative_points: | |
| draw.ellipse([nx-10, ny-10, nx+10, ny+10], fill=(255, 0, 0), outline=(255, 255, 255), width=3) | |
| img_bytes = pil_to_bytes(img_pil) | |
| return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_mask(request: MaskRequest): | |
| """Generate segmentation mask from prompts""" | |
| try: | |
| session = get_session(request.session_id) | |
| if not session.image_set: | |
| raise HTTPException(status_code=400, detail="Upload image first") | |
| has_box = len(session.box_points) == 2 | |
| has_points = len(session.positive_points) > 0 | |
| if not has_box and not has_points: | |
| raise HTTPException(status_code=400, detail="Add prompts first (box or points)") | |
| predictor = load_sam_model() | |
| box = None | |
| points = None | |
| labels = None | |
| # Box prompt | |
| if has_box: | |
| x1, y1 = session.box_points[0] | |
| x2, y2 = session.box_points[1] | |
| box = np.array([min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)]) | |
| # Point prompts | |
| if has_points or session.negative_points: | |
| point_list = [] | |
| label_list = [] | |
| for px, py in session.positive_points: | |
| point_list.append([px, py]) | |
| label_list.append(1) | |
| for nx, ny in session.negative_points: | |
| point_list.append([nx, ny]) | |
| label_list.append(0) | |
| points = np.array(point_list) | |
| labels = np.array(label_list) | |
| # Predict | |
| masks, scores, logits = predictor.predict( | |
| point_coords=points, | |
| point_labels=labels, | |
| box=box, | |
| multimask_output=True | |
| ) | |
| best_idx = np.argmax(scores) | |
| mask = masks[best_idx] | |
| score = scores[best_idx] | |
| session.current_mask = mask | |
| # Create RGBA | |
| image = session.original_image | |
| rgba_image = np.dstack((image, (mask * 255).astype(np.uint8))) | |
| session.rgba_image = rgba_image | |
| # Convert to PIL | |
| segmented_pil = Image.fromarray(rgba_image) | |
| mask_visual = (mask * 255).astype(np.uint8) | |
| mask_visual_rgb = np.stack([mask_visual] * 3, axis=-1) | |
| mask_pil = Image.fromarray(mask_visual_rgb) | |
| # Statistics | |
| mask_area = mask.sum() | |
| total_area = image.shape[0] * image.shape[1] | |
| area_percentage = (mask_area / total_area) * 100 | |
| # Return segmented image | |
| img_bytes = pil_to_bytes(segmented_pil) | |
| return StreamingResponse( | |
| io.BytesIO(img_bytes), | |
| media_type="image/png", | |
| headers={ | |
| "X-Score": str(float(score)), | |
| "X-Area-Percentage": str(float(area_percentage)), | |
| "X-Method": "box" if has_box else "points" | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def auto_segment_all(request: MaskRequest): | |
| """Auto segment and show all objects""" | |
| try: | |
| session = get_session(request.session_id) | |
| if not session.image_set: | |
| raise HTTPException(status_code=400, detail="Upload image first") | |
| predictor = load_sam_model() | |
| sam_model = predictor.model | |
| mask_generator = SamAutomaticMaskGenerator( | |
| model=sam_model, | |
| points_per_side=32, | |
| pred_iou_thresh=0.86, | |
| stability_score_thresh=0.92, | |
| min_mask_region_area=100 | |
| ) | |
| masks = mask_generator.generate(session.original_image) | |
| masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
| # Create visualization | |
| image = session.original_image.copy() | |
| h, w = image.shape[:2] | |
| overlay = np.zeros((h, w, 3), dtype=np.uint8) | |
| colors = [ | |
| [255, 0, 0], [0, 255, 0], [0, 0, 255], | |
| [255, 255, 0], [255, 0, 255], [0, 255, 255], | |
| [128, 0, 0], [0, 128, 0], [0, 0, 128] | |
| ] | |
| objects_info = [] | |
| for i, mask_data in enumerate(masks[:9]): | |
| mask = mask_data['segmentation'] | |
| color = colors[i % len(colors)] | |
| overlay[mask] = color | |
| area_pct = (mask_data['area'] / (h * w)) * 100 | |
| objects_info.append({ | |
| "index": i + 1, | |
| "area_percentage": float(area_pct), | |
| "color": color | |
| }) | |
| result = cv2.addWeighted(image, 0.6, overlay, 0.4, 0) | |
| result_pil = Image.fromarray(result) | |
| img_bytes = pil_to_bytes(result_pil) | |
| return StreamingResponse( | |
| io.BytesIO(img_bytes), | |
| media_type="image/png", | |
| headers={ | |
| "X-Objects-Found": str(len(masks)), | |
| "X-Objects-Info": json.dumps(objects_info) | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def auto_smart_select(request: MaskRequest): | |
| """Automatically select the best foreground object""" | |
| try: | |
| session = get_session(request.session_id) | |
| if not session.image_set: | |
| raise HTTPException(status_code=400, detail="Upload image first") | |
| predictor = load_sam_model() | |
| sam_model = predictor.model | |
| mask_generator = SamAutomaticMaskGenerator( | |
| model=sam_model, | |
| points_per_side=32, | |
| pred_iou_thresh=0.88, | |
| stability_score_thresh=0.93, | |
| min_mask_region_area=500 | |
| ) | |
| masks = mask_generator.generate(session.original_image) | |
| if len(masks) == 0: | |
| raise HTTPException(status_code=400, detail="No objects detected") | |
| # Smart selection | |
| image = session.original_image | |
| h, w = image.shape[:2] | |
| center_x, center_y = w // 2, h // 2 | |
| def score_mask(mask_data): | |
| mask = mask_data['segmentation'] | |
| area = mask_data['area'] | |
| stability = mask_data['stability_score'] | |
| y_coords, x_coords = np.where(mask) | |
| if len(x_coords) == 0: | |
| return 0 | |
| mask_center_x = x_coords.mean() | |
| mask_center_y = y_coords.mean() | |
| dist_from_center = np.sqrt((mask_center_x - center_x)**2 + (mask_center_y - center_y)**2) | |
| max_dist = np.sqrt(center_x**2 + center_y**2) | |
| center_score = 1 - (dist_from_center / max_dist) | |
| area_ratio = area / (h * w) | |
| if area_ratio > 0.8: | |
| size_score = 0.1 | |
| elif area_ratio < 0.02: | |
| size_score = 0.3 | |
| else: | |
| size_score = min(area_ratio * 5, 1.0) | |
| total_score = (center_score * 0.4 + size_score * 0.4 + stability * 0.2) | |
| return total_score | |
| scored_masks = [(score_mask(m), m) for m in masks] | |
| scored_masks.sort(reverse=True, key=lambda x: x[0]) | |
| best_score, best_mask_data = scored_masks[0] | |
| best_mask = best_mask_data['segmentation'] | |
| session.current_mask = best_mask | |
| rgba_image = np.dstack((image, (best_mask * 255).astype(np.uint8))) | |
| session.rgba_image = rgba_image | |
| segmented_pil = Image.fromarray(rgba_image) | |
| mask_area = best_mask.sum() | |
| area_percentage = (mask_area / (h * w)) * 100 | |
| img_bytes = pil_to_bytes(segmented_pil) | |
| return StreamingResponse( | |
| io.BytesIO(img_bytes), | |
| media_type="image/png", | |
| headers={ | |
| "X-Selection-Score": str(float(best_score)), | |
| "X-Stability": str(float(best_mask_data['stability_score'])), | |
| "X-Area-Percentage": str(float(area_percentage)), | |
| "X-Total-Objects": str(len(masks)) | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def reset_prompts(request: MaskRequest): | |
| """Reset all prompts but keep the image""" | |
| try: | |
| session = get_session(request.session_id) | |
| session.box_points = [] | |
| session.positive_points = [] | |
| session.negative_points = [] | |
| session.current_mask = None | |
| return {"message": "Session reset successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_background( | |
| background_tasks: BackgroundTasks, | |
| request: BackgroundRequest | |
| ): | |
| """Generate new background using Stable Diffusion XL Inpainting""" | |
| try: | |
| session = get_session(request.session_id) | |
| if session.rgba_image is None: | |
| raise HTTPException(status_code=400, detail="Generate mask first") | |
| print(f"🎨 Generating background: {request.background_prompt}") | |
| # Load inpainting model | |
| pipeline = load_inpaint_pipeline() | |
| # Prepare images | |
| rgba_image = Image.fromarray(session.rgba_image) | |
| alpha = np.array(rgba_image.split()[-1]) | |
| mask = np.where(alpha < 128, 255, 0).astype(np.uint8) | |
| mask = Image.fromarray(mask).convert("L") | |
| rgb_image = rgba_image.convert("RGB") | |
| # Generate | |
| generator = torch.manual_seed(int(request.seed)) | |
| result = pipeline( | |
| image=rgb_image, | |
| mask_image=mask, | |
| prompt=request.background_prompt, | |
| negative_prompt=request.negative_prompt, | |
| guidance_scale=request.guidance_scale, | |
| generator=generator, | |
| num_inference_steps=int(request.num_steps), | |
| width=rgb_image.width, | |
| height=rgb_image.height | |
| ).images[0] | |
| # Save result | |
| result_id = str(uuid.uuid4()) | |
| result_path = os.path.join(RESULT_DIR, f"{result_id}.png") | |
| result.save(result_path) | |
| # Schedule cleanup | |
| background_tasks.add_task(cleanup_files, [result_path]) | |
| return FileResponse( | |
| result_path, | |
| media_type="image/png", | |
| headers={ | |
| "X-Prompt": request.background_prompt, | |
| "X-Result-ID": result_id | |
| } | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e)} | |
| ) | |
| async def get_preview(session_id: str): | |
| """Get current image preview""" | |
| try: | |
| session = get_session(session_id) | |
| if session.original_image is None: | |
| raise HTTPException(status_code=400, detail="No image uploaded") | |
| img_pil = Image.fromarray(session.original_image) | |
| img_bytes = pil_to_bytes(img_pil) | |
| return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def health_check(): | |
| return { | |
| "status": "healthy", | |
| "sam_loaded": sam_predictor is not None, | |
| "inpaint_loaded": inpaint_pipeline is not None, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu" | |
| } |