yurista commited on
Commit
580ae5e
·
verified ·
1 Parent(s): a8206f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +526 -166
app.py CHANGED
@@ -1,21 +1,23 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
2
- from fastapi.responses import FileResponse, JSONResponse
 
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.staticfiles import StaticFiles
 
5
  import numpy as np
6
  import os
7
  import uuid
8
  import cv2
9
  import torch
10
- from PIL import Image
 
11
  import json
12
- from typing import List, Optional
13
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
14
- from huggingface_hub import hf_hub_download
15
  from diffusers import StableDiffusionXLInpaintPipeline
16
 
17
  # ======================================================
18
- # FASTAPI SETUP
19
  # ======================================================
20
  app = FastAPI(title="AI Background Changer API")
21
 
@@ -31,82 +33,105 @@ app.add_middleware(
31
  # DIRECTORIES
32
  # ======================================================
33
  UPLOAD_DIR = "uploads"
34
- PROCESSED_DIR = "processed"
35
- MODELS_DIR = "models"
36
  os.makedirs(UPLOAD_DIR, exist_ok=True)
37
- os.makedirs(PROCESSED_DIR, exist_ok=True)
38
- os.makedirs(MODELS_DIR, exist_ok=True)
39
 
40
- SAM_CHECKPOINT = os.path.join(MODELS_DIR, "sam_vit_b_01ec64.pth")
 
41
 
42
  # ======================================================
43
- # GLOBAL STATE
44
  # ======================================================
45
- global_models = {
46
- "sam_predictor": None,
47
- "inpaint_pipeline": None,
48
- "device": "cuda" if torch.cuda.is_available() else "cpu"
49
- }
50
-
51
  sessions = {}
52
 
53
- # ======================================================
54
- # DOWNLOAD SAM FROM HUGGINGFACE
55
- # ======================================================
56
- def download_sam_checkpoint():
57
- if not os.path.exists(SAM_CHECKPOINT):
58
- print("Downloading SAM checkpoint from HuggingFace...")
59
- hf_hub_download(
60
- repo_id="yurista/AI-Background-Maker",
61
- filename="models/sam_vit_b_01ec64.pth",
62
- local_dir="models",
63
- local_dir_use_symlinks=False
64
- )
65
- print("SAM checkpoint downloaded!")
66
- return SAM_CHECKPOINT
67
 
68
  # ======================================================
69
- # LOAD MODELS
70
  # ======================================================
71
- def load_sam_model():
72
- download_sam_checkpoint()
73
 
74
- if global_models["sam_predictor"] is None:
75
- print(f"Loading SAM model from {SAM_CHECKPOINT}...")
 
 
 
 
76
  sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT)
77
- sam.to(global_models["device"])
78
- predictor = SamPredictor(sam)
79
- global_models["sam_predictor"] = predictor
80
- print("SAM model loaded!")
81
- return global_models["sam_predictor"]
82
-
 
83
 
84
  def load_inpaint_pipeline():
85
- if global_models["inpaint_pipeline"] is None:
86
- print("Loading SDXL Inpaint model...")
87
- pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
 
88
  "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
89
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
90
  )
91
- pipeline = pipeline.to(global_models["device"])
92
- global_models["inpaint_pipeline"] = pipeline
93
- print("SDXL Inpaint loaded!")
94
- return global_models["inpaint_pipeline"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # ======================================================
97
- # HELPERS
98
  # ======================================================
99
- def save_image(image_array, suffix=""):
100
- unique_id = str(uuid.uuid4())
101
- filename = f"{unique_id}{suffix}.png"
102
- path = os.path.join(PROCESSED_DIR, filename)
 
103
 
104
- if isinstance(image_array, np.ndarray):
105
- cv2.imwrite(path, cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR))
106
- else:
107
- image_array.save(path)
108
 
109
- return path
 
 
 
 
110
 
111
  # ======================================================
112
  # ENDPOINTS
@@ -114,113 +139,448 @@ def save_image(image_array, suffix=""):
114
  @app.get("/")
115
  def home():
116
  return {
117
- "message": "API Ready",
118
- "endpoints": ["upload", "segment/box", "segment/points", "segment/auto-smart"]
 
 
 
 
 
 
 
 
 
119
  }
