Spaces:
Sleeping
Sleeping
| # inference.py | |
| import torch | |
| import torch.nn.functional as F | |
| import spaces | |
| import json | |
| import urllib.request | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation | |
| id2label = json.load(urllib.request.urlopen( | |
| "https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/id2label.json")) | |
| label2color = json.load(urllib.request.urlopen( | |
| "https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/label2color.json")) | |
| # Load model from HF (swap this with your own if you want) | |
| HF_MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024" | |
| def create_segmentation_overlay(pred, id2label, label2color, image, alpha=0.25): | |
| """ | |
| Colorizes the segmentation prediction and creates an overlay image. | |
| Args: | |
| pred: The segmentation prediction (numpy array). | |
| id2label: Dictionary mapping class IDs to labels. | |
| label2color: Dictionary mapping labels to colors. | |
| image: The original PIL Image. | |
| Returns: | |
| A PIL Image representing the overlay of the original image and the colorized segmentation mask. | |
| """ | |
| H, W = pred.shape | |
| rgb = np.zeros((H, W, 3), dtype=np.uint8) | |
| # Get unique class IDs present in the prediction | |
| unique_classes = np.unique(pred) | |
| # Create a mapping from class ID to color | |
| id2color = {int(id): label2color[label] for id, label in id2label.items()} | |
| # Define a default color for unknown classes (e.g., black) | |
| default_color = [0, 0, 0] | |
| # Iterate through unique class IDs and colorize the image | |
| for class_id in unique_classes: | |
| # Get the color for the current class ID, use default_color if not found | |
| rgb_c = id2color.get(int(class_id), default_color) | |
| # Assign the color to the pixels with the current class ID | |
| rgb[pred == class_id] = rgb_c | |
| mask_rgb = Image.fromarray(rgb) | |
| # 4) Alpha overlay | |
| overlay = Image.blend(image.convert("RGBA"), mask_rgb.convert("RGBA"), alpha=alpha) | |
| return overlay | |
| def resize_image(image, target_size=1024): | |
| """ | |
| Used to resize the image such that the smaller side equals 1024 | |
| """ | |
| h_img, w_img = image.size | |
| if h_img < w_img: | |
| new_h, new_w = target_size, int(w_img * (target_size / h_img)) | |
| else: | |
| new_h, new_w = int(h_img * (target_size / w_img)), target_size | |
| resized_img = image.resize((new_h, new_w)) | |
| return resized_img | |
| class CoralSegModel: | |
| def __init__(self, device=None): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID) | |
| self.model = SegformerForSemanticSegmentation.from_pretrained( | |
| HF_MODEL_ID, | |
| dtype=torch.bfloat16 | |
| ).to(self.device) | |
| self.model.eval() | |
| def segment_image(self, image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40) -> np.ndarray: | |
| """ | |
| Finds an optimal stride based on the image size and aspect ratio to create | |
| overlapping sliding windows of size 1024x1024 which are then fed into the model. | |
| """ | |
| h_crop, w_crop = crop_size | |
| img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0) | |
| img = img.to(self.device, torch.bfloat16) | |
| batch_size, _, h_img, w_img = img.size() | |
| h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1 | |
| w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1 | |
| h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop | |
| w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop | |
| preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) | |
| count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) | |
| for h_idx in range(h_grids): | |
| for w_idx in range(w_grids): | |
| y1 = h_idx * h_stride | |
| x1 = w_idx * w_stride | |
| y2 = min(y1 + h_crop, h_img) | |
| x2 = min(x1 + w_crop, w_img) | |
| y1 = max(y2 - h_crop, 0) | |
| x1 = max(x2 - w_crop, 0) | |
| crop_img = img[:, :, y1:y2, x1:x2] | |
| with torch.no_grad(): | |
| if(preprocessor): | |
| inputs = preprocessor(crop_img, return_tensors = "pt", device=self.device) | |
| inputs["pixel_values"] = inputs["pixel_values"].to(self.device, torch.bfloat16) | |
| else: | |
| inputs = crop_img.to(self.device, torch.bfloat16) | |
| outputs = model.to(self.device)(**inputs) | |
| resized_logits = F.interpolate( | |
| outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False | |
| ) | |
| preds += F.pad(resized_logits, | |
| (int(x1), int(preds.shape[3] - x2), int(y1), | |
| int(preds.shape[2] - y2))) | |
| count_mat[:, :, y1:y2, x1:x2] += 1 | |
| assert (count_mat == 0).sum() == 0 | |
| preds = preds / count_mat | |
| preds = preds.argmax(dim=1) | |
| preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest') | |
| label_pred = preds.squeeze().cpu().numpy() | |
| return label_pred | |
| def predict_map_and_overlay(self, frame_bgr: np.ndarray): | |
| """ | |
| Returns: | |
| pred_map: HxW (uint8/int) with class indices in [0..C-1] | |
| overlay: HxWx3 RGB uint8 (blended color mask over original) | |
| rgb: HxWx3 RGB uint8 original frame (for AnnotatedImage base) | |
| """ | |
| rgb = frame_bgr | |
| pil = Image.fromarray(rgb) | |
| pred = self.segment_image(pil, self.processor, self.model) | |
| overlay_rgb = create_segmentation_overlay(pred, id2label, label2color, pil, 0.5) | |
| return pred, overlay_rgb, rgb | |