Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from typing import Dict, Tuple, Optional | |
| import network | |
| class Predictor: | |
| """ | |
| Wrapper for ScribblePrompt Unet model | |
| """ | |
| def __init__(self, path: str, verbose: bool = False): | |
| self.verbose = verbose | |
| assert path.exists(), f"Checkpoint {path} does not exist" | |
| self.path = path | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.build_model() | |
| self.load() | |
| self.model.eval() | |
| self.to_device() | |
| def build_model(self): | |
| """ | |
| Build the model | |
| """ | |
| self.model = network.UNet( | |
| in_channels = 5, | |
| out_channels = 1, | |
| features = [192, 192, 192, 192], | |
| ) | |
| def load(self): | |
| """ | |
| Load the state of the model from a checkpoint file. | |
| """ | |
| with (self.path).open("rb") as f: | |
| state = torch.load(f, map_location=self.device) | |
| self.model.load_state_dict(state, strict=True) | |
| if self.verbose: | |
| print( | |
| f"Loaded checkpoint from {self.path} to {self.device}" | |
| ) | |
| def to_device(self): | |
| """ | |
| Move the model to cpu or gpu | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = self.model.to(self.device) | |
| def predict(self, prompts: Dict[str,any], img_features: Optional[torch.Tensor] = None, multimask_mode: bool = False): | |
| """ | |
| Make predictions! | |
| Returns: | |
| mask (torch.Tensor): H x W | |
| img_features (torch.Tensor): B x 1 x H x W (for SAM models) | |
| low_res_mask (torch.Tensor): B x 1 x H x W logits | |
| """ | |
| if self.verbose: | |
| print("point_coords", prompts.get("point_coords", None)) | |
| print("point_labels", prompts.get("point_labels", None)) | |
| print("box", prompts.get("box", None)) | |
| print("img", prompts.get("img").shape, prompts.get("img").min(), prompts.get("img").max()) | |
| if prompts.get("scribble") is not None: | |
| print("scribble", prompts.get("scribble", None).shape, prompts.get("scribble").min(), prompts.get("scribble").max()) | |
| original_shape = prompts.get('img').shape[-2:] | |
| # Rescale to 128 x 128 | |
| prompts = rescale_inputs(prompts) | |
| # Prepare inputs for ScribblePrompt unet (1 x 5 x 128 x 128) | |
| x = prepare_inputs(prompts).float() | |
| with torch.no_grad(): | |
| yhat = self.model(x.to(self.device)).cpu() | |
| mask = torch.sigmoid(yhat) | |
| # Resize for app resolution | |
| mask = F.interpolate(mask, size=original_shape, mode='bilinear').squeeze() | |
| # mask: H x W, yhat: 1 x 1 x H x W | |
| return mask, None, yhat | |
| # ----------------------------------------------------------------------------- | |
| # Prepare inputs | |
| # ----------------------------------------------------------------------------- | |
| def rescale_inputs(inputs: Dict[str,any], res=128): | |
| """ | |
| Rescale the inputs | |
| """ | |
| h,w = inputs['img'].shape[-2:] | |
| if h != res or w != res: | |
| inputs.update(dict( | |
| img = F.interpolate(inputs['img'], size=(res,res), mode='bilinear') | |
| )) | |
| if inputs.get('scribble') is not None: | |
| inputs.update({ | |
| 'scribble': F.interpolate(inputs['scribble'], size=(res,res), mode='bilinear') | |
| }) | |
| if inputs.get("box") is not None: | |
| boxes = inputs.get("box").clone() | |
| coords = boxes.reshape(-1, 2, 2) | |
| coords[..., 0] = coords[..., 0] * (res / w) | |
| coords[..., 1] = coords[..., 1] * (res / h) | |
| inputs.update({'box': coords.reshape(1, -1, 4).int()}) | |
| if inputs.get("point_coords") is not None: | |
| coords = inputs.get("point_coords").clone() | |
| coords[..., 0] = coords[..., 0] * (res / w) | |
| coords[..., 1] = coords[..., 1] * (res / h) | |
| inputs.update({'point_coords': coords.int()}) | |
| return inputs | |
| def prepare_inputs(inputs: Dict[str,torch.Tensor], device = None) -> torch.Tensor: | |
| """ | |
| Prepare inputs for ScribblePrompt Unet | |
| Returns: | |
| x (torch.Tensor): B x 5 x H x W | |
| """ | |
| img = inputs['img'] | |
| if device is None: | |
| device = img.device | |
| img = img.to(device) | |
| shape = tuple(img.shape[-2:]) | |
| if inputs.get("box") is not None: | |
| # Embed bounding box | |
| # Input: B x 1 x 4 | |
| # Output: B x 1 x H x W | |
| box_embed = bbox_shaded(inputs['box'], shape=shape, device=device) | |
| else: | |
| box_embed = torch.zeros(img.shape, device=device) | |
| if inputs.get("point_coords") is not None: | |
| # Encode points | |
| # B x 2 x H x W | |
| scribble_click_embed = click_onehot(inputs['point_coords'], inputs['point_labels'], shape=shape) | |
| else: | |
| scribble_click_embed = torch.zeros((img.shape[0], 2) + shape, device=device) | |
| if inputs.get("scribble") is not None: | |
| # Combine scribbles with click encoding | |
| # B x 2 x H x W | |
| scribble_click_embed = torch.clamp(scribble_click_embed + inputs.get('scribble'), min=0.0, max=1.0) | |
| if inputs.get('mask_input') is not None: | |
| # Previous prediction | |
| mask_input = inputs['mask_input'] | |
| else: | |
| # Initialize empty channel for mask input | |
| mask_input = torch.zeros(img.shape, device=img.device) | |
| x = torch.cat((img, box_embed, scribble_click_embed, mask_input), dim=-3) | |
| # B x 5 x H x W | |
| return x | |
| # ----------------------------------------------------------------------------- | |
| # Encode clicks and bounding boxes | |
| # ----------------------------------------------------------------------------- | |
| def click_onehot(point_coords, point_labels, shape: Tuple[int,int] = (128,128), indexing='xy'): | |
| """ | |
| Represent clicks as two HxW binary masks (one for positive clicks and one for negative) | |
| with 1 at the click locations and 0 otherwise | |
| Args: | |
| point_coords (torch.Tensor): BxNx2 tensor of xy coordinates | |
| point_labels (torch.Tensor): BxN tensor of labels (0 or 1) | |
| shape (tuple): output shape | |
| Returns: | |
| embed (torch.Tensor): Bx2xHxW tensor | |
| """ | |
| assert indexing in ['xy','uv'], f"Invalid indexing: {indexing}" | |
| assert len(point_coords.shape) == 3, "point_coords must be BxNx2" | |
| assert point_coords.shape[-1] == 2, "point_coords must be BxNx2" | |
| assert point_labels.shape[-1] == point_coords.shape[1], "point_labels must be BxN" | |
| assert len(shape)==2, f"shape must be 2D: {shape}" | |
| device = point_coords.device | |
| batch_size = point_coords.shape[0] | |
| n_points = point_coords.shape[1] | |
| embed = torch.zeros((batch_size,2)+shape, device=device) | |
| labels = point_labels.flatten().float() | |
| idx_coords = torch.cat(( | |
| torch.arange(batch_size, device=device).reshape(-1,1).repeat(1,n_points)[...,None], | |
| point_coords | |
| ), axis=2).reshape(-1,3) | |
| if indexing=='xy': | |
| embed[ idx_coords[:,0], 0, idx_coords[:,2], idx_coords[:,1] ] = labels | |
| embed[ idx_coords[:,0], 1, idx_coords[:,2], idx_coords[:,1] ] = 1.0-labels | |
| else: | |
| embed[ idx_coords[:,0], 0, idx_coords[:,1], idx_coords[:,2] ] = labels | |
| embed[ idx_coords[:,0], 1, idx_coords[:,1], idx_coords[:,2] ] = 1.0-labels | |
| return embed | |
| def bbox_shaded(boxes, shape: Tuple[int,int] = (128,128), device='cpu'): | |
| """ | |
| Represent bounding boxes as a binary mask with 1 inside boxes and 0 otherwise | |
| Args: | |
| boxes (torch.Tensor): Bx1x4 [x1, y1, x2, y2] | |
| Returns: | |
| bbox_embed (torch.Tesor): Bx1xHxW according to shape | |
| """ | |
| assert len(shape)==2, "shape must be 2D" | |
| if isinstance(boxes, torch.Tensor): | |
| boxes = boxes.int().cpu().numpy() | |
| batch_size = boxes.shape[0] | |
| n_boxes = boxes.shape[1] | |
| bbox_embed = torch.zeros((batch_size,1)+tuple(shape), device=device, dtype=torch.float32) | |
| if boxes is not None: | |
| for i in range(batch_size): | |
| for j in range(n_boxes): | |
| x1, y1, x2, y2 = boxes[i,j,:] | |
| x_min = min(x1,x2) | |
| x_max = max(x1,x2) | |
| y_min = min(y1,y2) | |
| y_max = max(y1,y2) | |
| bbox_embed[ i, 0, y_min:y_max, x_min:x_max ] = 1.0 | |
| return bbox_embed |