120
 
121
  @app.post("/upload")
122
  async def upload_image(file: UploadFile = File(...)):
123
- ext = os.path.splitext(file.filename)[1]
124
- uid = str(uuid.uuid4())
125
- path = os.path.join(UPLOAD_DIR, f"{uid}{ext}")
126
-
127
- with open(path, "wb") as f:
128
- f.write(await file.read())
129
-
130
- img = cv2.imread(path)
131
- if img is None:
132
- return JSONResponse(status_code=400, content={"error": "Invalid image"})
133
-
134
- rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
135
-
136
- predictor = load_sam_model()
137
- predictor.set_image(rgb)
138
-
139
- sid = str(uuid.uuid4())
140
- sessions[sid] = {
141
- "image_path": path,
142
- "mask": None,
143
- "rgba_path": None,
144
- "image_shape": rgb.shape
145
- }
146
-
147
- return {"session_id": sid, "image_shape": rgb.shape[:2]}
148
-
149
- @app.post("/segment/box")
150
- async def segment_box(
151
- session_id: str = Form(...),
152
- x1: int = Form(...),
153
- y1: int = Form(...),
154
- x2: int = Form(...),
155
- y2: int = Form(...)
156
- ):
157
- if session_id not in sessions:
158
- return JSONResponse(status_code=404, content={"error": "Session not found"})
159
-
160
- sess = sessions[session_id]
161
- img = cv2.imread(sess["image_path"])
162
- rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
163
-
164
- predictor = load_sam_model()
165
- predictor.set_image(rgb)
166
-
167
- box = np.array([min(x1,x2), min(y1,y2), max(x1,x2), max(y1,y2)])
168
-
169
- masks, scores, _ = predictor.predict(box=box, multimask_output=True)
170
- idx = np.argmax(scores)
171
- mask = masks[idx]
172
-
173
- sess["mask"] = mask.tolist()
174
-
175
- rgba = np.dstack((rgb, (mask * 255).astype(np.uint8)))
176
- rgba_path = save_image(rgba, "_rgba")
177
- sess["rgba_path"] = rgba_path
178
-
179
- segmented_path = save_image(rgba, "_segmented")
180
-
181
- return {
182
- "session_id": session_id,
183
- "score": float(scores[idx]),
184
- "segmented_preview": f"/files/{os.path.basename(segmented_path)}"
185
- }
186
-
187
-
188
- @app.post("/generate-background")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  async def generate_background(
190
- session_id: str = Form(...),
191
- prompt: str = Form("futuristic city"),
192
- negative_prompt: str = Form("blurry"),
193
- guidance_scale: float = Form(7.0),
194
- num_steps: int = Form(30)
195
  ):
196
- if session_id not in sessions:
197
- return JSONResponse(status_code=404, content={"error": "Session not found"})
198
-
199
- sess = sessions[session_id]
200
- if sess["rgba_path"] is None:
201
- return JSONResponse(status_code=400, content={"error": "Mask not generated"})
202
-
203
- rgba = Image.open(sess["rgba_path"])
204
- rgb = rgba.convert("RGB")
205
- alpha = np.array(rgba.split()[-1])
206
- mask = Image.fromarray(np.where(alpha < 128, 255, 0).astype(np.uint8))
207
-
208
- pipe = load_inpaint_pipeline()
209
-
210
- result = pipe(
211
- image=rgb,
212
- mask_image=mask,
213
- prompt=prompt,
214
- negative_prompt=negative_prompt,
215
- guidance_scale=float(guidance_scale),
216
- num_inference_steps=int(num_steps)
217
- ).images[0]
218
-
219
- out_path = save_image(result, "_bg")
220
- return FileResponse(out_path, media_type="image/png")
221
-
222
- @app.get("/status/{session_id}")
223
- def status(session_id: str):
224
- return sessions.get(session_id, {"error": "not found"})
225
-
226
- app.mount("/files", StaticFiles(directory=PROCESSED_DIR), name="files")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - AI Background Changer API
2
+
3
+ from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException
4
+ from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
+ from typing import Optional, List
8
  import numpy as np
9
  import os
10
  import uuid
11
  import cv2
