# app.py - AI Background Changer API from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List import numpy as np import os import uuid import cv2 import torch from PIL import Image, ImageDraw import io import json from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator from diffusers import StableDiffusionXLInpaintPipeline # ====================================================== # CONFIGURATION # ====================================================== app = FastAPI(title="AI Background Changer API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ====================================================== # DIRECTORIES # ====================================================== UPLOAD_DIR = "uploads" RESULT_DIR = "results" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULT_DIR, exist_ok=True) MODEL_DIR = "models" SAM_CHECKPOINT = os.path.join(MODEL_DIR, "sam_vit_b_01ec64.pth") # ====================================================== # GLOBAL STATE (In-memory session storage) # ====================================================== sessions = {} class SessionData: def __init__(self): self.original_image = None self.current_mask = None self.rgba_image = None self.image_set = False self.box_points = [] self.positive_points = [] self.negative_points = [] # ====================================================== # MODELS # ====================================================== sam_predictor = None inpaint_pipeline = None def load_sam_model(): global sam_predictor if sam_predictor is None: if not os.path.exists(SAM_CHECKPOINT): raise FileNotFoundError(f"SAM model not found at {SAM_CHECKPOINT}") sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT) device = "cuda" if torch.cuda.is_available() else "cpu" sam.to(device=device) sam_predictor = SamPredictor(sam) print(f"โœ… SAM loaded on {device}") return sam_predictor def load_inpaint_pipeline(): global inpaint_pipeline if inpaint_pipeline is None: print("๐ŸŽจ Loading Stable Diffusion XL Inpainting...") device = "cuda" if torch.cuda.is_available() else "cpu" # CPU CANNOT LOAD FP16 dtype = torch.float16 if device == "cuda" else torch.float32 variant = "fp16" if device == "cuda" else None try: inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=dtype, variant=variant, use_safetensors=True ) inpaint_pipeline = inpaint_pipeline.to(device) print(f"โœ… Inpainting model loaded on {device}") except Exception as e: print("โŒ Failed to load SDXL Inpainting:", str(e)) inpaint_pipeline = None return inpaint_pipeline # ====================================================== # PYDANTIC MODELS # ====================================================== class ClickRequest(BaseModel): session_id: str x: int y: int mode: str # "box", "positive", "negative" class MaskRequest(BaseModel): session_id: str class BackgroundRequest(BaseModel): session_id: str background_prompt: str negative_prompt: Optional[str] = "low quality, blurry" guidance_scale: Optional[float] = 7.0 num_steps: Optional[int] = 30 seed: Optional[int] = 42 # ====================================================== # HELPER FUNCTIONS # ====================================================== async def cleanup_files(paths: list): for path in paths: if os.path.exists(path): os.remove(path) print(f"๐Ÿงน File deleted: {path}") def get_session(session_id: str) -> SessionData: if session_id not in sessions: sessions[session_id] = SessionData() return sessions[session_id] def pil_to_bytes(img: Image.Image, format="PNG") -> bytes: buf = io.BytesIO() img.save(buf, format=format) buf.seek(0) return buf.getvalue() # ====================================================== # ENDPOINTS # ====================================================== @app.get("/") def home(): return { "message": "โœ… AI Background Changer API is running!", "endpoints": { "upload": "POST /upload - Upload image", "click": "POST /click - Add prompt point", "generate_mask": "POST /generate_mask - Generate segmentation mask", "auto_segment": "POST /auto_segment - Auto segment all objects", "auto_smart_select": "POST /auto_smart_select - Auto select foreground", "reset": "POST /reset - Reset session prompts", "generate_background": "POST /generate_background - Change background", "get_preview": "GET /preview/{session_id} - Get current preview" } } @app.post("/upload") async def upload_image(file: UploadFile = File(...)): """Upload image and initialize SAM""" try: # Create session session_id = str(uuid.uuid4()) session = get_session(session_id) # Read image contents = await file.read() nparr = np.frombuffer(contents, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: raise HTTPException(status_code=400, detail="Invalid image file") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Store in session session.original_image = image.copy() session.image_set = False # Load SAM and set image predictor = load_sam_model() predictor.set_image(image) session.image_set = True print(f"โœ… Image uploaded: {image.shape}, session: {session_id}") return { "session_id": session_id, "width": int(image.shape[1]), "height": int(image.shape[0]), "message": "Image uploaded successfully" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/click") async def add_click(request: ClickRequest): """Add click point (box, positive, or negative)""" try: session = get_session(request.session_id) if not session.image_set: raise HTTPException(status_code=400, detail="Upload image first") if request.mode == "box": if len(session.box_points) < 2: session.box_points.append([request.x, request.y]) else: session.box_points = [[request.x, request.y]] elif request.mode == "positive": session.positive_points.append([request.x, request.y]) elif request.mode == "negative": session.negative_points.append([request.x, request.y]) # Create preview with points img_pil = Image.fromarray(session.original_image.copy()) draw = ImageDraw.Draw(img_pil) # Draw box if len(session.box_points) > 0: if len(session.box_points) == 1: x, y = session.box_points[0] draw.ellipse([x-8, y-8, x+8, y+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3) else: x1, y1 = session.box_points[0] x2, y2 = session.box_points[1] draw.rectangle([x1, y1, x2, y2], outline=(0, 150, 255), width=4) draw.ellipse([x1-8, y1-8, x1+8, y1+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3) draw.ellipse([x2-8, y2-8, x2+8, y2+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3) # Draw positive points for px, py in session.positive_points: draw.ellipse([px-10, py-10, px+10, py+10], fill=(0, 255, 0), outline=(255, 255, 255), width=3) # Draw negative points for nx, ny in session.negative_points: draw.ellipse([nx-10, ny-10, nx+10, ny+10], fill=(255, 0, 0), outline=(255, 255, 255), width=3) img_bytes = pil_to_bytes(img_pil) return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate_mask") async def generate_mask(request: MaskRequest): """Generate segmentation mask from prompts""" try: session = get_session(request.session_id) if not session.image_set: raise HTTPException(status_code=400, detail="Upload image first") has_box = len(session.box_points) == 2 has_points = len(session.positive_points) > 0 if not has_box and not has_points: raise HTTPException(status_code=400, detail="Add prompts first (box or points)") predictor = load_sam_model() box = None points = None labels = None # Box prompt if has_box: x1, y1 = session.box_points[0] x2, y2 = session.box_points[1] box = np.array([min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)]) # Point prompts if has_points or session.negative_points: point_list = [] label_list = [] for px, py in session.positive_points: point_list.append([px, py]) label_list.append(1) for nx, ny in session.negative_points: point_list.append([nx, ny]) label_list.append(0) points = np.array(point_list) labels = np.array(label_list) # Predict masks, scores, logits = predictor.predict( point_coords=points, point_labels=labels, box=box, multimask_output=True ) best_idx = np.argmax(scores) mask = masks[best_idx] score = scores[best_idx] session.current_mask = mask # Create RGBA image = session.original_image rgba_image = np.dstack((image, (mask * 255).astype(np.uint8))) session.rgba_image = rgba_image # Convert to PIL segmented_pil = Image.fromarray(rgba_image) mask_visual = (mask * 255).astype(np.uint8) mask_visual_rgb = np.stack([mask_visual] * 3, axis=-1) mask_pil = Image.fromarray(mask_visual_rgb) # Statistics mask_area = mask.sum() total_area = image.shape[0] * image.shape[1] area_percentage = (mask_area / total_area) * 100 # Return segmented image img_bytes = pil_to_bytes(segmented_pil) return StreamingResponse( io.BytesIO(img_bytes), media_type="image/png", headers={ "X-Score": str(float(score)), "X-Area-Percentage": str(float(area_percentage)), "X-Method": "box" if has_box else "points" } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/auto_segment") async def auto_segment_all(request: MaskRequest): """Auto segment and show all objects""" try: session = get_session(request.session_id) if not session.image_set: raise HTTPException(status_code=400, detail="Upload image first") predictor = load_sam_model() sam_model = predictor.model mask_generator = SamAutomaticMaskGenerator( model=sam_model, points_per_side=32, pred_iou_thresh=0.86, stability_score_thresh=0.92, min_mask_region_area=100 ) masks = mask_generator.generate(session.original_image) masks = sorted(masks, key=lambda x: x['area'], reverse=True) # Create visualization image = session.original_image.copy() h, w = image.shape[:2] overlay = np.zeros((h, w, 3), dtype=np.uint8) colors = [ [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255], [128, 0, 0], [0, 128, 0], [0, 0, 128] ] objects_info = [] for i, mask_data in enumerate(masks[:9]): mask = mask_data['segmentation'] color = colors[i % len(colors)] overlay[mask] = color area_pct = (mask_data['area'] / (h * w)) * 100 objects_info.append({ "index": i + 1, "area_percentage": float(area_pct), "color": color }) result = cv2.addWeighted(image, 0.6, overlay, 0.4, 0) result_pil = Image.fromarray(result) img_bytes = pil_to_bytes(result_pil) return StreamingResponse( io.BytesIO(img_bytes), media_type="image/png", headers={ "X-Objects-Found": str(len(masks)), "X-Objects-Info": json.dumps(objects_info) } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/auto_smart_select") async def auto_smart_select(request: MaskRequest): """Automatically select the best foreground object""" try: session = get_session(request.session_id) if not session.image_set: raise HTTPException(status_code=400, detail="Upload image first") predictor = load_sam_model() sam_model = predictor.model mask_generator = SamAutomaticMaskGenerator( model=sam_model, points_per_side=32, pred_iou_thresh=0.88, stability_score_thresh=0.93, min_mask_region_area=500 ) masks = mask_generator.generate(session.original_image) if len(masks) == 0: raise HTTPException(status_code=400, detail="No objects detected") # Smart selection image = session.original_image h, w = image.shape[:2] center_x, center_y = w // 2, h // 2 def score_mask(mask_data): mask = mask_data['segmentation'] area = mask_data['area'] stability = mask_data['stability_score'] y_coords, x_coords = np.where(mask) if len(x_coords) == 0: return 0 mask_center_x = x_coords.mean() mask_center_y = y_coords.mean() dist_from_center = np.sqrt((mask_center_x - center_x)**2 + (mask_center_y - center_y)**2) max_dist = np.sqrt(center_x**2 + center_y**2) center_score = 1 - (dist_from_center / max_dist) area_ratio = area / (h * w) if area_ratio > 0.8: size_score = 0.1 elif area_ratio < 0.02: size_score = 0.3 else: size_score = min(area_ratio * 5, 1.0) total_score = (center_score * 0.4 + size_score * 0.4 + stability * 0.2) return total_score scored_masks = [(score_mask(m), m) for m in masks] scored_masks.sort(reverse=True, key=lambda x: x[0]) best_score, best_mask_data = scored_masks[0] best_mask = best_mask_data['segmentation'] session.current_mask = best_mask rgba_image = np.dstack((image, (best_mask * 255).astype(np.uint8))) session.rgba_image = rgba_image segmented_pil = Image.fromarray(rgba_image) mask_area = best_mask.sum() area_percentage = (mask_area / (h * w)) * 100 img_bytes = pil_to_bytes(segmented_pil) return StreamingResponse( io.BytesIO(img_bytes), media_type="image/png", headers={ "X-Selection-Score": str(float(best_score)), "X-Stability": str(float(best_mask_data['stability_score'])), "X-Area-Percentage": str(float(area_percentage)), "X-Total-Objects": str(len(masks)) } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/reset") async def reset_prompts(request: MaskRequest): """Reset all prompts but keep the image""" try: session = get_session(request.session_id) session.box_points = [] session.positive_points = [] session.negative_points = [] session.current_mask = None return {"message": "Session reset successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate_background") async def generate_background( background_tasks: BackgroundTasks, request: BackgroundRequest ): """Generate new background using Stable Diffusion XL Inpainting""" try: session = get_session(request.session_id) if session.rgba_image is None: raise HTTPException(status_code=400, detail="Generate mask first") print(f"๐ŸŽจ Generating background: {request.background_prompt}") # Load inpainting model pipeline = load_inpaint_pipeline() # Prepare images rgba_image = Image.fromarray(session.rgba_image) alpha = np.array(rgba_image.split()[-1]) mask = np.where(alpha < 128, 255, 0).astype(np.uint8) mask = Image.fromarray(mask).convert("L") rgb_image = rgba_image.convert("RGB") # Generate generator = torch.manual_seed(int(request.seed)) result = pipeline( image=rgb_image, mask_image=mask, prompt=request.background_prompt, negative_prompt=request.negative_prompt, guidance_scale=request.guidance_scale, generator=generator, num_inference_steps=int(request.num_steps), width=rgb_image.width, height=rgb_image.height ).images[0] # Save result result_id = str(uuid.uuid4()) result_path = os.path.join(RESULT_DIR, f"{result_id}.png") result.save(result_path) # Schedule cleanup background_tasks.add_task(cleanup_files, [result_path]) return FileResponse( result_path, media_type="image/png", headers={ "X-Prompt": request.background_prompt, "X-Result-ID": result_id } ) except Exception as e: return JSONResponse( status_code=500, content={"error": str(e)} ) @app.get("/preview/{session_id}") async def get_preview(session_id: str): """Get current image preview""" try: session = get_session(session_id) if session.original_image is None: raise HTTPException(status_code=400, detail="No image uploaded") img_pil = Image.fromarray(session.original_image) img_bytes = pil_to_bytes(img_pil) return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") def health_check(): return { "status": "healthy", "sam_loaded": sam_predictor is not None, "inpaint_loaded": inpaint_pipeline is not None, "device": "cuda" if torch.cuda.is_available() else "cpu" }