LhatMjnk commited on
Commit
664f6dd
·
verified ·
1 Parent(s): 14ebdf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -72
app.py CHANGED
@@ -1,124 +1,191 @@
1
- # app.py
2
- import time
3
  import cv2
4
  import numpy as np
5
  import gradio as gr
6
 
7
- from inference import CoralSegModel
8
-
9
  model = CoralSegModel()
10
 
11
- ############################
12
- # Helpers
13
- ############################
14
  def _safe_read(cap):
15
  ok, frame = cap.read()
16
- if not ok or frame is None:
17
- return None
18
- return frame
19
-
20
- ############################
21
- # 1) Remote stream (server-pull)
22
- ############################
23
- def remote_stream(rtsp_or_http_url: str, skip_every_n=1):
24
- """
25
- Generator that yields processed frames for gr.Video streaming.
26
- - rtsp_or_http_url: e.g., rtsp://..., http://mjpeg..., or a video file URL
27
- """
28
- if not rtsp_or_http_url:
29
- yield None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return
31
-
32
- cap = cv2.VideoCapture(rtsp_or_http_url)
33
  if not cap.isOpened():
34
- yield None
35
  return
36
-
37
  idx = 0
38
  try:
39
  while True:
40
  frame = _safe_read(cap)
41
  if frame is None:
42
  break
43
-
44
- if skip_every_n > 1 and (idx % skip_every_n) != 0:
45
  idx += 1
46
  continue
47
-
48
- processed = model.predict_overlay(frame)
49
- # IMPORTANT: Gradio 5 streaming expects raw numpy frames (H, W, 3) BGR/RGB both supported for display
50
- yield processed
51
  idx += 1
52
- # Lower CPU usage a bit (tune this)
53
- # time.sleep(0.001)
54
  finally:
55
  cap.release()
56
 
57
- def uploaded_video_stream(video_file, skip_every_n=1):
58
- """
59
- Gradio passes the uploaded file path (string) for gr.Video.
60
- We open it with OpenCV and yield processed frames to stream.
61
- """
62
  if not video_file:
63
- yield None
64
  return
65
-
66
  cap = cv2.VideoCapture(video_file)
67
  if not cap.isOpened():
68
- yield None
69
  return
70
-
71
  idx = 0
72
  try:
73
  while True:
74
  ok, frame = cap.read()
75
  if not ok or frame is None:
76
  break
77
- if skip_every_n > 1 and (idx % skip_every_n) != 0:
78
  idx += 1
79
  continue
80
- processed = model.predict_overlay(frame)
81
- yield processed
 
82
  idx += 1
83
- # tiny sleep to reduce CPU spikes; tune as needed
84
- # time.sleep(0.001)
85
  finally:
86
  cap.release()
87
 
88
- ############################
 
 
 
 
 
 
 
 
 
 
 
 
89
  # UI
90
- ############################
91
  with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
92
  gr.Markdown("# CoralScapes Streaming Segmentation")
93
  gr.Markdown(
94
- "Two modes: **Remote Stream** (paste RTSP/HTTP/MJPEG URL) or **Upload Video**."
95
  )
96
 
97
  with gr.Tab("Remote Stream (RTSP/HTTP)"):
