Spaces:
Configuration error
Configuration error
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| import sys as _sys | |
| from keyword import iskeyword as _iskeyword | |
| from operator import itemgetter as _itemgetter | |
| from segment_anything import SamPredictor | |
| from comfy import model_management | |
| ################################################################################ | |
| ### namedtuple | |
| ################################################################################ | |
| try: | |
| from _collections import _tuplegetter | |
| except ImportError: | |
| _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) | |
| def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): | |
| """Returns a new subclass of tuple with named fields. | |
| >>> Point = namedtuple('Point', ['x', 'y']) | |
| >>> Point.__doc__ # docstring for the new class | |
| 'Point(x, y)' | |
| >>> p = Point(11, y=22) # instantiate with positional args or keywords | |
| >>> p[0] + p[1] # indexable like a plain tuple | |
| 33 | |
| >>> x, y = p # unpack like a regular tuple | |
| >>> x, y | |
| (11, 22) | |
| >>> p.x + p.y # fields also accessible by name | |
| 33 | |
| >>> d = p._asdict() # convert to a dictionary | |
| >>> d['x'] | |
| 11 | |
| >>> Point(**d) # convert from a dictionary | |
| Point(x=11, y=22) | |
| >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields | |
| Point(x=100, y=22) | |
| """ | |
| # Validate the field names. At the user's option, either generate an error | |
| # message or automatically replace the field name with a valid name. | |
| if isinstance(field_names, str): | |
| field_names = field_names.replace(',', ' ').split() | |
| field_names = list(map(str, field_names)) | |
| typename = _sys.intern(str(typename)) | |
| if rename: | |
| seen = set() | |
| for index, name in enumerate(field_names): | |
| if (not name.isidentifier() | |
| or _iskeyword(name) | |
| or name.startswith('_') | |
| or name in seen): | |
| field_names[index] = f'_{index}' | |
| seen.add(name) | |
| for name in [typename] + field_names: | |
| if type(name) is not str: | |
| raise TypeError('Type names and field names must be strings') | |
| if not name.isidentifier(): | |
| raise ValueError('Type names and field names must be valid ' | |
| f'identifiers: {name!r}') | |
| if _iskeyword(name): | |
| raise ValueError('Type names and field names cannot be a ' | |
| f'keyword: {name!r}') | |
| seen = set() | |
| for name in field_names: | |
| if name.startswith('_') and not rename: | |
| raise ValueError('Field names cannot start with an underscore: ' | |
| f'{name!r}') | |
| if name in seen: | |
| raise ValueError(f'Encountered duplicate field name: {name!r}') | |
| seen.add(name) | |
| field_defaults = {} | |
| if defaults is not None: | |
| defaults = tuple(defaults) | |
| if len(defaults) > len(field_names): | |
| raise TypeError('Got more default values than field names') | |
| field_defaults = dict(reversed(list(zip(reversed(field_names), | |
| reversed(defaults))))) | |
| # Variables used in the methods and docstrings | |
| field_names = tuple(map(_sys.intern, field_names)) | |
| num_fields = len(field_names) | |
| arg_list = ', '.join(field_names) | |
| if num_fields == 1: | |
| arg_list += ',' | |
| repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' | |
| tuple_new = tuple.__new__ | |
| _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip | |
| # Create all the named tuple methods to be added to the class namespace | |
| namespace = { | |
| '_tuple_new': tuple_new, | |
| '__builtins__': {}, | |
| '__name__': f'namedtuple_{typename}', | |
| } | |
| code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' | |
| __new__ = eval(code, namespace) | |
| __new__.__name__ = '__new__' | |
| __new__.__doc__ = f'Create new instance of {typename}({arg_list})' | |
| if defaults is not None: | |
| __new__.__defaults__ = defaults | |
| def _make(cls, iterable): | |
| result = tuple_new(cls, iterable) | |
| if _len(result) != num_fields: | |
| raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') | |
| return result | |
| _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' | |
| 'or iterable') | |
| def _replace(self, /, **kwds): | |
| result = self._make(_map(kwds.pop, field_names, self)) | |
| if kwds: | |
| raise ValueError(f'Got unexpected field names: {list(kwds)!r}') | |
| return result | |
| _replace.__doc__ = (f'Return a new {typename} object replacing specified ' | |
| 'fields with new values') | |
| def __repr__(self): | |
| 'Return a nicely formatted representation string' | |
| return self.__class__.__name__ + repr_fmt % self | |
| def _asdict(self): | |
| 'Return a new dict which maps field names to their values.' | |
| return _dict(_zip(self._fields, self)) | |
| def __getnewargs__(self): | |
| 'Return self as a plain tuple. Used by copy and pickle.' | |
| return _tuple(self) | |
| # Modify function metadata to help with introspection and debugging | |
| for method in ( | |
| __new__, | |
| _make.__func__, | |
| _replace, | |
| __repr__, | |
| _asdict, | |
| __getnewargs__, | |
| ): | |
| method.__qualname__ = f'{typename}.{method.__name__}' | |
| # Build-up the class namespace dictionary | |
| # and use type() to build the result class | |
| class_namespace = { | |
| '__doc__': f'{typename}({arg_list})', | |
| '__slots__': (), | |
| '_fields': field_names, | |
| '_field_defaults': field_defaults, | |
| '__new__': __new__, | |
| '_make': _make, | |
| '_replace': _replace, | |
| '__repr__': __repr__, | |
| '_asdict': _asdict, | |
| '__getnewargs__': __getnewargs__, | |
| '__match_args__': field_names, | |
| } | |
| for index, name in enumerate(field_names): | |
| doc = _sys.intern(f'Alias for field number {index}') | |
| class_namespace[name] = _tuplegetter(index, doc) | |
| result = type(typename, (tuple,), class_namespace) | |
| # For pickling to work, the __module__ variable needs to be set to the frame | |
| # where the named tuple is created. Bypass this step in environments where | |
| # sys._getframe is not defined (Jython for example) or sys._getframe is not | |
| # defined for arguments greater than 0 (IronPython), or where the user has | |
| # specified a particular module. | |
| if module is None: | |
| try: | |
| module = _sys._getframe(1).f_globals.get('__name__', '__main__') | |
| except (AttributeError, ValueError): | |
| pass | |
| if module is not None: | |
| result.__module__ = module | |
| return result | |
| SEG = namedtuple("SEG", | |
| ['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], | |
| defaults=[None]) | |
| def crop_ndarray4(npimg, crop_region): | |
| x1 = crop_region[0] | |
| y1 = crop_region[1] | |
| x2 = crop_region[2] | |
| y2 = crop_region[3] | |
| cropped = npimg[:, y1:y2, x1:x2, :] | |
| return cropped | |
| crop_tensor4 = crop_ndarray4 | |
| def crop_ndarray2(npimg, crop_region): | |
| x1 = crop_region[0] | |
| y1 = crop_region[1] | |
| x2 = crop_region[2] | |
| y2 = crop_region[3] | |
| cropped = npimg[y1:y2, x1:x2] | |
| return cropped | |
| def crop_image(image, crop_region): | |
| return crop_tensor4(image, crop_region) | |
| def normalize_region(limit, startp, size): | |
| if startp < 0: | |
| new_endp = min(limit, size) | |
| new_startp = 0 | |
| elif startp + size > limit: | |
| new_startp = max(0, limit - size) | |
| new_endp = limit | |
| else: | |
| new_startp = startp | |
| new_endp = min(limit, startp+size) | |
| return int(new_startp), int(new_endp) | |
| def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): | |
| x1 = bbox[0] | |
| y1 = bbox[1] | |
| x2 = bbox[2] | |
| y2 = bbox[3] | |
| bbox_w = x2 - x1 | |
| bbox_h = y2 - y1 | |
| crop_w = bbox_w * crop_factor | |
| crop_h = bbox_h * crop_factor | |
| if crop_min_size is not None: | |
| crop_w = max(crop_min_size, crop_w) | |
| crop_h = max(crop_min_size, crop_h) | |
| kernel_x = x1 + bbox_w / 2 | |
| kernel_y = y1 + bbox_h / 2 | |
| new_x1 = int(kernel_x - crop_w / 2) | |
| new_y1 = int(kernel_y - crop_h / 2) | |
| # make sure position in (w,h) | |
| new_x1, new_x2 = normalize_region(w, new_x1, crop_w) | |
| new_y1, new_y2 = normalize_region(h, new_y1, crop_h) | |
| return [new_x1, new_y1, new_x2, new_y2] | |
| def create_segmasks(results): | |
| bboxs = results[1] | |
| segms = results[2] | |
| confidence = results[3] | |
| results = [] | |
| for i in range(len(segms)): | |
| item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) | |
| results.append(item) | |
| return results | |
| def dilate_masks(segmasks, dilation_factor, iter=1): | |
| if dilation_factor == 0: | |
| return segmasks | |
| dilated_masks = [] | |
| kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) | |
| kernel = cv2.UMat(kernel) | |
| for i in range(len(segmasks)): | |
| cv2_mask = segmasks[i][1] | |
| cv2_mask = cv2.UMat(cv2_mask) | |
| if dilation_factor > 0: | |
| dilated_mask = cv2.dilate(cv2_mask, kernel, iter) | |
| else: | |
| dilated_mask = cv2.erode(cv2_mask, kernel, iter) | |
| dilated_mask = dilated_mask.get() | |
| item = (segmasks[i][0], dilated_mask, segmasks[i][2]) | |
| dilated_masks.append(item) | |
| return dilated_masks | |
| def is_same_device(a, b): | |
| a_device = torch.device(a) if isinstance(a, str) else a | |
| b_device = torch.device(b) if isinstance(b, str) else b | |
| return a_device.type == b_device.type and a_device.index == b_device.index | |
| class SafeToGPU: | |
| def __init__(self, size): | |
| self.size = size | |
| def to_device(self, obj, device): | |
| if is_same_device(device, 'cpu'): | |
| obj.to(device) | |
| else: | |
| if is_same_device(obj.device, 'cpu'): # cpu to gpu | |
| model_management.free_memory(self.size * 1.3, device) | |
| if model_management.get_free_memory(device) > self.size * 1.3: | |
| try: | |
| obj.to(device) | |
| except: | |
| print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]") | |
| else: | |
| print(f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]") | |
| def center_of_bbox(bbox): | |
| w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| return bbox[0] + w/2, bbox[1] + h/2 | |
| def sam_predict(predictor, points, plabs, bbox, threshold): | |
| point_coords = None if not points else np.array(points) | |
| point_labels = None if not plabs else np.array(plabs) | |
| box = np.array([bbox]) if bbox is not None else None | |
| cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) | |
| total_masks = [] | |
| selected = False | |
| max_score = 0 | |
| max_mask = None | |
| for idx in range(len(scores)): | |
| if scores[idx] > max_score: | |
| max_score = scores[idx] | |
| max_mask = cur_masks[idx] | |
| if scores[idx] >= threshold: | |
| selected = True | |
| total_masks.append(cur_masks[idx]) | |
| else: | |
| pass | |
| if not selected and max_mask is not None: | |
| total_masks.append(max_mask) | |
| return total_masks | |
| def make_2d_mask(mask): | |
| if len(mask.shape) == 4: | |
| return mask.squeeze(0).squeeze(0) | |
| elif len(mask.shape) == 3: | |
| return mask.squeeze(0) | |
| return mask | |
| def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): | |
| mask = make_2d_mask(mask) | |
| points = [] | |
| plabs = [] | |
| # minimum sampling step >= 3 | |
| y_step = max(3, int(mask.shape[0] / 20)) | |
| x_step = max(3, int(mask.shape[1] / 20)) | |
| for i in range(0, len(mask), y_step): | |
| for j in range(0, len(mask[i]), x_step): | |
| if mask[i][j] > threshold: | |
| points.append((x + j, y + i)) | |
| plabs.append(1) | |
| elif use_negative and mask[i][j] == 0: | |
| points.append((x + j, y + i)) | |
| plabs.append(0) | |
| return points, plabs | |
| def gen_negative_hints(w, h, x1, y1, x2, y2): | |
| npoints = [] | |
| nplabs = [] | |
| # minimum sampling step >= 3 | |
| y_step = max(3, int(w / 20)) | |
| x_step = max(3, int(h / 20)) | |
| for i in range(10, h - 10, y_step): | |
| for j in range(10, w - 10, x_step): | |
| if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): | |
| npoints.append((j, i)) | |
| nplabs.append(0) | |
| return npoints, nplabs | |
| def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, | |
| mask_hint_use_negative): | |
| [x1, y1, x2, y2] = dilated_bbox | |
| points = [] | |
| plabs = [] | |
| if detection_hint == "center-1": | |
| points.append(center) | |
| plabs = [1] # 1 = foreground point, 0 = background point | |
| elif detection_hint == "horizontal-2": | |
| gap = (x2 - x1) / 3 | |
| points.append((x1 + gap, center[1])) | |
| points.append((x1 + gap * 2, center[1])) | |
| plabs = [1, 1] | |
| elif detection_hint == "vertical-2": | |
| gap = (y2 - y1) / 3 | |
| points.append((center[0], y1 + gap)) | |
| points.append((center[0], y1 + gap * 2)) | |
| plabs = [1, 1] | |
| elif detection_hint == "rect-4": | |
| x_gap = (x2 - x1) / 3 | |
| y_gap = (y2 - y1) / 3 | |
| points.append((x1 + x_gap, center[1])) | |
| points.append((x1 + x_gap * 2, center[1])) | |
| points.append((center[0], y1 + y_gap)) | |
| points.append((center[0], y1 + y_gap * 2)) | |
| plabs = [1, 1, 1, 1] | |
| elif detection_hint == "diamond-4": | |
| x_gap = (x2 - x1) / 3 | |
| y_gap = (y2 - y1) / 3 | |
| points.append((x1 + x_gap, y1 + y_gap)) | |
| points.append((x1 + x_gap * 2, y1 + y_gap)) | |
| points.append((x1 + x_gap, y1 + y_gap * 2)) | |
| points.append((x1 + x_gap * 2, y1 + y_gap * 2)) | |
| plabs = [1, 1, 1, 1] | |
| elif detection_hint == "mask-point-bbox": | |
| center = center_of_bbox(seg.bbox) | |
| points.append(center) | |
| plabs = [1] | |
| elif detection_hint == "mask-area": | |
| points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], | |
| seg.cropped_mask, | |
| mask_hint_threshold, use_small_negative) | |
| if mask_hint_use_negative == "Outter": | |
| npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], | |
| seg.crop_region[0], seg.crop_region[1], | |
| seg.crop_region[2], seg.crop_region[3]) | |
| points += npoints | |
| plabs += nplabs | |
| return points, plabs | |
| def combine_masks2(masks): | |
| if len(masks) == 0: | |
| return None | |
| else: | |
| initial_cv2_mask = np.array(masks[0]).astype(np.uint8) | |
| combined_cv2_mask = initial_cv2_mask | |
| for i in range(1, len(masks)): | |
| cv2_mask = np.array(masks[i]).astype(np.uint8) | |
| if combined_cv2_mask.shape == cv2_mask.shape: | |
| combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) | |
| else: | |
| # do nothing - incompatible mask | |
| pass | |
| mask = torch.from_numpy(combined_cv2_mask) | |
| return mask | |
| def dilate_mask(mask, dilation_factor, iter=1): | |
| if dilation_factor == 0: | |
| return make_2d_mask(mask) | |
| mask = make_2d_mask(mask) | |
| kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) | |
| mask = cv2.UMat(mask) | |
| kernel = cv2.UMat(kernel) | |
| if dilation_factor > 0: | |
| result = cv2.dilate(mask, kernel, iter) | |
| else: | |
| result = cv2.erode(mask, kernel, iter) | |
| return result.get() | |
| def convert_and_stack_masks(masks): | |
| if len(masks) == 0: | |
| return None | |
| mask_tensors = [] | |
| for mask in masks: | |
| mask_array = np.array(mask, dtype=np.uint8) | |
| mask_tensor = torch.from_numpy(mask_array) | |
| mask_tensors.append(mask_tensor) | |
| stacked_masks = torch.stack(mask_tensors, dim=0) | |
| stacked_masks = stacked_masks.unsqueeze(1) | |
| return stacked_masks | |
| def merge_and_stack_masks(stacked_masks, group_size): | |
| if stacked_masks is None: | |
| return None | |
| num_masks = stacked_masks.size(0) | |
| merged_masks = [] | |
| for i in range(0, num_masks, group_size): | |
| subset_masks = stacked_masks[i:i + group_size] | |
| merged_mask = torch.any(subset_masks, dim=0) | |
| merged_masks.append(merged_mask) | |
| if len(merged_masks) > 0: | |
| merged_masks = torch.stack(merged_masks, dim=0) | |
| return merged_masks | |
| def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, | |
| threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): | |
| if sam_model.is_auto_mode: | |
| device = model_management.get_torch_device() | |
| sam_model.safe_to.to_device(sam_model, device=device) | |
| try: | |
| predictor = SamPredictor(sam_model) | |
| image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) | |
| predictor.set_image(image, "RGB") | |
| total_masks = [] | |
| use_small_negative = mask_hint_use_negative == "Small" | |
| # seg_shape = segs[0] | |
| segs = segs[1] | |
| if detection_hint == "mask-points": | |
| points = [] | |
| plabs = [] | |
| for i in range(len(segs)): | |
| bbox = segs[i].bbox | |
| center = center_of_bbox(bbox) | |
| points.append(center) | |
| # small point is background, big point is foreground | |
| if use_small_negative and bbox[2] - bbox[0] < 10: | |
| plabs.append(0) | |
| else: | |
| plabs.append(1) | |
| detected_masks = sam_predict(predictor, points, plabs, None, threshold) | |
| total_masks += detected_masks | |
| else: | |
| for i in range(len(segs)): | |
| bbox = segs[i].bbox | |
| center = center_of_bbox(bbox) | |
| x1 = max(bbox[0] - bbox_expansion, 0) | |
| y1 = max(bbox[1] - bbox_expansion, 0) | |
| x2 = min(bbox[2] + bbox_expansion, image.shape[1]) | |
| y2 = min(bbox[3] + bbox_expansion, image.shape[0]) | |
| dilated_bbox = [x1, y1, x2, y2] | |
| points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, | |
| mask_hint_threshold, use_small_negative, | |
| mask_hint_use_negative) | |
| detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) | |
| total_masks += detected_masks | |
| # merge every collected masks | |
| mask = combine_masks2(total_masks) | |
| finally: | |
| if sam_model.is_auto_mode: | |
| sam_model.cpu() | |
| pass | |
| mask_working_device = torch.device("cpu") | |
| if mask is not None: | |
| mask = mask.float() | |
| mask = dilate_mask(mask.cpu().numpy(), dilation) | |
| mask = torch.from_numpy(mask) | |
| mask = mask.to(device=mask_working_device) | |
| else: | |
| # Extracting batch, height and width | |
| height, width, _ = image.shape | |
| mask = torch.zeros( | |
| (height, width), dtype=torch.float32, device=mask_working_device | |
| ) # empty mask | |
| stacked_masks = convert_and_stack_masks(total_masks) | |
| return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) | |
| def tensor2mask(t: torch.Tensor) -> torch.Tensor: | |
| size = t.size() | |
| if (len(size) < 4): | |
| return t | |
| if size[3] == 1: | |
| return t[:,:,:,0] | |
| elif size[3] == 4: | |
| # Not sure what the right thing to do here is. Going to try to be a little smart and use alpha unless all alpha is 1 in case we'll fallback to RGB behavior | |
| if torch.min(t[:, :, :, 3]).item() != 1.: | |
| return t[:,:,:,3] | |
| return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] | |
| def tensor2rgb(t: torch.Tensor) -> torch.Tensor: | |
| size = t.size() | |
| if (len(size) < 4): | |
| return t.unsqueeze(3).repeat(1, 1, 1, 3) | |
| if size[3] == 1: | |
| return t.repeat(1, 1, 1, 3) | |
| elif size[3] == 4: | |
| return t[:, :, :, :3] | |
| else: | |
| return t | |
| def tensor2rgba(t: torch.Tensor) -> torch.Tensor: | |
| size = t.size() | |
| if (len(size) < 4): | |
| return t.unsqueeze(3).repeat(1, 1, 1, 4) | |
| elif size[3] == 1: | |
| return t.repeat(1, 1, 1, 4) | |
| elif size[3] == 3: | |
| alpha_tensor = torch.ones((size[0], size[1], size[2], 1)) | |
| return torch.cat((t, alpha_tensor), dim=3) | |
| else: | |
| return t | |