12
  import torch
13
+ from PIL import Image, ImageDraw
14
+ import io
15
  import json
 
16
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
 
17
  from diffusers import StableDiffusionXLInpaintPipeline
18
 
19
  # ======================================================
20
+ # CONFIGURATION
21
  # ======================================================
22
  app = FastAPI(title="AI Background Changer API")
23
 
 
33
  # DIRECTORIES
34
  # ======================================================
35
  UPLOAD_DIR = "uploads"
36
+ RESULT_DIR = "results"
 
37
  os.makedirs(UPLOAD_DIR, exist_ok=True)
38
+ os.makedirs(RESULT_DIR, exist_ok=True)
 
39
 
40
+ MODEL_DIR = "models"
41
+ SAM_CHECKPOINT = os.path.join(MODEL_DIR, "sam_vit_b_01ec64.pth")
42
 
43
  # ======================================================
44
+ # GLOBAL STATE (In-memory session storage)
45
  # ======================================================
 
 
 
 
 
 
46
  sessions = {}
47
 
48
+ class SessionData:
49
+ def __init__(self):
50
+ self.original_image = None
51
+ self.current_mask = None
52
+ self.rgba_image = None
53
+ self.image_set = False
54
+ self.box_points = []
55
+ self.positive_points = []
56
+ self.negative_points = []
 
 
 
 
 
57
 
58
  # ======================================================
59
+ # MODELS
60
  # ======================================================
61
+ sam_predictor = None
62
+ inpaint_pipeline = None
63
 
64
+ def load_sam_model():
65
+ global sam_predictor
66
+ if sam_predictor is None:
67
+ if not os.path.exists(SAM_CHECKPOINT):
68
+ raise FileNotFoundError(f"SAM model not found at {SAM_CHECKPOINT}")
69
+
70
  sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT)
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ sam.to(device=device)
73
+
74
+ sam_predictor = SamPredictor(sam)
75
+ print(f"✅ SAM loaded on {device}")
76
+
77
+ return sam_predictor
78
 
79
  def load_inpaint_pipeline():
80
+ global inpaint_pipeline
81
+ if inpaint_pipeline is None:
82
+ print("🎨 Loading Stable Diffusion XL Inpainting...")
83
+ inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
84
  "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
85
+ torch_dtype=torch.float16,
86
+ variant="fp16",
87
  )
88
+
89
+ if torch.cuda.is_available():
90
+ inpaint_pipeline = inpaint_pipeline.to("cuda")
91
+
92
+ print("✅ Inpainting model loaded")
93
+
94
+ return inpaint_pipeline
95
+
96
+ # ======================================================
97
+ # PYDANTIC MODELS
98
+ # ======================================================
99
+ class ClickRequest(BaseModel):
100
+ session_id: str
101
+ x: int
102
+ y: int
103
+ mode: str # "box", "positive", "negative"
104
+
105
+ class MaskRequest(BaseModel):
106
+ session_id: str
107
+
108
+ class BackgroundRequest(BaseModel):
109
+ session_id: str
110
+ background_prompt: str
111
+ negative_prompt: Optional[str] = "low quality, blurry"
112
+ guidance_scale: Optional[float] = 7.0
113
+ num_steps: Optional[int] = 30
114
+ seed: Optional[int] = 42
115
 
116
  # ======================================================
117
+ # HELPER FUNCTIONS
118
  # ======================================================
119
+ async def cleanup_files(paths: list):
120
+ for path in paths:
121
+ if os.path.exists(path):
122
+ os.remove(path)
123
+ print(f"🧹 File deleted: {path}")
124
 
125
+ def get_session(session_id: str) -> SessionData:
126
+ if session_id not in sessions:
127
+ sessions[session_id] = SessionData()
128
+ return sessions[session_id]
129
 
130
+ def pil_to_bytes(img: Image.Image, format="PNG") -> bytes:
131
+ buf = io.BytesIO()
132
+ img.save(buf, format=format)
133
+ buf.seek(0)
134
+ return buf.getvalue()
135
 
136
  # ======================================================
137
  # ENDPOINTS
 
139
  @app.get("/")
140
  def home():