98
- url = gr.Textbox(
99
- label="Stream URL (rtsp://..., http://...)", placeholder="rtsp://user:pass@ip:port/..."
100
- )
101
- skip = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame (perf tweak)")
102
- out_image = gr.Image(label="Segmented Stream", streaming=True) # Changed to Image
103
- start_btn = gr.Button("Start")
104
- stop_btn = gr.Button("Stop")
105
-
106
- def _start(url_value, n):
107
- return remote_stream(url_value, int(n))
108
-
109
- start_btn.click(_start, inputs=[url, skip], outputs=out_image)
110
- stop_btn.click(lambda: None, inputs=None, outputs=out_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  with gr.Tab("Upload Video"):
113
- gr.Markdown("Upload a video file; the server will stream segmented frames back in real time.")
114
- vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
115
- out_image = gr.Image(label="Segmented Output (streaming)", streaming=True) # Changed to Image
116
- start_btn2 = gr.Button("Process")
117
- stop_btn2 = gr.Button("Stop")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
120
- start_btn2.click(uploaded_video_stream, inputs=[vid_in, skip2], outputs=out_image)
121
- stop_btn2.click(lambda: None, inputs=None, outputs=out_image)
 
 
 
 
 
 
 
 
122
 
123
  if __name__ == "__main__":
124
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
+ from PIL import Image
 
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
5
 
6
+ from inference import CoralSegModel, id2label, label2color, create_segmentation_overlay
 
7
  model = CoralSegModel()
8
 
9
+ # ---- helpers ----
 
 
10
  def _safe_read(cap):
11
  ok, frame = cap.read()
12
+ return frame if ok and frame is not None else None
13
+
14
+ def build_annotations(pred_map: np.ndarray, selected: list[str]) -> list[tuple[np.ndarray, str]]:
15
+ """Return [(mask,label), ...] where mask is 0/1 float HxW for AnnotatedImage."""
16
+ if pred_map is None or not selected:
17
+ return []
18
+
19
+ # Create reverse mapping: label_name -> class_id
20
+ label2id = {label: int(id_str) for id_str, label in id2label.items()}
21
+
22
+ anns = []
23
+ for label_name in selected:
24
+ if label_name not in label2id:
25
+ continue # Skip unknown labels
26
+
27
+ class_id = label2id[label_name] # Convert label name to class ID
28
+ mask = (pred_map == class_id).astype(np.float32)
29
+ if mask.sum() > 0:
30
+ anns.append((mask, label_name)) # Use the label name for display
31
+ return anns
32
+
33
+ # ==============================
34
+ # STREAMING EVENT FUNCTIONS
35
+ # ==============================
36
+ # IMPORTANT: make the event functions themselves generators.
37
+ # Also: include the States as outputs so we can update them every frame.
38
+ def remote_start(url: str, n: int, pred_state, base_state):
39
+ if not url:
40
  return
41
+ cap = cv2.VideoCapture(url)
 
42
  if not cap.isOpened():
 
43
  return
 
44
  idx = 0
45
  try:
46
  while True:
47
  frame = _safe_read(cap)
48
  if frame is None:
49
  break
50
+ if n > 1 and (idx % n) != 0:
 
51
  idx += 1
52
  continue
53
+ pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
54
+ # yield live image + updated States' *values*
55
+ yield overlay_rgb, pred_map, base_rgb
 
56
  idx += 1
 
 
57
  finally:
58
  cap.release()
59
 
60
+ def upload_start(video_file: str, n: int):
 
 
 
 
61
  if not video_file:
 
62
  return
 
63
  cap = cv2.VideoCapture(video_file)
64
  if not cap.isOpened():
 
65
  return
 
66
  idx = 0
67
  try:
68
  while True:
69
  ok, frame = cap.read()
70
  if not ok or frame is None:
71
  break
72
+ if n > 1 and (idx % n) != 0:
73
  idx += 1
74
  continue
75
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
+ pred_map, overlay_rgb, base_rgb = model.predict_map_and_overlay(frame)
77
+ yield overlay_rgb, pred_map, base_rgb
78
  idx += 1
 
 
79
  finally:
80
  cap.release()
81
 
82
+ # ==============================
83
+ # SNAPSHOT / TOGGLES (non-streaming)
84
+ # ==============================
85
+ # NOTE: When you pass gr.State as an input, you receive the *value*, not the wrapper.
86
+ def make_snapshot(selected_labels, pred_map, base_rgb, alpha=0.25):
87
+ if pred_map is None or base_rgb is None:
88
+ return gr.update()
89
+ # rebuild overlay to match the live look
90
+ overlay = create_segmentation_overlay(pred_map, id2label, label2color, Image.fromarray(base_rgb), alpha=alpha)
91
+ ann = build_annotations(pred_map, selected_labels or [])
92
+ return (overlay, ann) # (base_image, [(mask,label), ...])
93
+
94
+ # ==============================
95
  # UI
96
+ # ==============================
97
  with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
98
  gr.Markdown("# CoralScapes Streaming Segmentation")
99
  gr.Markdown(
100
+ "Left: **live stream** (fast). Right: **snapshot** with **hover labels** and **per-class toggles**."
101
  )
102
 
103
  with gr.Tab("Remote Stream (RTSP/HTTP)"):
104
+ with gr.Row():
105
+ with gr.Column(scale=2):
106
+
107
+ # States start as None. We'll UPDATE them on every frame by returning them as outputs.
108
+ pred_state_remote = gr.State(None) # holds last pred_map (HxW np.uint8)
109
+ base_state_remote = gr.State(None) # holds last base_rgb (HxWx3 uint8)
110
+
111
+ live_remote = gr.Image(label="Live segmented stream")
112
+
113
+ start_btn = gr.Button("Start")
114
+
115
+ snap_btn_remote = gr.Button("📸 Snapshot (hover-able)")
116
+ hover_remote = gr.AnnotatedImage(label="Snapshot (hover to see label)")
117
+
118
+
119
+ with gr.Column(scale=1):
120
+ url = gr.Textbox(label="Stream URL", placeholder="rtsp://user:pass@ip:port/…")
121
+ skip = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
122
+
123
+ toggles_remote = gr.CheckboxGroup(
124
+ choices=list(id2label.values()), value=list(id2label.values()),
125
+ label="Toggle classes in snapshot",
126
+ )
127
+
128
+ start_btn.click(
129
+ remote_start,
130
+ inputs=[url, skip, pred_state_remote, base_state_remote],
131
+ outputs=[live_remote, pred_state_remote, base_state_remote],
132
+ queue=True, # be explicit; required for generator streaming
133
+ )
134
+
135
+ snap_btn_remote.click(
136
+ make_snapshot,
137
+ inputs=[toggles_remote, pred_state_remote, base_state_remote],
138
+ outputs=[hover_remote],
139
+ )
140
+ toggles_remote.change(
141
+ make_snapshot,
142
+ inputs=[toggles_remote, pred_state_remote, base_state_remote],
143
+ outputs=[hover_remote],
144
+ )
145
 
146
  with gr.Tab("Upload Video"):
147
+ with gr.Row():
148
+ # Left column (now contains toggles, snapshot button, and live output)
149
+ with gr.Column(scale=2):
150
+ # States remain in the same column as live_upload
151
+ pred_state_upload = gr.State(None)
152
+ base_state_upload = gr.State(None)
153
+
154
+ live_upload = gr.Image(label="Live segmented output")
155
+ start_btn2 = gr.Button("Process")
156
+
157
+ snap_btn_upload = gr.Button("📸 Snapshot (hover-able)")
158
+ hover_upload = gr.AnnotatedImage(label="Snapshot (hover to see label)")
159
+
160
+ # Right column (now contains video input and slider)
161
+ with gr.Column(scale=1):
162
+ vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
163
+ skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
164
+
165
+ toggles_upload = gr.CheckboxGroup(
166
+ choices=list(id2label.values()), value=list(id2label.values()),
167
+ label="Toggle classes in snapshot",
168
+ )
169
+
170
+ # Event handlers remain the same
171
+ start_btn2.click(
172
+ upload_start,
173
+ inputs=[vid_in, skip2],
174
+ outputs=[live_upload, pred_state_upload, base_state_upload],
175
+ queue=True,
176
+ )
177
 
178
+ snap_btn_upload.click(
179
+ make_snapshot,
180
+ inputs=[toggles_upload, pred_state_upload, base_state_upload],
181
+ outputs=[hover_upload],
182
+ )
183
+
184
+ toggles_upload.change(
185
+ make_snapshot,
186
+ inputs=[toggles_upload, pred_state_upload, base_state_upload],
187
+ outputs=[hover_upload],
188
+ )
189
 
190
  if __name__ == "__main__":
191
+ demo.queue().launch(share=True)