File size: 7,873 Bytes
83d7af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pillow matplotlib
!pip uninstall -y diffusers transformers
!pip install diffusers==0.30.3 transformers==4.44.2 accelerate
!pip install gradio

import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import cv2
from segment_anything import sam_model_registry, SamPredictor
from diffusers import StableDiffusionXLInpaintPipeline
import os

# ============================================================
# SAM SMART FOREGROUND SELECTOR - OFFLINE MODEL VERSION
# ============================================================

# Global state
global_state = {
    "sam_predictor": None,
    "original_image": None,
    "current_mask": None,
    "rgba_image": None,
    "image_set": False,
    "box_points": [],
    "positive_points": [],
    "negative_points": [],
    "auto_masks": []
}

# ------------------------------------------------------------
# 🧩 Load SAM Model (Offline from /models/)
# ------------------------------------------------------------
def load_sam_model(sam_checkpoint="models/sam_vit_b_01ec64.pth"):
    """Load SAM model from local folder models/"""
    if global_state["sam_predictor"] is None:
        if not os.path.exists(sam_checkpoint):
            raise FileNotFoundError(
                f"❌ File model tidak ditemukan: {sam_checkpoint}\n"
                f"Pastikan kamu sudah upload 'sam_vit_b_01ec64.pth' ke folder /models/"
            )

        print(f"✅ Loading SAM model from {sam_checkpoint}...")
        sam = sam_model_registry["vit_b"](checkpoint=sam_checkpoint)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        sam.to(device=device)
        predictor = SamPredictor(sam)
        global_state["sam_predictor"] = predictor
        print(f"✅ SAM loaded successfully on {device}")
    return global_state["sam_predictor"]

# ------------------------------------------------------------
# 🖼️ Step 1: Upload Image
# ------------------------------------------------------------
def process_upload(image):
    if image is None:
        return None, "⚠️ Upload gambar terlebih dahulu!", None, None

    image_np = np.array(image)
    global_state["original_image"] = image_np.copy()
    predictor = load_sam_model()
    predictor.set_image(image_np)
    global_state["image_set"] = True

    status = "✅ Gambar berhasil diupload!\n\nKlik 'Auto Smart Select' atau gunakan Box/Point untuk memilih objek."
    return Image.fromarray(image_np), status, None, None

# ------------------------------------------------------------
# ✨ Step 2: Auto Smart Select (Pilih Objek Otomatis)
# ------------------------------------------------------------
def auto_smart_select():
    if not global_state["image_set"]:
        return None, None, "⚠️ Upload gambar terlebih dahulu!"

    from segment_anything import SamAutomaticMaskGenerator

    predictor = global_state["sam_predictor"]
    sam_model = predictor.model

    mask_generator = SamAutomaticMaskGenerator(
        model=sam_model,
        points_per_side=32,
        pred_iou_thresh=0.88,
        stability_score_thresh=0.93,
        min_mask_region_area=500
    )

    masks = mask_generator.generate(global_state["original_image"])
    if not masks:
        return None, None, "❌ Tidak ada objek terdeteksi."

    # Pilih mask terbaik berdasarkan area dan posisi tengah
    h, w = global_state["original_image"].shape[:2]
    center_x, center_y = w // 2, h // 2

    def score_mask(mask_data):
        mask = mask_data['segmentation']
        y, x = np.where(mask)
        if len(x) == 0: return 0
        dist = np.sqrt((x.mean() - center_x)**2 + (y.mean() - center_y)**2)
        dist_score = 1 - (dist / np.sqrt(center_x**2 + center_y**2))
        area_ratio = mask_data['area'] / (h * w)
        size_score = 1 - abs(area_ratio - 0.15)  # ideal sekitar 15%
        return dist_score * 0.6 + size_score * 0.4

    best_mask = max(masks, key=score_mask)['segmentation']
    rgba_image = np.dstack((global_state["original_image"], (best_mask * 255).astype(np.uint8)))
    global_state["current_mask"] = best_mask
    global_state["rgba_image"] = rgba_image

    segmented_preview = Image.fromarray(rgba_image)
    mask_preview = Image.fromarray((best_mask * 255).astype(np.uint8))

    status = "✅ Foreground otomatis terpilih!\nKlik 'Generate Background' untuk mengganti latar belakang."
    return segmented_preview, mask_preview, status

# ------------------------------------------------------------
# 🎨 Step 3: Ganti Background dengan Prompt
# ------------------------------------------------------------
def change_background(background_prompt, negative_prompt, guidance_scale, num_steps, seed):
    if global_state["rgba_image"] is None:
        return None, "⚠️ Generate mask dulu!"

    rgba_image = Image.fromarray(global_state["rgba_image"])
    pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
        torch_dtype=torch.float16,
        variant="fp16"
    )
    if torch.cuda.is_available():
        pipeline = pipeline.to("cuda")

    alpha = np.array(rgba_image.split()[-1])
    mask = np.where(alpha < 128, 255, 0).astype(np.uint8)
    mask = Image.fromarray(mask).convert("L")
    rgb_image = rgba_image.convert("RGB")

    generator = torch.manual_seed(int(seed))
    result = pipeline(
        image=rgb_image,
        mask_image=mask,
        prompt=background_prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        generator=generator,
        num_inference_steps=int(num_steps),
        width=rgb_image.width,
        height=rgb_image.height
    ).images[0]

    status = f"✅ Background diganti: {background_prompt}"
    return result, status

# ------------------------------------------------------------
# 🧱 Gradio UI
# ------------------------------------------------------------
def create_gradio_interface():
    with gr.Blocks(title="AI Background Maker (Offline SAM)") as demo:
        gr.Markdown("# 🖼️ AI Background Maker\nUpload foto → Pilih objek → Ganti background ✨")

        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(label="Upload Gambar", type="pil")
                upload_btn = gr.Button("📥 Load Image", variant="primary")
                auto_smart_btn = gr.Button("🤖 Auto Smart Select (No Click!)", variant="primary")
                bg_prompt = gr.Textbox(value="sunset beach, cinematic, 8k", label="Background Prompt")
                neg_prompt = gr.Textbox(value="blurry, low quality", label="Negative Prompt")
                guidance = gr.Slider(1, 15, 7, 0.5, label="Guidance Scale")
                steps = gr.Slider(10, 50, 30, 5, label="Steps")
                seed = gr.Number(value=42, label="Seed")
                gen_bg = gr.Button("🚀 Generate Background", variant="secondary")

            with gr.Column(scale=1):
                status = gr.Textbox(label="Status", lines=5)
                segmented = gr.Image(label="Foreground")
                mask = gr.Image(label="Mask")
                result = gr.Image(label="Final Output")
                result_status = gr.Textbox(label="Output Status", lines=2)

        upload_btn.click(process_upload, inputs=[input_image],
                         outputs=[segmented, status, mask, result])
        auto_smart_btn.click(auto_smart_select, outputs=[segmented, mask, status])
        gen_bg.click(change_background,
                     inputs=[bg_prompt, neg_prompt, guidance, steps, seed],
                     outputs=[result, result_status])
    return demo

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(share=True, debug=True, server_name="0.0.0.0", server_port=7860)