yurista's picture
Update app.py
7ad8ae2 verified
# app.py - AI Background Changer API
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import numpy as np
import os
import uuid
import cv2
import torch
from PIL import Image, ImageDraw
import io
import json
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from diffusers import StableDiffusionXLInpaintPipeline
# ======================================================
# CONFIGURATION
# ======================================================
app = FastAPI(title="AI Background Changer API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ======================================================
# DIRECTORIES
# ======================================================
UPLOAD_DIR = "uploads"
RESULT_DIR = "results"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
MODEL_DIR = "models"
SAM_CHECKPOINT = os.path.join(MODEL_DIR, "sam_vit_b_01ec64.pth")
# ======================================================
# GLOBAL STATE (In-memory session storage)
# ======================================================
sessions = {}
class SessionData:
def __init__(self):
self.original_image = None
self.current_mask = None
self.rgba_image = None
self.image_set = False
self.box_points = []
self.positive_points = []
self.negative_points = []
# ======================================================
# MODELS
# ======================================================
sam_predictor = None
inpaint_pipeline = None
def load_sam_model():
global sam_predictor
if sam_predictor is None:
if not os.path.exists(SAM_CHECKPOINT):
raise FileNotFoundError(f"SAM model not found at {SAM_CHECKPOINT}")
sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT)
device = "cuda" if torch.cuda.is_available() else "cpu"
sam.to(device=device)
sam_predictor = SamPredictor(sam)
print(f"✅ SAM loaded on {device}")
return sam_predictor
def load_inpaint_pipeline():
global inpaint_pipeline
if inpaint_pipeline is None:
print("🎨 Loading Stable Diffusion XL Inpainting...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# CPU CANNOT LOAD FP16
dtype = torch.float16 if device == "cuda" else torch.float32
variant = "fp16" if device == "cuda" else None
try:
inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=dtype,
variant=variant,
use_safetensors=True
)
inpaint_pipeline = inpaint_pipeline.to(device)
print(f"✅ Inpainting model loaded on {device}")
except Exception as e:
print("❌ Failed to load SDXL Inpainting:", str(e))
inpaint_pipeline = None
return inpaint_pipeline
# ======================================================
# PYDANTIC MODELS
# ======================================================
class ClickRequest(BaseModel):
session_id: str
x: int
y: int
mode: str # "box", "positive", "negative"
class MaskRequest(BaseModel):
session_id: str
class BackgroundRequest(BaseModel):
session_id: str
background_prompt: str
negative_prompt: Optional[str] = "low quality, blurry"
guidance_scale: Optional[float] = 7.0
num_steps: Optional[int] = 30
seed: Optional[int] = 42
# ======================================================
# HELPER FUNCTIONS
# ======================================================
async def cleanup_files(paths: list):
for path in paths:
if os.path.exists(path):
os.remove(path)
print(f"🧹 File deleted: {path}")
def get_session(session_id: str) -> SessionData:
if session_id not in sessions:
sessions[session_id] = SessionData()
return sessions[session_id]
def pil_to_bytes(img: Image.Image, format="PNG") -> bytes:
buf = io.BytesIO()
img.save(buf, format=format)
buf.seek(0)
return buf.getvalue()
# ======================================================
# ENDPOINTS
# ======================================================
@app.get("/")
def home():
return {
"message": "✅ AI Background Changer API is running!",
"endpoints": {
"upload": "POST /upload - Upload image",
"click": "POST /click - Add prompt point",
"generate_mask": "POST /generate_mask - Generate segmentation mask",
"auto_segment": "POST /auto_segment - Auto segment all objects",
"auto_smart_select": "POST /auto_smart_select - Auto select foreground",
"reset": "POST /reset - Reset session prompts",
"generate_background": "POST /generate_background - Change background",
"get_preview": "GET /preview/{session_id} - Get current preview"
}
}
@app.post("/upload")
async def upload_image(file: UploadFile = File(...)):
"""Upload image and initialize SAM"""
try:
# Create session
session_id = str(uuid.uuid4())
session = get_session(session_id)
# Read image
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid image file")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Store in session
session.original_image = image.copy()
session.image_set = False
# Load SAM and set image
predictor = load_sam_model()
predictor.set_image(image)
session.image_set = True
print(f"✅ Image uploaded: {image.shape}, session: {session_id}")
return {
"session_id": session_id,
"width": int(image.shape[1]),
"height": int(image.shape[0]),
"message": "Image uploaded successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/click")
async def add_click(request: ClickRequest):
"""Add click point (box, positive, or negative)"""
try:
session = get_session(request.session_id)
if not session.image_set:
raise HTTPException(status_code=400, detail="Upload image first")
if request.mode == "box":
if len(session.box_points) < 2:
session.box_points.append([request.x, request.y])
else:
session.box_points = [[request.x, request.y]]
elif request.mode == "positive":
session.positive_points.append([request.x, request.y])
elif request.mode == "negative":
session.negative_points.append([request.x, request.y])
# Create preview with points
img_pil = Image.fromarray(session.original_image.copy())
draw = ImageDraw.Draw(img_pil)
# Draw box
if len(session.box_points) > 0:
if len(session.box_points) == 1:
x, y = session.box_points[0]
draw.ellipse([x-8, y-8, x+8, y+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3)
else:
x1, y1 = session.box_points[0]
x2, y2 = session.box_points[1]
draw.rectangle([x1, y1, x2, y2], outline=(0, 150, 255), width=4)
draw.ellipse([x1-8, y1-8, x1+8, y1+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3)
draw.ellipse([x2-8, y2-8, x2+8, y2+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3)
# Draw positive points
for px, py in session.positive_points:
draw.ellipse([px-10, py-10, px+10, py+10], fill=(0, 255, 0), outline=(255, 255, 255), width=3)
# Draw negative points
for nx, ny in session.negative_points:
draw.ellipse([nx-10, ny-10, nx+10, ny+10], fill=(255, 0, 0), outline=(255, 255, 255), width=3)
img_bytes = pil_to_bytes(img_pil)
return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate_mask")
async def generate_mask(request: MaskRequest):
"""Generate segmentation mask from prompts"""
try:
session = get_session(request.session_id)
if not session.image_set:
raise HTTPException(status_code=400, detail="Upload image first")
has_box = len(session.box_points) == 2
has_points = len(session.positive_points) > 0
if not has_box and not has_points:
raise HTTPException(status_code=400, detail="Add prompts first (box or points)")
predictor = load_sam_model()
box = None
points = None
labels = None
# Box prompt
if has_box:
x1, y1 = session.box_points[0]
x2, y2 = session.box_points[1]
box = np.array([min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)])
# Point prompts
if has_points or session.negative_points:
point_list = []
label_list = []
for px, py in session.positive_points:
point_list.append([px, py])
label_list.append(1)
for nx, ny in session.negative_points:
point_list.append([nx, ny])
label_list.append(0)
points = np.array(point_list)
labels = np.array(label_list)
# Predict
masks, scores, logits = predictor.predict(
point_coords=points,
point_labels=labels,
box=box,
multimask_output=True
)
best_idx = np.argmax(scores)
mask = masks[best_idx]
score = scores[best_idx]
session.current_mask = mask
# Create RGBA
image = session.original_image
rgba_image = np.dstack((image, (mask * 255).astype(np.uint8)))
session.rgba_image = rgba_image
# Convert to PIL
segmented_pil = Image.fromarray(rgba_image)
mask_visual = (mask * 255).astype(np.uint8)
mask_visual_rgb = np.stack([mask_visual] * 3, axis=-1)
mask_pil = Image.fromarray(mask_visual_rgb)
# Statistics
mask_area = mask.sum()
total_area = image.shape[0] * image.shape[1]
area_percentage = (mask_area / total_area) * 100
# Return segmented image
img_bytes = pil_to_bytes(segmented_pil)
return StreamingResponse(
io.BytesIO(img_bytes),
media_type="image/png",
headers={
"X-Score": str(float(score)),
"X-Area-Percentage": str(float(area_percentage)),
"X-Method": "box" if has_box else "points"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/auto_segment")
async def auto_segment_all(request: MaskRequest):
"""Auto segment and show all objects"""
try:
session = get_session(request.session_id)
if not session.image_set:
raise HTTPException(status_code=400, detail="Upload image first")
predictor = load_sam_model()
sam_model = predictor.model
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
min_mask_region_area=100
)
masks = mask_generator.generate(session.original_image)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
# Create visualization
image = session.original_image.copy()
h, w = image.shape[:2]
overlay = np.zeros((h, w, 3), dtype=np.uint8)
colors = [
[255, 0, 0], [0, 255, 0], [0, 0, 255],
[255, 255, 0], [255, 0, 255], [0, 255, 255],
[128, 0, 0], [0, 128, 0], [0, 0, 128]
]
objects_info = []
for i, mask_data in enumerate(masks[:9]):
mask = mask_data['segmentation']
color = colors[i % len(colors)]
overlay[mask] = color
area_pct = (mask_data['area'] / (h * w)) * 100
objects_info.append({
"index": i + 1,
"area_percentage": float(area_pct),
"color": color
})
result = cv2.addWeighted(image, 0.6, overlay, 0.4, 0)
result_pil = Image.fromarray(result)
img_bytes = pil_to_bytes(result_pil)
return StreamingResponse(
io.BytesIO(img_bytes),
media_type="image/png",
headers={
"X-Objects-Found": str(len(masks)),
"X-Objects-Info": json.dumps(objects_info)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/auto_smart_select")
async def auto_smart_select(request: MaskRequest):
"""Automatically select the best foreground object"""
try:
session = get_session(request.session_id)
if not session.image_set:
raise HTTPException(status_code=400, detail="Upload image first")
predictor = load_sam_model()
sam_model = predictor.model
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32,
pred_iou_thresh=0.88,
stability_score_thresh=0.93,
min_mask_region_area=500
)
masks = mask_generator.generate(session.original_image)
if len(masks) == 0:
raise HTTPException(status_code=400, detail="No objects detected")
# Smart selection
image = session.original_image
h, w = image.shape[:2]
center_x, center_y = w // 2, h // 2
def score_mask(mask_data):
mask = mask_data['segmentation']
area = mask_data['area']
stability = mask_data['stability_score']
y_coords, x_coords = np.where(mask)
if len(x_coords) == 0:
return 0
mask_center_x = x_coords.mean()
mask_center_y = y_coords.mean()
dist_from_center = np.sqrt((mask_center_x - center_x)**2 + (mask_center_y - center_y)**2)
max_dist = np.sqrt(center_x**2 + center_y**2)
center_score = 1 - (dist_from_center / max_dist)
area_ratio = area / (h * w)
if area_ratio > 0.8:
size_score = 0.1
elif area_ratio < 0.02:
size_score = 0.3
else:
size_score = min(area_ratio * 5, 1.0)
total_score = (center_score * 0.4 + size_score * 0.4 + stability * 0.2)
return total_score
scored_masks = [(score_mask(m), m) for m in masks]
scored_masks.sort(reverse=True, key=lambda x: x[0])
best_score, best_mask_data = scored_masks[0]
best_mask = best_mask_data['segmentation']
session.current_mask = best_mask
rgba_image = np.dstack((image, (best_mask * 255).astype(np.uint8)))
session.rgba_image = rgba_image
segmented_pil = Image.fromarray(rgba_image)
mask_area = best_mask.sum()
area_percentage = (mask_area / (h * w)) * 100
img_bytes = pil_to_bytes(segmented_pil)
return StreamingResponse(
io.BytesIO(img_bytes),
media_type="image/png",
headers={
"X-Selection-Score": str(float(best_score)),
"X-Stability": str(float(best_mask_data['stability_score'])),
"X-Area-Percentage": str(float(area_percentage)),
"X-Total-Objects": str(len(masks))
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/reset")
async def reset_prompts(request: MaskRequest):
"""Reset all prompts but keep the image"""
try:
session = get_session(request.session_id)
session.box_points = []
session.positive_points = []
session.negative_points = []
session.current_mask = None
return {"message": "Session reset successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate_background")
async def generate_background(
background_tasks: BackgroundTasks,
request: BackgroundRequest
):
"""Generate new background using Stable Diffusion XL Inpainting"""
try:
session = get_session(request.session_id)
if session.rgba_image is None:
raise HTTPException(status_code=400, detail="Generate mask first")
print(f"🎨 Generating background: {request.background_prompt}")
# Load inpainting model
pipeline = load_inpaint_pipeline()
# Prepare images
rgba_image = Image.fromarray(session.rgba_image)
alpha = np.array(rgba_image.split()[-1])
mask = np.where(alpha < 128, 255, 0).astype(np.uint8)
mask = Image.fromarray(mask).convert("L")
rgb_image = rgba_image.convert("RGB")
# Generate
generator = torch.manual_seed(int(request.seed))
result = pipeline(
image=rgb_image,
mask_image=mask,
prompt=request.background_prompt,
negative_prompt=request.negative_prompt,
guidance_scale=request.guidance_scale,
generator=generator,
num_inference_steps=int(request.num_steps),
width=rgb_image.width,
height=rgb_image.height
).images[0]
# Save result
result_id = str(uuid.uuid4())
result_path = os.path.join(RESULT_DIR, f"{result_id}.png")
result.save(result_path)
# Schedule cleanup
background_tasks.add_task(cleanup_files, [result_path])
return FileResponse(
result_path,
media_type="image/png",
headers={
"X-Prompt": request.background_prompt,
"X-Result-ID": result_id
}
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.get("/preview/{session_id}")
async def get_preview(session_id: str):
"""Get current image preview"""
try:
session = get_session(session_id)
if session.original_image is None:
raise HTTPException(status_code=400, detail="No image uploaded")
img_pil = Image.fromarray(session.original_image)
img_bytes = pil_to_bytes(img_pil)
return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
return {
"status": "healthy",
"sam_loaded": sam_predictor is not None,
"inpaint_loaded": inpaint_pipeline is not None,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}