LhatMjnk commited on
Commit
5311f6e
·
verified ·
1 Parent(s): 8c5d28f

upload inference and app python scripts

Browse files
Files changed (2) hide show
  1. app.py +124 -0
  2. inference.py +44 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import time
3
+ import cv2
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ from inference import CoralSegModel
8
+
9
+ model = CoralSegModel()
10
+
11
+ ############################
12
+ # Helpers
13
+ ############################
14
+ def _safe_read(cap):
15
+ ok, frame = cap.read()
16
+ if not ok or frame is None:
17
+ return None
18
+ return frame
19
+
20
+ ############################
21
+ # 1) Remote stream (server-pull)
22
+ ############################
23
+ def remote_stream(rtsp_or_http_url: str, skip_every_n=1):
24
+ """
25
+ Generator that yields processed frames for gr.Video streaming.
26
+ - rtsp_or_http_url: e.g., rtsp://..., http://mjpeg..., or a video file URL
27
+ """
28
+ if not rtsp_or_http_url:
29
+ yield None
30
+ return
31
+
32
+ cap = cv2.VideoCapture(rtsp_or_http_url)
33
+ if not cap.isOpened():
34
+ yield None
35
+ return
36
+
37
+ idx = 0
38
+ try:
39
+ while True:
40
+ frame = _safe_read(cap)
41
+ if frame is None:
42
+ break
43
+
44
+ if skip_every_n > 1 and (idx % skip_every_n) != 0:
45
+ idx += 1
46
+ continue
47
+
48
+ processed = model.predict_overlay(frame)
49
+ # IMPORTANT: Gradio 5 streaming expects raw numpy frames (H, W, 3) BGR/RGB both supported for display
50
+ yield processed
51
+ idx += 1
52
+ # Lower CPU usage a bit (tune this)
53
+ time.sleep(0.001)
54
+ finally:
55
+ cap.release()
56
+
57
+ def uploaded_video_stream(video_file, skip_every_n=1):
58
+ """
59
+ Gradio passes the uploaded file path (string) for gr.Video.
60
+ We open it with OpenCV and yield processed frames to stream.
61
+ """
62
+ if not video_file:
63
+ yield None
64
+ return
65
+
66
+ cap = cv2.VideoCapture(video_file)
67
+ if not cap.isOpened():
68
+ yield None
69
+ return
70
+
71
+ idx = 0
72
+ try:
73
+ while True:
74
+ ok, frame = cap.read()
75
+ if not ok or frame is None:
76
+ break
77
+ if skip_every_n > 1 and (idx % skip_every_n) != 0:
78
+ idx += 1
79
+ continue
80
+ processed = model.predict_overlay(frame)
81
+ yield processed
82
+ idx += 1
83
+ # tiny sleep to reduce CPU spikes; tune as needed
84
+ time.sleep(0.001)
85
+ finally:
86
+ cap.release()
87
+
88
+ ############################
89
+ # UI
90
+ ############################
91
+ with gr.Blocks(title="CoralScapes Streaming Segmentation") as demo:
92
+ gr.Markdown("# CoralScapes Streaming Segmentation")
93
+ gr.Markdown(
94
+ "Two modes: **Remote Stream** (paste RTSP/HTTP/MJPEG URL) or **Upload Video**."
95
+ )
96
+
97
+ with gr.Tab("Remote Stream (RTSP/HTTP)"):
98
+ url = gr.Textbox(
99
+ label="Stream URL (rtsp://..., http://...)", placeholder="rtsp://user:pass@ip:port/..."
100
+ )
101
+ skip = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame (perf tweak)")
102
+ out_image = gr.Image(label="Segmented Stream", streaming=True) # Changed to Image
103
+ start_btn = gr.Button("Start")
104
+ stop_btn = gr.Button("Stop")
105
+
106
+ def _start(url_value, n):
107
+ return remote_stream(url_value, int(n))
108
+
109
+ start_btn.click(_start, inputs=[url, skip], outputs=out_image)
110
+ stop_btn.click(lambda: None, inputs=None, outputs=out_image)
111
+
112
+ with gr.Tab("Upload Video"):
113
+ gr.Markdown("Upload a video file; the server will stream segmented frames back in real time.")
114
+ vid_in = gr.Video(sources=["upload"], format="mp4", label="Input Video")
115
+ out_image = gr.Image(label="Segmented Output (streaming)", streaming=True) # Changed to Image
116
+ start_btn2 = gr.Button("Process")
117
+ stop_btn2 = gr.Button("Stop")
118
+
119
+ skip2 = gr.Slider(1, 5, value=1, step=1, label="Process every Nth frame")
120
+ start_btn2.click(uploaded_video_stream, inputs=[vid_in, skip2], outputs=out_image)
121
+ stop_btn2.click(lambda: None, inputs=None, outputs=out_image)
122
+
123
+ if __name__ == "__main__":
124
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
inference.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
6
+
7
+ # Load model from HF (swap this with your own if you want)
8
+ HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
9
+
10
+ class CoralSegModel:
11
+ def __init__(self, device=None):
12
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.processor = SegformerImageProcessor.from_pretrained(HF_MODEL_ID)
14
+ self.model = SegformerForSemanticSegmentation.from_pretrained(HF_MODEL_ID).to(self.device)
15
+ self.model.eval()
16
+
17
+ # Build a simple color palette for masks (fallback if none provided)
18
+ # 0..N-1 colors - here random-ish but stable
19
+ num_classes = self.model.config.id2label and len(self.model.config.id2label) or 40
20
+ rng = np.random.RandomState(0)
21
+ self.palette = (rng.randint(0, 255, size=(num_classes, 3))).astype(np.uint8)
22
+
23
+ @torch.inference_mode()
24
+ def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
25
+ """
26
+ frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
27
+ returns: np.ndarray HxWx3 in BGR (overlay)
28
+ """
29
+ # Convert BGR -> RGB PIL
30
+ rgb = frame_bgr[:, :, ::-1]
31
+ pil = Image.fromarray(rgb)
32
+
33
+ inputs = self.processor(images=pil, return_tensors="pt").to(self.device)
34
+ outputs = self.model(**inputs)
35
+ logits = outputs.logits # [B, C, h, w]
36
+ upsampled = torch.nn.functional.interpolate(
37
+ logits, size=pil.size[::-1], mode="bilinear", align_corners=False
38
+ )
39
+ pred = upsampled.argmax(dim=1)[0].detach().cpu().numpy().astype(np.uint8) # HxW
40
+
41
+ color_mask = self.palette[pred] # HxWx3 (RGB)
42
+ overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
43
+ overlay_bgr = overlay_rgb[:, :, ::-1]
44
+ return overlay_bgr