141
  return {
142
+ "message": "✅ AI Background Changer API is running!",
143
+ "endpoints": {
144
+ "upload": "POST /upload - Upload image",
145
+ "click": "POST /click - Add prompt point",
146
+ "generate_mask": "POST /generate_mask - Generate segmentation mask",
147
+ "auto_segment": "POST /auto_segment - Auto segment all objects",
148
+ "auto_smart_select": "POST /auto_smart_select - Auto select foreground",
149
+ "reset": "POST /reset - Reset session prompts",
150
+ "generate_background": "POST /generate_background - Change background",
151
+ "get_preview": "GET /preview/{session_id} - Get current preview"
152
+ }
153
  }
154
 
155
  @app.post("/upload")
156
  async def upload_image(file: UploadFile = File(...)):
157
+ """Upload image and initialize SAM"""
158
+ try:
159
+ # Create session
160
+ session_id = str(uuid.uuid4())
161
+ session = get_session(session_id)
162
+
163
+ # Read image
164
+ contents = await file.read()
165
+ nparr = np.frombuffer(contents, np.uint8)
166
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
167
+
168
+ if image is None:
169
+ raise HTTPException(status_code=400, detail="Invalid image file")
170
+
171
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
172
+
173
+ # Store in session
174
+ session.original_image = image.copy()
175
+ session.image_set = False
176
+
177
+ # Load SAM and set image
178
+ predictor = load_sam_model()
179
+ predictor.set_image(image)
180
+ session.image_set = True
181
+
182
+ print(f"�� Image uploaded: {image.shape}, session: {session_id}")
183
+
184
+ return {
185
+ "session_id": session_id,
186
+ "width": int(image.shape[1]),
187
+ "height": int(image.shape[0]),
188
+ "message": "Image uploaded successfully"
189
+ }
190
+
191
+ except Exception as e:
192
+ raise HTTPException(status_code=500, detail=str(e))
193
+
194
+ @app.post("/click")
195
+ async def add_click(request: ClickRequest):
196
+ """Add click point (box, positive, or negative)"""
197
+ try:
198
+ session = get_session(request.session_id)
199
+
200
+ if not session.image_set:
201
+ raise HTTPException(status_code=400, detail="Upload image first")
202
+
203
+ if request.mode == "box":
204
+ if len(session.box_points) < 2:
205
+ session.box_points.append([request.x, request.y])
206
+ else:
207
+ session.box_points = [[request.x, request.y]]
208
+ elif request.mode == "positive":
209
+ session.positive_points.append([request.x, request.y])
210
+ elif request.mode == "negative":
211
+ session.negative_points.append([request.x, request.y])
212
+
213
+ # Create preview with points
214
+ img_pil = Image.fromarray(session.original_image.copy())
215
+ draw = ImageDraw.Draw(img_pil)
216
+
217
+ # Draw box
218
+ if len(session.box_points) > 0:
219
+ if len(session.box_points) == 1:
220
+ x, y = session.box_points[0]
221
+ draw.ellipse([x-8, y-8, x+8, y+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3)
222
+ else:
223
+ x1, y1 = session.box_points[0]
224
+ x2, y2 = session.box_points[1]
225
+ draw.rectangle([x1, y1, x2, y2], outline=(0, 150, 255), width=4)
226
+ draw.ellipse([x1-8, y1-8, x1+8, y1+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3)
227
+ draw.ellipse([x2-8, y2-8, x2+8, y2+8], fill=(0, 150, 255), outline=(255, 255, 255), width=3)
228
+
229
+ # Draw positive points
230
+ for px, py in session.positive_points:
231
+ draw.ellipse([px-10, py-10, px+10, py+10], fill=(0, 255, 0), outline=(255, 255, 255), width=3)
232
+
233
+ # Draw negative points
234
+ for nx, ny in session.negative_points:
235
+ draw.ellipse([nx-10, ny-10, nx+10, ny+10], fill=(255, 0, 0), outline=(255, 255, 255), width=3)
236
+
237
+ img_bytes = pil_to_bytes(img_pil)
238
+
239
+ return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png")
240
+
241
+ except Exception as e:
242
+ raise HTTPException(status_code=500, detail=str(e))
243
+
244
+ @app.post("/generate_mask")
245
+ async def generate_mask(request: MaskRequest):
246
+ """Generate segmentation mask from prompts"""
247
+ try:
248
+ session = get_session(request.session_id)
249
+
250
+ if not session.image_set:
251
+ raise HTTPException(status_code=400, detail="Upload image first")
252
+
253
+ has_box = len(session.box_points) == 2
254
+ has_points = len(session.positive_points) > 0
255
+
256
+ if not has_box and not has_points:
257
+ raise HTTPException(status_code=400, detail="Add prompts first (box or points)")
258
+
259
+ predictor = load_sam_model()
260
+
261
+ box = None
262
+ points = None
263
+ labels = None
264
+
265
+ # Box prompt
266
+ if has_box:
267
+ x1, y1 = session.box_points[0]
268
+ x2, y2 = session.box_points[1]
269
+ box = np.array([min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)])
270
+
271
+ # Point prompts
272
+ if has_points or session.negative_points:
273
+ point_list = []
274
+ label_list = []
275
+
276
+ for px, py in session.positive_points:
277
+ point_list.append([px, py])
278
+ label_list.append(1)
279
+
280
+ for nx, ny in session.negative_points:
281
+ point_list.append([nx, ny])
282
+ label_list.append(0)
283
+
284
+ points = np.array(point_list)
285
+ labels = np.array(label_list)
286
+
287
+ # Predict
288
+ masks, scores, logits = predictor.predict(
289
+ point_coords=points,
290
+ point_labels=labels,
291
+ box=box,
292
+ multimask_output=True
293
+ )
294
+
295
+ best_idx = np.argmax(scores)
296
+ mask = masks[best_idx]
297
+ score = scores[best_idx]
298
+
299
+ session.current_mask = mask
300
+
301
+ # Create RGBA
302
+ image = session.original_image
303
+ rgba_image = np.dstack((image, (mask * 255).astype(np.uint8)))
304
+ session.rgba_image = rgba_image
305
+
306
+ # Convert to PIL
307
+ segmented_pil = Image.fromarray(rgba_image)
308
+
309
+ mask_visual = (mask * 255).astype(np.uint8)
310
+ mask_visual_rgb = np.stack([mask_visual] * 3, axis=-1)
311
+ mask_pil = Image.fromarray(mask_visual_rgb)
312
+
313
+ # Statistics
314
+ mask_area = mask.sum()
315
+ total_area = image.shape[0] * image.shape[1]
316
+ area_percentage = (mask_area / total_area) * 100
317
+
318
+ # Return segmented image
319
+ img_bytes = pil_to_bytes(segmented_pil)
320
+
321
+ return StreamingResponse(
322
+ io.BytesIO(img_bytes),
323
+ media_type="image/png",
324
+ headers={
325
+ "X-Score": str(float(score)),
326
+ "X-Area-Percentage": str(float(area_percentage)),
327
+ "X-Method": "box" if has_box else "points"
328
+ }
329
+ )
330
+
331
+ except Exception as e:
332
+ raise HTTPException(status_code=500, detail=str(e))
333
+
334
+ @app.post("/auto_segment")
335
+ async def auto_segment_all(request: MaskRequest):
336
+ """Auto segment and show all objects"""
337
+ try:
338
+ session = get_session(request.session_id)
339
+
340
+ if not session.image_set:
341
+ raise HTTPException(status_code=400, detail="Upload image first")
342
+
343
+ predictor = load_sam_model()
344
+ sam_model = predictor.model
345
+
346
+ mask_generator = SamAutomaticMaskGenerator(
347
+ model=sam_model,
348
+ points_per_side=32,
349
+ pred_iou_thresh=0.86,
350
+ stability_score_thresh=0.92,
351
+ min_mask_region_area=100
352
+ )
353
+
354
+ masks = mask_generator.generate(session.original_image)
355
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
356
+
357
+ # Create visualization
358
+ image = session.original_image.copy()
359
+ h, w = image.shape[:2]
360
+ overlay = np.zeros((h, w, 3), dtype=np.uint8)
361
+
362
+ colors = [
363
+ [255, 0, 0], [0, 255, 0], [0, 0, 255],
364
+ [255, 255, 0], [255, 0, 255], [0, 255, 255],
365
+ [128, 0, 0], [0, 128, 0], [0, 0, 128]
366
+ ]
367
+
368
+ objects_info = []
369
+ for i, mask_data in enumerate(masks[:9]):
370
+ mask = mask_data['segmentation']
371
+ color = colors[i % len(colors)]
372
+ overlay[mask] = color
373
+
374
+ area_pct = (mask_data['area'] / (h * w)) * 100
375
+ objects_info.append({
376
+ "index": i + 1,
377
+ "area_percentage": float(area_pct),
378
+ "color": color
379
+ })
380
+
381
+ result = cv2.addWeighted(image, 0.6, overlay, 0.4, 0)
382
+ result_pil = Image.fromarray(result)
383
+
384
+ img_bytes = pil_to_bytes(result_pil)
385
+
386
+ return StreamingResponse(
387
+ io.BytesIO(img_bytes),
388
+ media_type="image/png",
389
+ headers={
390
+ "X-Objects-Found": str(len(masks)),
391
+ "X-Objects-Info": json.dumps(objects_info)
392
+ }
393
+ )
394
+
395
+ except Exception as e:
396
+ raise HTTPException(status_code=500, detail=str(e))
397
+
398
+ @app.post("/auto_smart_select")
399
+ async def auto_smart_select(request: MaskRequest):
400
+ """Automatically select the best foreground object"""
401
+ try:
402
+ session = get_session(request.session_id)
403
+
404
+ if not session.image_set:
405
+ raise HTTPException(status_code=400, detail="Upload image first")
406
+
407
+ predictor = load_sam_model()
408
+ sam_model = predictor.model
409
+
410
+ mask_generator = SamAutomaticMaskGenerator(
411
+ model=sam_model,
412
+ points_per_side=32,
413
+ pred_iou_thresh=0.88,
414
+ stability_score_thresh=0.93,
415
+ min_mask_region_area=500
416
+ )
417
+
418
+ masks = mask_generator.generate(session.original_image)
419
+
420
+ if len(masks) == 0:
421
+ raise HTTPException(status_code=400, detail="No objects detected")
422
+
423
+ # Smart selection
424
+ image = session.original_image
425
+ h, w = image.shape[:2]
426
+ center_x, center_y = w // 2, h // 2
427
+
428
+ def score_mask(mask_data):
429
+ mask = mask_data['segmentation']
430
+ area = mask_data['area']
431
+ stability = mask_data['stability_score']
432
+
433
+ y_coords, x_coords = np.where(mask)
434
+ if len(x_coords) == 0:
435
+ return 0
436
+
437
+ mask_center_x = x_coords.mean()
438
+ mask_center_y = y_coords.mean()
439
+
440
+ dist_from_center = np.sqrt((mask_center_x - center_x)**2 + (mask_center_y - center_y)**2)
441
+ max_dist = np.sqrt(center_x**2 + center_y**2)
442
+ center_score = 1 - (dist_from_center / max_dist)
443
+
444
+ area_ratio = area / (h * w)
445
+ if area_ratio > 0.8:
446
+ size_score = 0.1
447
+ elif area_ratio < 0.02:
448
+ size_score = 0.3
449
+ else:
450
+ size_score = min(area_ratio * 5, 1.0)
451
+
452
+ total_score = (center_score * 0.4 + size_score * 0.4 + stability * 0.2)
453
+ return total_score
454
+
455
+ scored_masks = [(score_mask(m), m) for m in masks]
456
+ scored_masks.sort(reverse=True, key=lambda x: x[0])
457
+
458
+ best_score, best_mask_data = scored_masks[0]
459
+ best_mask = best_mask_data['segmentation']
460
+
461
+ session.current_mask = best_mask
462
+
463
+ rgba_image = np.dstack((image, (best_mask * 255).astype(np.uint8)))
464
+ session.rgba_image = rgba_image
465
+
466
+ segmented_pil = Image.fromarray(rgba_image)
467
+
468
+ mask_area = best_mask.sum()
469
+ area_percentage = (mask_area / (h * w)) * 100
470
+
471
+ img_bytes = pil_to_bytes(segmented_pil)
472
+
473
+ return StreamingResponse(
474
+ io.BytesIO(img_bytes),
475
+ media_type="image/png",
476
+ headers={
477
+ "X-Selection-Score": str(float(best_score)),
478
+ "X-Stability": str(float(best_mask_data['stability_score'])),
479
+ "X-Area-Percentage": str(float(area_percentage)),
480
+ "X-Total-Objects": str(len(masks))
481
+ }
482
+ )
483
+
484
+ except Exception as e:
485
+ raise HTTPException(status_code=500, detail=str(e))
486
+
487
+ @app.post("/reset")
488
+ async def reset_prompts(request: MaskRequest):
489
+ """Reset all prompts but keep the image"""
490
+ try:
491
+ session = get_session(request.session_id)
492
+
493
+ session.box_points = []
494
+ session.positive_points = []
495
+ session.negative_points = []
496
+ session.current_mask = None
497
+
498
+ return {"message": "Session reset successfully"}
499
+
500
+ except Exception as e:
501
+ raise HTTPException(status_code=500, detail=str(e))
502
+
503
+ @app.post("/generate_background")
504
  async def generate_background(
505
+ background_tasks: BackgroundTasks,
506
+ request: BackgroundRequest
 
 
 
507
  ):
508
+ """Generate new background using Stable Diffusion XL Inpainting"""
509
+ try:
510
+ session = get_session(request.session_id)
511
+
512
+ if session.rgba_image is None:
513
+ raise HTTPException(status_code=400, detail="Generate mask first")
514
+
515
+ print(f"🎨 Generating background: {request.background_prompt}")
516
+
517
+ # Load inpainting model
518
+ pipeline = load_inpaint_pipeline()
519
+
520
+ # Prepare images
521
+ rgba_image = Image.fromarray(session.rgba_image)
522
+ alpha = np.array(rgba_image.split()[-1])
523
+ mask = np.where(alpha < 128, 255, 0).astype(np.uint8)
524
+ mask = Image.fromarray(mask).convert("L")
525
+
526
+ rgb_image = rgba_image.convert("RGB")
527
+
528
+ # Generate
529
+ generator = torch.manual_seed(int(request.seed))
530
+ result = pipeline(
531
+ image=rgb_image,
532
+ mask_image=mask,
533
+ prompt=request.background_prompt,
534
+ negative_prompt=request.negative_prompt,
535
+ guidance_scale=request.guidance_scale,
536
+ generator=generator,
537
+ num_inference_steps=int(request.num_steps),
538
+ width=rgb_image.width,
539
+ height=rgb_image.height
540
+ ).images[0]
541
+
542
+ # Save result
543
+ result_id = str(uuid.uuid4())
544
+ result_path = os.path.join(RESULT_DIR, f"{result_id}.png")
545
+ result.save(result_path)
546
+
547
+ # Schedule cleanup
548
+ background_tasks.add_task(cleanup_files, [result_path])
549
+
550
+ return FileResponse(
551
+ result_path,
552
+ media_type="image/png",
553
+ headers={
554
+ "X-Prompt": request.background_prompt,
555
+ "X-Result-ID": result_id
556
+ }
557
+ )
558
+
559
+ except Exception as e:
560
+ raise HTTPException(status_code=500, detail=str(e))
561
+
562
+ @app.get("/preview/{session_id}")
563
+ async def get_preview(session_id: str):
564
+ """Get current image preview"""
565
+ try:
566
+ session = get_session(session_id)
567
+
568
+ if session.original_image is None:
569
+ raise HTTPException(status_code=400, detail="No image uploaded")
570
+
571
+ img_pil = Image.fromarray(session.original_image)
572
+ img_bytes = pil_to_bytes(img_pil)
573
+
574
+ return StreamingResponse(io.BytesIO(img_bytes), media_type="image/png")
575
+
576
+ except Exception as e:
577
+ raise HTTPException(status_code=500, detail=str(e))
578
+
579
+ @app.get("/health")
580
+ def health_check():
581
+ return {
582
+ "status": "healthy",
583
+ "sam_loaded": sam_predictor is not None,
584
+ "inpaint_loaded": inpaint_pipeline is not None,
585
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
586
+ }