yurista commited on
Commit
83d7af9
·
verified ·
1 Parent(s): d2340e0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install git+https://github.com/facebookresearch/segment-anything.git
2
+ !pip install opencv-python pillow matplotlib
3
+ !pip uninstall -y diffusers transformers
4
+ !pip install diffusers==0.30.3 transformers==4.44.2 accelerate
5
+ !pip install gradio
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageDraw
11
+ import cv2
12
+ from segment_anything import sam_model_registry, SamPredictor
13
+ from diffusers import StableDiffusionXLInpaintPipeline
14
+ import os
15
+
16
+ # ============================================================
17
+ # SAM SMART FOREGROUND SELECTOR - OFFLINE MODEL VERSION
18
+ # ============================================================
19
+
20
+ # Global state
21
+ global_state = {
22
+ "sam_predictor": None,
23
+ "original_image": None,
24
+ "current_mask": None,
25
+ "rgba_image": None,
26
+ "image_set": False,
27
+ "box_points": [],
28
+ "positive_points": [],
29
+ "negative_points": [],
30
+ "auto_masks": []
31
+ }
32
+
33
+ # ------------------------------------------------------------
34
+ # 🧩 Load SAM Model (Offline from /models/)
35
+ # ------------------------------------------------------------
36
+ def load_sam_model(sam_checkpoint="models/sam_vit_b_01ec64.pth"):
37
+ """Load SAM model from local folder models/"""
38
+ if global_state["sam_predictor"] is None:
39
+ if not os.path.exists(sam_checkpoint):
40
+ raise FileNotFoundError(
41
+ f"❌ File model tidak ditemukan: {sam_checkpoint}\n"
42
+ f"Pastikan kamu sudah upload 'sam_vit_b_01ec64.pth' ke folder /models/"
43
+ )
44
+
45
+ print(f"✅ Loading SAM model from {sam_checkpoint}...")
46
+ sam = sam_model_registry["vit_b"](checkpoint=sam_checkpoint)
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ sam.to(device=device)
49
+ predictor = SamPredictor(sam)
50
+ global_state["sam_predictor"] = predictor
51
+ print(f"✅ SAM loaded successfully on {device}")
52
+ return global_state["sam_predictor"]
53
+
54
+ # ------------------------------------------------------------
55
+ # 🖼️ Step 1: Upload Image
56
+ # ------------------------------------------------------------
57
+ def process_upload(image):
58
+ if image is None:
59
+ return None, "⚠️ Upload gambar terlebih dahulu!", None, None
60
+
61
+ image_np = np.array(image)
62
+ global_state["original_image"] = image_np.copy()
63
+ predictor = load_sam_model()
64
+ predictor.set_image(image_np)
65
+ global_state["image_set"] = True
66
+
67
+ status = "✅ Gambar berhasil diupload!\n\nKlik 'Auto Smart Select' atau gunakan Box/Point untuk memilih objek."
68
+ return Image.fromarray(image_np), status, None, None
69
+
70
+ # ------------------------------------------------------------
71
+ # ✨ Step 2: Auto Smart Select (Pilih Objek Otomatis)
72
+ # ------------------------------------------------------------
73
+ def auto_smart_select():
74
+ if not global_state["image_set"]:
75
+ return None, None, "⚠️ Upload gambar terlebih dahulu!"
76
+
77
+ from segment_anything import SamAutomaticMaskGenerator
78
+
79
+ predictor = global_state["sam_predictor"]
80
+ sam_model = predictor.model
81
+
82
+ mask_generator = SamAutomaticMaskGenerator(
83
+ model=sam_model,
84
+ points_per_side=32,
85
+ pred_iou_thresh=0.88,
86
+ stability_score_thresh=0.93,
87
+ min_mask_region_area=500
88
+ )
89
+
90
+ masks = mask_generator.generate(global_state["original_image"])
91
+ if not masks:
92
+ return None, None, "❌ Tidak ada objek terdeteksi."
93
+
94
+ # Pilih mask terbaik berdasarkan area dan posisi tengah
95
+ h, w = global_state["original_image"].shape[:2]
96
+ center_x, center_y = w // 2, h // 2
97
+
98
+ def score_mask(mask_data):
99
+ mask = mask_data['segmentation']
100
+ y, x = np.where(mask)
101
+ if len(x) == 0: return 0
102
+ dist = np.sqrt((x.mean() - center_x)**2 + (y.mean() - center_y)**2)
103
+ dist_score = 1 - (dist / np.sqrt(center_x**2 + center_y**2))
104
+ area_ratio = mask_data['area'] / (h * w)
105
+ size_score = 1 - abs(area_ratio - 0.15) # ideal sekitar 15%
106
+ return dist_score * 0.6 + size_score * 0.4
107
+
108
+ best_mask = max(masks, key=score_mask)['segmentation']
109
+ rgba_image = np.dstack((global_state["original_image"], (best_mask * 255).astype(np.uint8)))
110
+ global_state["current_mask"] = best_mask
111
+ global_state["rgba_image"] = rgba_image
112
+
113
+ segmented_preview = Image.fromarray(rgba_image)
114
+ mask_preview = Image.fromarray((best_mask * 255).astype(np.uint8))
115
+
116
+ status = "✅ Foreground otomatis terpilih!\nKlik 'Generate Background' untuk mengganti latar belakang."
117
+ return segmented_preview, mask_preview, status
118
+
119
+ # ------------------------------------------------------------
120
+ # 🎨 Step 3: Ganti Background dengan Prompt
121
+ # ------------------------------------------------------------
122
+ def change_background(background_prompt, negative_prompt, guidance_scale, num_steps, seed):
123
+ if global_state["rgba_image"] is None:
124
+ return None, "⚠️ Generate mask dulu!"
125
+
126
+ rgba_image = Image.fromarray(global_state["rgba_image"])
127
+ pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
128
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
129
+ torch_dtype=torch.float16,
130
+ variant="fp16"
131
+ )
132
+ if torch.cuda.is_available():
133
+ pipeline = pipeline.to("cuda")
134
+
135
+ alpha = np.array(rgba_image.split()[-1])
136
+ mask = np.where(alpha < 128, 255, 0).astype(np.uint8)
137
+ mask = Image.fromarray(mask).convert("L")
138
+ rgb_image = rgba_image.convert("RGB")
139
+
140
+ generator = torch.manual_seed(int(seed))
141
+ result = pipeline(
142
+ image=rgb_image,
143
+ mask_image=mask,
144
+ prompt=background_prompt,
145
+ negative_prompt=negative_prompt,
146
+ guidance_scale=guidance_scale,
147
+ generator=generator,
148
+ num_inference_steps=int(num_steps),
149
+ width=rgb_image.width,
150
+ height=rgb_image.height
151
+ ).images[0]
152
+
153
+ status = f"✅ Background diganti: {background_prompt}"
154
+ return result, status
155
+
156
+ # ------------------------------------------------------------
157
+ # 🧱 Gradio UI
158
+ # ------------------------------------------------------------
159
+ def create_gradio_interface():
160
+ with gr.Blocks(title="AI Background Maker (Offline SAM)") as demo:
161
+ gr.Markdown("# 🖼️ AI Background Maker\nUpload foto → Pilih objek → Ganti background ✨")
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ input_image = gr.Image(label="Upload Gambar", type="pil")
166
+ upload_btn = gr.Button("📥 Load Image", variant="primary")
167
+ auto_smart_btn = gr.Button("🤖 Auto Smart Select (No Click!)", variant="primary")
168
+ bg_prompt = gr.Textbox(value="sunset beach, cinematic, 8k", label="Background Prompt")
169
+ neg_prompt = gr.Textbox(value="blurry, low quality", label="Negative Prompt")
170
+ guidance = gr.Slider(1, 15, 7, 0.5, label="Guidance Scale")
171
+ steps = gr.Slider(10, 50, 30, 5, label="Steps")
172
+ seed = gr.Number(value=42, label="Seed")
173
+ gen_bg = gr.Button("🚀 Generate Background", variant="secondary")
174
+
175
+ with gr.Column(scale=1):
176
+ status = gr.Textbox(label="Status", lines=5)
177
+ segmented = gr.Image(label="Foreground")
178
+ mask = gr.Image(label="Mask")
179
+ result = gr.Image(label="Final Output")
180
+ result_status = gr.Textbox(label="Output Status", lines=2)
181
+
182
+ upload_btn.click(process_upload, inputs=[input_image],
183
+ outputs=[segmented, status, mask, result])
184
+ auto_smart_btn.click(auto_smart_select, outputs=[segmented, mask, status])
185
+ gen_bg.click(change_background,
186
+ inputs=[bg_prompt, neg_prompt, guidance, steps, seed],
187
+ outputs=[result, result_status])
188
+ return demo
189
+
190
+ if __name__ == "__main__":
191
+ demo = create_gradio_interface()
192
+ demo.launch(share=True, debug=True, server_name="0.0.0.0", server_port=7860)