diff --git a/README.md b/README.md
index 758e3be788552de668ae0964e4f8812a45c8f6cb..c02cf011506a1b534ec255d540149c881bdf80ce 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,14 @@
---
-title: FluoGen
-emoji: π
-colorFrom: indigo
-colorTo: pink
+title: FluoGen Demo
+emoji: π
+colorFrom: red
+colorTo: red
sdk: gradio
sdk_version: 5.49.1
app_file: app.py
pinned: false
license: mit
-short_description: TMP
+short_description: 'Demo space of FluoGen: An Open-Source Generative Foundation '
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda41e1a96f67f80f8dd4ebb6e5f742d9c9d69d2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,664 @@
+import torch
+import numpy as np
+import gradio as gr
+from PIL import Image
+import os
+import json
+import glob
+import random
+import tifffile
+import re
+import imageio
+from torchvision import transforms, models
+import accelerate
+import shutil
+import time
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import OrderedDict
+
+
+# --- Imports from both scripts ---
+from diffusers import DDPMScheduler, DDIMScheduler
+from transformers import CLIPTextModel, CLIPTokenizer
+from accelerate.state import AcceleratorState
+from transformers.utils import ContextManagers
+
+# --- Custom Model Imports ---
+from models.pipeline_ddpm_text_encoder import DDPMPipeline
+from models.unet_2d import UNet2DModel
+from models.controlnet import ControlNetModel
+from models.unet_2d_condition import UNet2DConditionModel
+from models.pipeline_controlnet import DDPMControlnetPipeline
+
+# --- New Import for Segmentation ---
+from cellpose import models as cellpose_models
+from cellpose import plot as cellpose_plot
+from huggingface_hub import hf_hub_download
+
+# --- 0. Configuration & Constants ---
+# --- General ---
+MODEL_TITLE = "π¬ FluoGen: AI-Powered Fluorescence Microscopy Suite"
+MODEL_DESCRIPTION = """
+**Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis*
+
+Select a task from the tabs below: generate new images from text, enhance existing images using super-resolution, denoise them, generate training data from segmentation masks, perform cell segmentation, or classify cell images.
+"""
+DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
+WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
+LOGO_PATH = "utils/logo_0801_2.png"
+
+# --- Global switch to control example saving ---
+SAVE_EXAMPLES = False
+
+# --- Base directory for all models ---
+# NOTE: All model paths are now relative.
+# Run the `copy_weights.py` script once to copy all necessary model files into this local directory.
+REPO_ID = "rayquaza384mega/FluoGen-demo-test-ckpts"
+MODELS_ROOT_DIR = hf_hub_download(repo_id=REPO_ID) #"models_collection"
+
+
+# --- Tab 1: Mask-to-Image Config (Formerly Segmentation-to-Image) ---
+M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
+M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
+
+# --- Tab 2: Text-to-Image Config ---
+T2I_PROMPTS = ["F-actin of COS-7", "ER of COS-7", "Mitochondria of BPAE", "Nucleus of BPAE", "ER of HeLa", "Microtubules of HeLa"]
+T2I_EXAMPLE_IMG_DIR = "example_images"
+T2I_CHECKPOINT = 285000
+T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
+T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-{T2I_CHECKPOINT}"
+
+# --- Tab 3, 4: ControlNet-based Tasks Config ---
+CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
+CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
+
+# Super-Resolution Models
+SR_CONTROLNET_MODELS = {
+ "Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000",
+ "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
+}
+SR_EXAMPLE_IMG_DIR = "example_images_sr"
+
+# Denoising Model
+DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
+DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
+DN_EXAMPLE_IMG_DIR = "example_images_dn"
+
+# --- Tab 5: Cell Segmentation Config ---
+SEG_MODELS = {
+ "DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
+ "DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
+ "DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135",
+ "DSB Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135",
+}
+SEG_EXAMPLE_IMG_DIR = "example_images_seg"
+
+
+# --- Tab 6: Classification Config ---
+CLS_MODEL_PATHS = OrderedDict({
+ "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
+ #"10shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_re",
+ #"15shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_re",
+ #"20shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_re",
+ "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
+ #"10shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_aug_re",
+ #"15shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_aug_re",
+ #"20shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_aug_re",
+})
+CLS_CLASS_NAMES = ['dap', 'erdak', 'giant', 'gpp130', 'h4b4', 'mc151', 'nucle', 'phal', 'tfr', 'tubul']
+CLS_EXAMPLE_IMG_DIR = "example_images_cls"
+
+
+# --- Helper Functions ---
+def sanitize_prompt_for_filename(prompt):
+ prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt)
+ return f"{prompt}.png"
+
+def min_max_norm(x):
+ x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x)
+ if max_val - min_val < 1e-8: return np.zeros_like(x)
+ return (x - min_val) / (max_val - min_val)
+
+def numpy_to_pil(image_np, target_mode="RGB"):
+ # If the input is already a PIL image, just ensure mode and return
+ if isinstance(image_np, Image.Image):
+ if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB")
+ if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
+ return image_np
+
+ # Handle numpy array conversion
+ squeezed_np = np.squeeze(image_np);
+ if squeezed_np.dtype == np.uint8:
+ # If it's already uint8, it's likely in the 0-255 range.
+ image_8bit = squeezed_np
+ else:
+ # Normalize and scale for other types
+ normalized_np = min_max_norm(squeezed_np)
+ image_8bit = (normalized_np * 255).astype(np.uint8)
+
+ pil_image = Image.fromarray(image_8bit)
+
+ if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
+ elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L")
+ return pil_image
+
+# --- 1. Model Loading ---
+print("--- Initializing FluoGen Application ---")
+t2i_pipe, controlnet_pipe = None, None
+try:
+ print("Loading Text-to-Image model...")
+ t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing")
+ t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet")
+ t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE)
+ t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
+ t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
+ t2i_pipe.to(DEVICE)
+ print("β Text-to-Image model loaded successfully!")
+except Exception as e:
+ print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
+
+try:
+ print("Loading shared ControlNet pipeline components...")
+ controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ default_controlnet_path = M2I_CONTROLNET_PATH # Start with the first tab's model
+ controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
+ controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
+ with ContextManagers([]):
+ controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
+ controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ controlnet_pipe.current_controlnet_path = default_controlnet_path
+ print("β Shared ControlNet pipeline loaded successfully!")
+except Exception as e:
+ print(f"!!!!!! FATAL: ControlNet Pipeline Loading Failed !!!!!!\nError: {e}")
+
+# --- 2. Core Logic Functions ---
+def swap_controlnet(pipe, target_path):
+ if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path):
+ print(f"Swapping ControlNet model to: {target_path}")
+ try:
+ pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ pipe.current_controlnet_path = target_path
+ except Exception as e:
+ raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
+ return pipe
+
+def generate_t2i(prompt, num_inference_steps):
+ if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
+ print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}")
+ image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
+ generated_image = numpy_to_pil(image_np)
+ print("β Image generated")
+ if SAVE_EXAMPLES:
+ example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
+ if not os.path.exists(example_filepath):
+ generated_image.save(example_filepath); print(f"β New T2I example saved: {example_filepath}")
+ return generated_image
+
+def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed):
+ if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
+ if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask TIF file.")
+ if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.")
+
+ if SAVE_EXAMPLES:
+ input_path = mask_file_obj.name
+ filename = os.path.basename(input_path)
+ dest_path = os.path.join(M2I_EXAMPLE_IMG_DIR, filename)
+ if not os.path.exists(dest_path):
+ shutil.copy(input_path, dest_path)
+ print(f"β New Mask-to-Image example saved: {dest_path}")
+
+ pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
+ try:
+ mask_np = tifffile.imread(mask_file_obj.name)
+ except Exception as e:
+ raise gr.Error(f"Failed to read the TIF file. Error: {e}")
+
+ input_display_image = numpy_to_pil(mask_np, "L")
+ mask_normalized = min_max_norm(mask_np)
+ image_tensor = torch.from_numpy(mask_normalized.astype(np.float32))
+ image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
+
+ prompt = f"nuclei of {cell_type.strip()}"
+ print(f"\nTask started... | Task: Mask-to-Image | Prompt: '{prompt}' | Steps: {steps} | Images: {num_images}")
+
+ generated_images_pil = []
+ for i in range(int(num_images)):
+ current_seed = int(seed) + i
+ generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
+ with torch.autocast("cuda"):
+ output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
+ pil_image = numpy_to_pil(output_np)
+ generated_images_pil.append(pil_image)
+ print(f"β Generated image {i+1}/{int(num_images)}")
+
+ return input_display_image, generated_images_pil
+
+def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed):
+ if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
+ if low_res_file_obj is None: raise gr.Error("Please upload a low-resolution TIF file.")
+
+ if SAVE_EXAMPLES:
+ input_path = low_res_file_obj.name
+ filename = os.path.basename(input_path)
+ dest_path = os.path.join(SR_EXAMPLE_IMG_DIR, filename)
+ if not os.path.exists(dest_path):
+ shutil.copy(input_path, dest_path)
+ print(f"β New SR example saved: {dest_path}")
+
+ target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
+ if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.")
+
+ pipe = swap_controlnet(controlnet_pipe, target_path)
+ try:
+ image_stack_np = tifffile.imread(low_res_file_obj.name)
+ except Exception as e:
+ raise gr.Error(f"Failed to read the TIF file. Error: {e}")
+
+ if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
+ raise gr.Error(f"Invalid TIF shape. Expected 9 channels (shape 9, H, W), but got {image_stack_np.shape}.")
+
+ average_projection_np = np.mean(image_stack_np, axis=0)
+ input_display_image = numpy_to_pil(average_projection_np, "L")
+
+ image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
+
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
+ with torch.autocast("cuda"):
+ output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
+
+ return input_display_image, numpy_to_pil(output_np)
+
+def run_denoising(noisy_image_np, image_type, steps, seed):
+ if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
+ if noisy_image_np is None: raise gr.Error("Please upload a noisy image.")
+
+ if SAVE_EXAMPLES:
+ timestamp = int(time.time() * 1000)
+ filename = f"dn_input_{image_type}_{timestamp}.tif"
+ dest_path = os.path.join(DN_EXAMPLE_IMG_DIR, filename)
+ try:
+ img_to_save = noisy_image_np.astype(np.uint8) if noisy_image_np.dtype != np.uint8 else noisy_image_np
+ tifffile.imwrite(dest_path, img_to_save)
+ print(f"β New Denoising example saved: {dest_path}")
+ except Exception as e:
+ print(f"β Failed to save denoising example: {e}")
+
+ pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
+ prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
+ print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}")
+
+ image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0)
+ image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
+ image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
+
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
+ with torch.autocast("cuda"):
+ output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
+
+ return numpy_to_pil(noisy_image_np, "L"), numpy_to_pil(output_np)
+
+def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
+ """
+ Runs cell segmentation and creates a dark red overlay.
+ """
+ if input_image_np is None:
+ raise gr.Error("Please upload an image to segment.")
+
+ model_path = SEG_MODELS.get(model_name)
+ if not model_path:
+ raise gr.Error(f"Segmentation model '{model_name}' not found.")
+
+ if not os.path.exists(model_path):
+ raise gr.Error(f"Model file not found at path: {model_path}. Please check the configuration.")
+
+ print(f"\nTask started... | Task: Cell Segmentation | Model: '{model_name}'")
+
+ # 1. Load Cellpose Model
+ try:
+ use_gpu = torch.cuda.is_available()
+ model = cellpose_models.CellposeModel(gpu=use_gpu, pretrained_model=model_path)
+ except Exception as e:
+ raise gr.Error(f"Failed to load Cellpose model. Error: {e}")
+
+ diameter_to_use = model.diam_labels if diameter == 0 else float(diameter)
+ print(f"Using Diameter: {diameter_to_use}")
+
+ # 2. Run model evaluation
+ try:
+ masks, _, _ = model.eval(
+ [input_image_np],
+ channels=[0, 0],
+ diameter=diameter_to_use,
+ flow_threshold=flow_threshold,
+ cellprob_threshold=cellprob_threshold
+ )
+ mask_output = masks[0]
+ except Exception as e:
+ raise gr.Error(f"Cellpose model evaluation failed. Error: {e}")
+
+ # 3. Create custom dark red overlay
+ # Ensure input image is uint8 and 3-channel for blending
+ original_rgb = numpy_to_pil(input_image_np, "RGB")
+ original_rgb_np = np.array(original_rgb)
+
+ # Create a blank layer for the red mask
+ red_mask_layer = np.zeros_like(original_rgb_np)
+ dark_red_color = [139, 0, 0]
+
+ # Apply the red color where the mask is present
+ is_mask_pixels = mask_output > 0
+ red_mask_layer[is_mask_pixels] = dark_red_color
+
+ # Blend the original image with the red mask layer
+ alpha = 0.4 # Opacity of the mask
+ blended_image_np = ((1 - alpha) * original_rgb_np + alpha * red_mask_layer).astype(np.uint8)
+
+ # 4. Save example if enabled
+ if SAVE_EXAMPLES:
+ timestamp = int(time.time() * 1000)
+ filename = f"seg_input_{timestamp}.tif"
+ dest_path = os.path.join(SEG_EXAMPLE_IMG_DIR, filename)
+ try:
+ img_to_save = input_image_np.astype(np.uint8) if input_image_np.dtype != np.uint8 else input_image_np
+ tifffile.imwrite(dest_path, img_to_save)
+ print(f"β New Segmentation example saved: {dest_path}")
+ except Exception as e:
+ print(f"β Failed to save segmentation example: {e}")
+
+ print("β Segmentation complete")
+
+ return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB")
+
+def run_classification(input_image_np, model_name):
+ """
+ Runs classification on a single image using a pre-trained ResNet50 model.
+ """
+ if input_image_np is None:
+ raise gr.Error("Please upload an image to classify.")
+
+ model_dir = CLS_MODEL_PATHS.get(model_name)
+ if not model_dir:
+ raise gr.Error(f"Classification model '{model_name}' not found.")
+
+ model_path = os.path.join(model_dir, "best_resnet50.pth")
+ if not os.path.exists(model_path):
+ raise gr.Error(f"Model file not found at {model_path}. Please check the configuration.")
+
+ print(f"\nTask started... | Task: Classification | Model: '{model_name}'")
+
+ # 1. Load Model
+ try:
+ model = models.resnet50(weights=None)
+ num_features = model.fc.in_features
+ model.fc = nn.Linear(num_features, len(CLS_CLASS_NAMES))
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
+ model.to(DEVICE)
+ model.eval()
+ except Exception as e:
+ raise gr.Error(f"Failed to load classification model. Error: {e}")
+
+ # 2. Preprocess Image
+ # Grayscale numpy -> RGB PIL -> transform -> tensor
+ input_pil = numpy_to_pil(input_image_np, "RGB")
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # ResNet needs 3-channel norm
+ ])
+ input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE)
+
+ # 3. Perform Inference
+ with torch.no_grad():
+ outputs = model(input_tensor)
+ probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()
+
+ # 4. Format output for Gradio Label component
+ confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)}
+
+ # 5. Save example
+ if SAVE_EXAMPLES:
+ timestamp = int(time.time() * 1000)
+ filename = f"cls_input_{timestamp}.png" # Save as png for compatibility
+ dest_path = os.path.join(CLS_EXAMPLE_IMG_DIR, filename)
+ try:
+ input_pil.save(dest_path)
+ print(f"β New Classification example saved: {dest_path}")
+ except Exception as e:
+ print(f"β Failed to save classification example: {e}")
+
+ print("β Classification complete")
+
+ return numpy_to_pil(input_image_np, "L"), confidences
+
+
+# --- 3. Gradio UI Layout ---
+print("Building Gradio interface...")
+# Create directories for all example types
+os.makedirs(M2I_EXAMPLE_IMG_DIR, exist_ok=True)
+os.makedirs(T2I_EXAMPLE_IMG_DIR, exist_ok=True)
+os.makedirs(SR_EXAMPLE_IMG_DIR, exist_ok=True)
+os.makedirs(DN_EXAMPLE_IMG_DIR, exist_ok=True)
+os.makedirs(SEG_EXAMPLE_IMG_DIR, exist_ok=True)
+os.makedirs(CLS_EXAMPLE_IMG_DIR, exist_ok=True)
+
+# --- Load examples ---
+filename_to_prompt_map = { sanitize_prompt_for_filename(prompt): prompt for prompt in T2I_PROMPTS }
+t2i_gallery_examples = []
+for filename in os.listdir(T2I_EXAMPLE_IMG_DIR):
+ if filename in filename_to_prompt_map:
+ filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, filename)
+ prompt = filename_to_prompt_map[filename]
+ t2i_gallery_examples.append((filepath, prompt))
+
+def load_image_examples(example_dir, is_stack=False):
+ examples = []
+ if not os.path.exists(example_dir): return examples
+ for f in sorted(os.listdir(example_dir)):
+ if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
+ filepath = os.path.join(example_dir, f)
+ try:
+ if f.lower().endswith(('.tif', '.tiff')):
+ img_np = tifffile.imread(filepath)
+ else:
+ img_np = np.array(Image.open(filepath).convert("L"))
+
+ if is_stack and img_np.ndim == 3:
+ img_np = np.mean(img_np, axis=0)
+
+ display_img = numpy_to_pil(img_np, "L")
+ examples.append((display_img, filepath))
+ except Exception as e:
+ print(f"Warning: Could not load gallery image {filepath}. Error: {e}")
+ return examples
+
+m2i_gallery_examples = load_image_examples(M2I_EXAMPLE_IMG_DIR)
+sr_gallery_examples = load_image_examples(SR_EXAMPLE_IMG_DIR, is_stack=True)
+dn_gallery_examples = load_image_examples(DN_EXAMPLE_IMG_DIR)
+seg_gallery_examples = load_image_examples(SEG_EXAMPLE_IMG_DIR)
+cls_gallery_examples = load_image_examples(CLS_EXAMPLE_IMG_DIR)
+
+# --- Universal event handlers ---
+def select_example_prompt(evt: gr.SelectData):
+ return evt.value['caption']
+
+def select_example_input_file(evt: gr.SelectData):
+ return evt.value['caption']
+
+with gr.Blocks(theme=gr.themes.Soft()) as demo:
+ with gr.Row():
+ gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
+ gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}")
+
+ with gr.Tabs():
+ # --- TAB 1: Mask-to-Image ---
+ with gr.Tab("Mask-to-Image", id="mask2img"):
+ gr.Markdown("""
+ ### Instructions
+ 1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below.
+ 2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt.
+ 3. Select how many sample images you want to generate.
+ 4. Adjust 'Inference Steps' and 'Seed' as needed.
+ 5. Click 'Generate Training Samples' to start the process.
+ 6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference.
+ """) # Content hidden for brevity
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1, min_width=350):
+ m2i_input_file = gr.File(label="Upload Segmentation Mask (.tif)", file_types=['.tif', '.tiff'])
+ m2i_cell_type_input = gr.Textbox(label="Cell Type (for prompt)", placeholder="e.g., CoNSS, HeLa, MCF-7")
+ m2i_num_images_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Images to Generate")
+ m2i_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
+ m2i_seed_input = gr.Number(label="Seed", value=42)
+ m2i_generate_button = gr.Button("Generate Training Samples", variant="primary")
+ with gr.Column(scale=2):
+ m2i_output_gallery = gr.Gallery(label="Generated Samples", columns=5, object_fit="contain", height="auto")
+ m2i_input_display = gr.Image(label="Input Mask", type="pil", interactive=False)
+ m2i_gallery = gr.Gallery(value=m2i_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
+
+ # --- TAB 2: Text-to-Image ---
+ with gr.Tab("Text-to-Image Generation", id="txt2img"):
+ gr.Markdown("""
+ ### Instructions
+ 1. Select a desired prompt from the dropdown menu.
+ 2. Adjust the 'Inference Steps' slider to control generation quality.
+ 3. Click the 'Generate' button to create a new image.
+ 4. Explore the 'Examples' gallery; clicking an image will load its prompt.
+
+ **Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model.
+ """) # Content hidden for brevity
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1, min_width=350):
+ t2i_prompt_input = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Select a Prompt")
+ t2i_steps_slider = gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Inference Steps")
+ t2i_generate_button = gr.Button("Generate", variant="primary")
+ with gr.Column(scale=2):
+ t2i_generated_output = gr.Image(label="Generated Image", type="pil", interactive=False)
+ t2i_gallery = gr.Gallery(value=t2i_gallery_examples, label="Examples (Click an image to use its prompt)", columns=6, object_fit="contain", height="auto")
+
+ # --- TAB 3: Super-Resolution ---
+ with gr.Tab("Super-Resolution", id="super_res"):
+ gr.Markdown("""
+ ### Instructions
+ 1. Upload a low-resolution 9-channel TIF stack, or select one from the examples.
+ 2. Select a 'Super-Resolution Model' from the dropdown.
+ 3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7').
+ 4. Adjust 'Inference Steps' and 'Seed' as needed.
+ 5. Click 'Generate Super-Resolution' to process the image.
+
+ **Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
+ """) # Content hidden for brevity
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1, min_width=350):
+ sr_input_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
+ sr_model_selector = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Select Super-Resolution Model")
+ sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
+ sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
+ sr_seed_input = gr.Number(label="Seed", value=42)
+ sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary")
+ with gr.Column(scale=2):
+ with gr.Row():
+ sr_input_display = gr.Image(label="Input (Average Projection)", type="pil", interactive=False)
+ sr_output_image = gr.Image(label="Super-Resolved Image", type="pil", interactive=False)
+ sr_gallery = gr.Gallery(value=sr_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
+
+ # --- TAB 4: Denoising ---
+ with gr.Tab("Denoising", id="denoising"):
+ gr.Markdown("""
+ ### Instructions
+ 1. Upload a noisy single-channel image, or select one from the examples.
+ 2. Select the 'Image Type' from the dropdown to provide context for the model.
+ 3. Adjust 'Inference Steps' and 'Seed' as needed.
+ 4. Click 'Denoise Image' to reduce the noise.
+
+ **Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
+ """) # Content hidden for brevity
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1, min_width=350):
+ dn_input_image = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
+ dn_image_type_selector = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Select Image Type (for Prompt)")
+ dn_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
+ dn_seed_input = gr.Number(label="Seed", value=42)
+ dn_generate_button = gr.Button("Denoise Image", variant="primary")
+ with gr.Column(scale=2):
+ with gr.Row():
+ dn_original_display = gr.Image(label="Original Noisy Image", type="pil", interactive=False)
+ dn_output_image = gr.Image(label="Denoised Image", type="pil", interactive=False)
+ dn_gallery = gr.Gallery(value=dn_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
+
+ # --- TAB 5: Cell Segmentation ---
+ with gr.Tab("Cell Segmentation", id="segmentation"):
+ gr.Markdown("""
+ ### Instructions
+ 1. Upload a single-channel image for segmentation, or select one from the examples.
+ 2. Select a 'Segmentation Model' from the dropdown menu.
+ 3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it.
+ 4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control.
+ 5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image.
+ """)
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1, min_width=350):
+ gr.Markdown("### 1. Inputs & Controls")
+ seg_input_image = gr.Image(type="numpy", label="Upload Image for Segmentation", image_mode="L")
+ seg_model_selector = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Select Segmentation Model")
+ seg_diameter_input = gr.Number(label="Cell Diameter (pixels, 0=auto)", value=30)
+ seg_flow_slider = gr.Slider(minimum=0.0, maximum=3.0, step=0.1, value=0.4, label="Flow Threshold")
+ seg_cellprob_slider = gr.Slider(minimum=-6.0, maximum=6.0, step=0.5, value=0.0, label="Cell Probability Threshold")
+ seg_generate_button = gr.Button("Segment Cells", variant="primary")
+ with gr.Column(scale=2):
+ gr.Markdown("### 2. Results")
+ with gr.Row():
+ seg_original_display = gr.Image(label="Original Image", type="pil", interactive=False)
+ seg_output_image = gr.Image(label="Segmented Image (Overlay)", type="pil", interactive=False)
+ seg_gallery = gr.Gallery(value=seg_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
+
+ # --- NEW TAB 6: Classification ---
+ with gr.Tab("Classification", id="classification"):
+ gr.Markdown("""
+ ### Instructions
+ 1. Upload a single-channel image for classification, or select an example.
+ 2. Select a pre-trained 'Classification Model' from the dropdown menu.
+ 3. Click 'Classify Image' to view the prediction probabilities for each class.
+
+ **Note:** The models provided are ResNet50 trained on the 2D HeLa dataset.
+ """)
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=1, min_width=350):
+ gr.Markdown("### 1. Inputs & Controls")
+ cls_input_image = gr.Image(type="numpy", label="Upload Image for Classification", image_mode="L")
+ cls_model_selector = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Select Classification Model")
+ cls_generate_button = gr.Button("Classify Image", variant="primary")
+ with gr.Column(scale=2):
+ gr.Markdown("### 2. Results")
+ cls_original_display = gr.Image(label="Input Image", type="pil", interactive=False)
+ cls_output_label = gr.Label(label="Classification Results", num_top_classes=len(CLS_CLASS_NAMES))
+ cls_gallery = gr.Gallery(value=cls_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
+
+
+ # --- Event Handlers ---
+ m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input], outputs=[m2i_input_display, m2i_output_gallery])
+ m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file)
+
+ t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider], outputs=[t2i_generated_output])
+ t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input)
+
+ sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input], outputs=[sr_input_display, sr_output_image])
+ sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file)
+
+ dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input], outputs=[dn_original_display, dn_output_image])
+ dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image)
+
+ seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image])
+ seg_gallery.select(fn=select_example_input_file, outputs=seg_input_image)
+
+ cls_generate_button.click(fn=run_classification, inputs=[cls_input_image, cls_model_selector], outputs=[cls_original_display, cls_output_label])
+ cls_gallery.select(fn=select_example_input_file, outputs=cls_input_image)
+
+
+# --- 4. Launch Application ---
+if __name__ == "__main__":
+ print("Interface built. Launching server...")
+ demo.launch()
\ No newline at end of file
diff --git a/cache/00000001.tif b/cache/00000001.tif
new file mode 100644
index 0000000000000000000000000000000000000000..653f4cacdb2d1d85957fbb005ecad12406ecc6f3
--- /dev/null
+++ b/cache/00000001.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:436a899290607814ba5f5c0e9456e9d6551bb2cb986223aff7fc4ae396d7a86a
+size 1181232
diff --git a/cellpose/__init__.py b/cellpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b9750249a2a5ab1f0c17bbfb0c26e3e19b88f13
--- /dev/null
+++ b/cellpose/__init__.py
@@ -0,0 +1 @@
+from cellpose.version import version, version_str
diff --git a/cellpose/__main__.py b/cellpose/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fea1d9c9b58cfef9831cbfaeb5b2e3fa85b2c9cb
--- /dev/null
+++ b/cellpose/__main__.py
@@ -0,0 +1,380 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import sys, os, glob, pathlib, time
+import numpy as np
+from natsort import natsorted
+from tqdm import tqdm
+from cellpose import utils, models, io, version_str, train, denoise
+from cellpose.cli import get_arg_parser
+
+try:
+ from cellpose.gui import gui3d, gui
+ GUI_ENABLED = True
+except ImportError as err:
+ GUI_ERROR = err
+ GUI_ENABLED = False
+ GUI_IMPORT = True
+except Exception as err:
+ GUI_ENABLED = False
+ GUI_ERROR = err
+ GUI_IMPORT = False
+ raise
+
+import logging
+
+
+# settings re-grouped a bit
+def main():
+ """ Run cellpose from command line
+ """
+
+ args = get_arg_parser().parse_args(
+ ) # this has to be in a separate file for autodoc to work
+
+ if args.version:
+ print(version_str)
+ return
+
+ if args.check_mkl:
+ mkl_enabled = models.check_mkl()
+ else:
+ mkl_enabled = True
+
+ if len(args.dir) == 0 and len(args.image_path) == 0:
+ if args.add_model:
+ io.add_model(args.add_model)
+ else:
+ if not GUI_ENABLED:
+ print("GUI ERROR: %s" % GUI_ERROR)
+ if GUI_IMPORT:
+ print(
+ "GUI FAILED: GUI dependencies may not be installed, to install, run"
+ )
+ print(" pip install 'cellpose[gui]'")
+ else:
+ if args.Zstack:
+ gui3d.run()
+ else:
+ gui.run()
+
+ else:
+ if args.verbose:
+ from .io import logger_setup
+ logger, log_file = logger_setup()
+ else:
+ print(
+ ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
+ print("No --verbose => no progress or info printed")
+ logger = logging.getLogger(__name__)
+
+ use_gpu = False
+ channels = [args.chan, args.chan2]
+
+ # find images
+ if len(args.img_filter) > 0:
+ imf = args.img_filter
+ else:
+ imf = None
+
+ # Check with user if they REALLY mean to run without saving anything
+ if not (args.train or args.train_size):
+ saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
+
+ device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
+ device=args.gpu_device)
+
+ if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
+ pretrained_model = False
+ else:
+ pretrained_model = args.pretrained_model
+
+ restore_type = args.restore_type
+ if restore_type is not None:
+ try:
+ denoise.model_path(restore_type)
+ except Exception as e:
+ raise ValueError("restore_type invalid")
+ if args.train or args.train_size:
+ raise ValueError("restore_type cannot be used with training on CLI yet")
+
+ if args.transformer and (restore_type is None):
+ default_model = "transformer_cp3"
+ backbone = "transformer"
+ elif args.transformer and restore_type is not None:
+ raise ValueError("no transformer based restoration")
+ else:
+ default_model = "cyto3"
+ backbone = "default"
+
+ if args.norm_percentile is not None:
+ value1, value2 = args.norm_percentile
+ normalize = {'percentile': (float(value1), float(value2))}
+ else:
+ normalize = (not args.no_norm)
+
+
+ model_type = None
+ if pretrained_model and not os.path.exists(pretrained_model):
+ model_type = pretrained_model if pretrained_model is not None else "cyto3"
+ model_strings = models.get_user_models()
+ all_models = models.MODEL_NAMES.copy()
+ all_models.extend(model_strings)
+ if ~np.any([model_type == s for s in all_models]):
+ model_type = default_model
+ logger.warning(
+ f"pretrained model has incorrect path, using {default_model}")
+ if model_type == "nuclei":
+ szmean = 17.
+ else:
+ szmean = 30.
+ builtin_size = (model_type == "cyto" or model_type == "cyto2" or
+ model_type == "nuclei" or model_type == "cyto3")
+
+ if len(args.image_path) > 0 and (args.train or args.train_size):
+ raise ValueError("ERROR: cannot train model with single image input")
+
+ if not args.train and not args.train_size:
+ tic = time.time()
+ if len(args.dir) > 0:
+ image_names = io.get_image_files(
+ args.dir, args.mask_filter, imf=imf,
+ look_one_level_down=args.look_one_level_down)
+ else:
+ if os.path.exists(args.image_path):
+ image_names = [args.image_path]
+ else:
+ raise ValueError(f"ERROR: no file found at {args.image_path}")
+ nimg = len(image_names)
+
+ if args.savedir:
+ if not os.path.exists(args.savedir):
+ raise FileExistsError("--savedir {args.savedir} does not exist")
+
+ cstr0 = ["GRAY", "RED", "GREEN", "BLUE"]
+ cstr1 = ["NONE", "RED", "GREEN", "BLUE"]
+ logger.info(
+ ">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s"
+ % (nimg, cstr0[channels[0]], cstr1[channels[1]]))
+
+ # handle built-in model exceptions
+ if builtin_size and restore_type is None and not args.pretrained_model_ortho:
+ model = models.Cellpose(gpu=gpu, device=device, model_type=model_type,
+ backbone=backbone)
+ else:
+ builtin_size = False
+ if args.all_channels:
+ channels = None
+ img = io.imread(image_names[0])
+ if img.ndim == 3:
+ nchan = min(img.shape)
+ elif img.ndim == 2:
+ nchan = 1
+ channels = None
+ else:
+ nchan = 2
+
+ pretrained_model = None if model_type is not None else pretrained_model
+ if restore_type is None:
+ pretrained_model_ortho = None if args.pretrained_model_ortho is None else args.pretrained_model_ortho
+ model = models.CellposeModel(gpu=gpu, device=device,
+ pretrained_model=pretrained_model,
+ model_type=model_type,
+ nchan=nchan,
+ backbone=backbone,
+ pretrained_model_ortho=pretrained_model_ortho)
+ else:
+ model = denoise.CellposeDenoiseModel(
+ gpu=gpu, device=device, pretrained_model=pretrained_model,
+ model_type=model_type, restore_type=restore_type, nchan=nchan,
+ chan2_restore=args.chan2_restore)
+
+ # handle diameters
+ if args.diameter == 0:
+ if builtin_size:
+ diameter = None
+ logger.info(">>>> estimating diameter for each image")
+ else:
+ if restore_type is None:
+ logger.info(
+ ">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
+ )
+ else:
+ logger.info(
+ ">>>> cannot auto-estimate diameter for image restoration")
+ diameter = model.diam_labels
+ logger.info(">>>> using diameter %0.3f for all images" % diameter)
+ else:
+ diameter = args.diameter
+ logger.info(">>>> using diameter %0.3f for all images" % diameter)
+
+ tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
+
+ for image_name in tqdm(image_names, file=tqdm_out):
+ image = io.imread(image_name)
+ out = model.eval(
+ image, channels=channels, diameter=diameter, do_3D=args.do_3D,
+ augment=args.augment, resample=(not args.no_resample),
+ flow_threshold=args.flow_threshold,
+ cellprob_threshold=args.cellprob_threshold,
+ stitch_threshold=args.stitch_threshold, min_size=args.min_size,
+ invert=args.invert, batch_size=args.batch_size,
+ interp=(not args.no_interp), normalize=normalize,
+ channel_axis=args.channel_axis, z_axis=args.z_axis,
+ anisotropy=args.anisotropy, niter=args.niter,
+ flow3D_smooth=args.flow3D_smooth)
+ masks, flows = out[:2]
+ if len(out) > 3 and restore_type is None:
+ diams = out[-1]
+ else:
+ diams = diameter
+ ratio = 1.
+ if restore_type is not None:
+ imgs_dn = out[-1]
+ ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1.
+ diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams
+ else:
+ imgs_dn = None
+ if args.exclude_on_edges:
+ masks = utils.remove_edge_masks(masks)
+ if not args.no_npy:
+ io.masks_flows_to_seg(image, masks, flows, image_name,
+ imgs_restore=imgs_dn, channels=channels,
+ diams=diams, restore_type=restore_type,
+ ratio=1.)
+ if saving_something:
+ suffix = "_cp_masks"
+ if args.output_name is not None:
+ # (1) If `savedir` is not defined, then must have a non-zero `suffix`
+ if args.savedir is None and len(args.output_name) > 0:
+ suffix = args.output_name
+ elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
+ # (2) If `savedir` is defined, and different from `dir` then
+ # takes the value passed as a param. (which can be empty string)
+ suffix = args.output_name
+
+ io.save_masks(image, masks, flows, image_name,
+ suffix=suffix, png=args.save_png,
+ tif=args.save_tif, save_flows=args.save_flows,
+ save_outlines=args.save_outlines,
+ dir_above=args.dir_above, savedir=args.savedir,
+ save_txt=args.save_txt, in_folders=args.in_folders,
+ save_mpl=args.save_mpl)
+ if args.save_rois:
+ io.save_rois(masks, image_name)
+ logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
+ else:
+ test_dir = None if len(args.test_dir) == 0 else args.test_dir
+ images, labels, image_names, train_probs = None, None, None, None
+ test_images, test_labels, image_names_test, test_probs = None, None, None, None
+ compute_flows = False
+ if len(args.file_list) > 0:
+ if os.path.exists(args.file_list):
+ dat = np.load(args.file_list, allow_pickle=True).item()
+ image_names = dat["train_files"]
+ image_names_test = dat.get("test_files", None)
+ train_probs = dat.get("train_probs", None)
+ test_probs = dat.get("test_probs", None)
+ compute_flows = dat.get("compute_flows", False)
+ load_files = False
+ else:
+ logger.critical(f"ERROR: {args.file_list} does not exist")
+ else:
+ output = io.load_train_test_data(args.dir, test_dir, imf,
+ args.mask_filter,
+ args.look_one_level_down)
+ images, labels, image_names, test_images, test_labels, image_names_test = output
+ load_files = True
+
+ # training with all channels
+ if args.all_channels:
+ img = images[0] if images is not None else io.imread(image_names[0])
+ if img.ndim == 3:
+ nchan = min(img.shape)
+ elif img.ndim == 2:
+ nchan = 1
+ channels = None
+ else:
+ nchan = 2
+
+ # model path
+ szmean = args.diam_mean
+ if not os.path.exists(pretrained_model) and model_type is None:
+ if not args.train:
+ error_message = "ERROR: model path missing or incorrect - cannot train size model"
+ logger.critical(error_message)
+ raise ValueError(error_message)
+ pretrained_model = False
+ logger.info(">>>> training from scratch")
+ if args.train:
+ logger.info(
+ ">>>> during training rescaling images to fixed diameter of %0.1f pixels"
+ % args.diam_mean)
+
+ # initialize model
+ model = models.CellposeModel(
+ device=device, model_type=model_type, diam_mean=szmean, nchan=nchan,
+ pretrained_model=pretrained_model if model_type is None else None,
+ backbone=backbone)
+
+ # train segmentation model
+ if args.train:
+ cpmodel_path = train.train_seg(
+ model.net, images, labels, train_files=image_names,
+ test_data=test_images, test_labels=test_labels,
+ test_files=image_names_test, train_probs=train_probs,
+ test_probs=test_probs, compute_flows=compute_flows,
+ load_files=load_files, normalize=normalize,
+ channels=channels, channel_axis=args.channel_axis, rgb=(nchan == 3),
+ learning_rate=args.learning_rate, weight_decay=args.weight_decay,
+ SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size,
+ min_train_masks=args.min_train_masks,
+ nimg_per_epoch=args.nimg_per_epoch,
+ nimg_test_per_epoch=args.nimg_test_per_epoch,
+ save_path=os.path.realpath(args.dir), save_every=args.save_every,
+ model_name=args.model_name_out)[0]
+ model.pretrained_model = cpmodel_path
+ logger.info(">>>> model trained and saved to %s" % cpmodel_path)
+
+ # train size model
+ if args.train_size:
+ sz_model = models.SizeModel(cp_model=model, device=device)
+ # data has already been normalized and reshaped
+ sz_model.params = train.train_size(
+ model.net, model.pretrained_model, images, labels,
+ train_files=image_names, test_data=test_images,
+ test_labels=test_labels, test_files=image_names_test,
+ train_probs=train_probs, test_probs=test_probs,
+ load_files=load_files, channels=channels,
+ min_train_masks=args.min_train_masks,
+ channel_axis=args.channel_axis, rgb=(nchan == 3),
+ nimg_per_epoch=args.nimg_per_epoch, normalize=normalize,
+ nimg_test_per_epoch=args.nimg_test_per_epoch,
+ batch_size=args.batch_size)
+ if test_images is not None:
+ test_masks = [lbl[0] for lbl in test_labels
+ ] if test_labels is not None else test_labels
+ predicted_diams, diams_style = sz_model.eval(
+ test_images, channels=channels)
+ ccs = np.corrcoef(
+ diams_style,
+ np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0, 1]
+ cc = np.corrcoef(
+ predicted_diams,
+ np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0, 1]
+ logger.info(
+ "style test correlation: %0.4f; final test correlation: %0.4f" %
+ (ccs, cc))
+ np.save(
+ os.path.join(
+ args.test_dir,
+ "%s_predicted_diams.npy" % os.path.split(cpmodel_path)[1]),
+ {
+ "predicted_diams": predicted_diams,
+ "diams_style": diams_style
+ })
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cellpose/__pycache__/__init__.cpython-310.pyc b/cellpose/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4711a7c4a12e6f8aa4502c9f3d502f87d6d9610
Binary files /dev/null and b/cellpose/__pycache__/__init__.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/__init__.cpython-311.pyc b/cellpose/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddeeae369d0e38bd8427c87ca5708112ff6ed3e8
Binary files /dev/null and b/cellpose/__pycache__/__init__.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/__init__.cpython-312.pyc b/cellpose/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13d99532b2ec3a9109f7cd12b63171cfde0be08c
Binary files /dev/null and b/cellpose/__pycache__/__init__.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/core.cpython-310.pyc b/cellpose/__pycache__/core.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82c745924bf4c29d4ac6d11dd6f81a8fbd22952f
Binary files /dev/null and b/cellpose/__pycache__/core.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/core.cpython-311.pyc b/cellpose/__pycache__/core.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a7d57f5bddc8f4d1b4878b43611d5c17561c039
Binary files /dev/null and b/cellpose/__pycache__/core.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/core.cpython-312.pyc b/cellpose/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e760a67005aab4954dbd3b3a4e1a646c5eb1cdbe
Binary files /dev/null and b/cellpose/__pycache__/core.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/dynamics.cpython-310.pyc b/cellpose/__pycache__/dynamics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21dcc91e54563d7df02f963113c7c9f8ab948c0a
Binary files /dev/null and b/cellpose/__pycache__/dynamics.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/dynamics.cpython-311.pyc b/cellpose/__pycache__/dynamics.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b36365a8468f847a3bb8c0cab1e7b4be6064276
Binary files /dev/null and b/cellpose/__pycache__/dynamics.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/dynamics.cpython-312.pyc b/cellpose/__pycache__/dynamics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b270ed4b6adb3d085c4ac268f8ec3735fa71362
Binary files /dev/null and b/cellpose/__pycache__/dynamics.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc
new file mode 100644
index 0000000000000000000000000000000000000000..ed51d53bcd4f6c0cd8a0490d274e82ca7a6cf445
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc
new file mode 100644
index 0000000000000000000000000000000000000000..c95b9a6259ad7a6987b7d0ac6241b0ea5bf3a4d9
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi
new file mode 100644
index 0000000000000000000000000000000000000000..259b67b759402723411cfd4b214b242c3090c993
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc
new file mode 100644
index 0000000000000000000000000000000000000000..9843e74d95f2fff5da4d7d6ccf0c29ca341ba374
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc
new file mode 100644
index 0000000000000000000000000000000000000000..5416c7461c6c86d56c1b66b70e86ab624fe9211e
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi
new file mode 100644
index 0000000000000000000000000000000000000000..9f7c8fd3edc8a1da3f66f1018a8f22699f8fcb38
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc
new file mode 100644
index 0000000000000000000000000000000000000000..4969bd00d68697357f57bdda43414d52aeaaf8e3
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc
new file mode 100644
index 0000000000000000000000000000000000000000..3d771a418ddc10de51f7a48efb324757f8179b92
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc differ
diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi
new file mode 100644
index 0000000000000000000000000000000000000000..e6ee3eb368c8ee8273c2a70a5ee37a13449210d4
Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi differ
diff --git a/cellpose/__pycache__/io.cpython-310.pyc b/cellpose/__pycache__/io.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ed96f8b731c319779100d8fa73bd40499947498
Binary files /dev/null and b/cellpose/__pycache__/io.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/io.cpython-311.pyc b/cellpose/__pycache__/io.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6142b4480f8621f72ba9b8acd7f2871683de0cff
Binary files /dev/null and b/cellpose/__pycache__/io.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/io.cpython-312.pyc b/cellpose/__pycache__/io.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..891a7298846644d03cf85f355f228443594bbe70
Binary files /dev/null and b/cellpose/__pycache__/io.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/metrics.cpython-310.pyc b/cellpose/__pycache__/metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd2d6f9b9450218669b43cd37544e09357039d53
Binary files /dev/null and b/cellpose/__pycache__/metrics.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/metrics.cpython-311.pyc b/cellpose/__pycache__/metrics.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c53ac1a8dcbf08920123dae8abb128af34405f9
Binary files /dev/null and b/cellpose/__pycache__/metrics.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/metrics.cpython-312.pyc b/cellpose/__pycache__/metrics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a861397f990129ee506a4b7b52344b62721a152
Binary files /dev/null and b/cellpose/__pycache__/metrics.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/models.cpython-310.pyc b/cellpose/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b16eb38aa02d62e6bbda76c6c5642d4137ab5aa5
Binary files /dev/null and b/cellpose/__pycache__/models.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/models.cpython-311.pyc b/cellpose/__pycache__/models.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ac605152416c732b961eccfe4387a6e814c62c0
Binary files /dev/null and b/cellpose/__pycache__/models.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/models.cpython-312.pyc b/cellpose/__pycache__/models.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c085d2d24f7ede30539e3cd15cd330966108bd24
Binary files /dev/null and b/cellpose/__pycache__/models.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/plot.cpython-310.pyc b/cellpose/__pycache__/plot.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bea04fb0793a5e924ec2fab9d8d5feb66b6d02b9
Binary files /dev/null and b/cellpose/__pycache__/plot.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/plot.cpython-311.pyc b/cellpose/__pycache__/plot.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a93aed81d324ac260ece773da03a7c6fa1511d8f
Binary files /dev/null and b/cellpose/__pycache__/plot.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/plot.cpython-312.pyc b/cellpose/__pycache__/plot.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3f38da5a63c1811a31e44a2699e4e4c5e159a4a
Binary files /dev/null and b/cellpose/__pycache__/plot.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/resnet_torch.cpython-310.pyc b/cellpose/__pycache__/resnet_torch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24fb000cc1ed4529bcaffcb983390ea09d566d04
Binary files /dev/null and b/cellpose/__pycache__/resnet_torch.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/resnet_torch.cpython-311.pyc b/cellpose/__pycache__/resnet_torch.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2758c2293455c7b70986aa6c8bb78c7b01c8cae
Binary files /dev/null and b/cellpose/__pycache__/resnet_torch.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/resnet_torch.cpython-312.pyc b/cellpose/__pycache__/resnet_torch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35561f86b3574c71f7f818ddad2ed4a7d216ff28
Binary files /dev/null and b/cellpose/__pycache__/resnet_torch.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/train.cpython-310.pyc b/cellpose/__pycache__/train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b6ce3df8c32e54ea01d887c24390955c1de2a3d
Binary files /dev/null and b/cellpose/__pycache__/train.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/train.cpython-312.pyc b/cellpose/__pycache__/train.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45c9a129f7e6baf921b9bbdd16ff191d8f5c176e
Binary files /dev/null and b/cellpose/__pycache__/train.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/transforms.cpython-310.pyc b/cellpose/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1731f5d9dfa4998fe91909fd42344823daabdff4
Binary files /dev/null and b/cellpose/__pycache__/transforms.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/transforms.cpython-311.pyc b/cellpose/__pycache__/transforms.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0fc835a45f7f66d886178f500052b34407a8556
Binary files /dev/null and b/cellpose/__pycache__/transforms.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/transforms.cpython-312.pyc b/cellpose/__pycache__/transforms.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f39e7e28f94f97148fde5e9d99f7bca85b9bbd83
Binary files /dev/null and b/cellpose/__pycache__/transforms.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/utils.cpython-310.pyc b/cellpose/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18ab946e07853c6263cf66097bf59a06510210d1
Binary files /dev/null and b/cellpose/__pycache__/utils.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/utils.cpython-311.pyc b/cellpose/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05bbb4980db6fa0026a3bebf2883b3beb3a11186
Binary files /dev/null and b/cellpose/__pycache__/utils.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/utils.cpython-312.pyc b/cellpose/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88a08e18a7343c4a14dc9e03b54903441ebe527b
Binary files /dev/null and b/cellpose/__pycache__/utils.cpython-312.pyc differ
diff --git a/cellpose/__pycache__/version.cpython-310.pyc b/cellpose/__pycache__/version.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2eaf07082f1d5e1f23571e30a1ba9ea8a06a3fb1
Binary files /dev/null and b/cellpose/__pycache__/version.cpython-310.pyc differ
diff --git a/cellpose/__pycache__/version.cpython-311.pyc b/cellpose/__pycache__/version.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51fd4cf427242d59d7f210b8d974859855079c02
Binary files /dev/null and b/cellpose/__pycache__/version.cpython-311.pyc differ
diff --git a/cellpose/__pycache__/version.cpython-312.pyc b/cellpose/__pycache__/version.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f6544dc4596e8b2c63931d06fdadc0bcfa4b3fe
Binary files /dev/null and b/cellpose/__pycache__/version.cpython-312.pyc differ
diff --git a/cellpose/cli.py b/cellpose/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..be31d06dbfd2970adac0decb050c4bdee4e91a81
--- /dev/null
+++ b/cellpose/cli.py
@@ -0,0 +1,231 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
+"""
+
+import argparse
+
+
+def get_arg_parser():
+ """ Parses command line arguments for cellpose main function
+
+ Note: this function has to be in a separate file to allow autodoc to work for CLI.
+ The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
+ see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
+ """
+
+ parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
+
+ # misc settings
+ parser.add_argument("--version", action="store_true",
+ help="show cellpose version info")
+ parser.add_argument(
+ "--verbose", action="store_true",
+ help="show information about running and settings and save to log")
+ parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
+
+ # settings for CPU vs GPU
+ hardware_args = parser.add_argument_group("Hardware Arguments")
+ hardware_args.add_argument("--use_gpu", action="store_true",
+ help="use gpu if torch with cuda installed")
+ hardware_args.add_argument(
+ "--gpu_device", required=False, default="0", type=str,
+ help="which gpu device to use, use an integer for torch, or mps for M1")
+ hardware_args.add_argument("--check_mkl", action="store_true",
+ help="check if mkl working")
+
+ # settings for locating and formatting images
+ input_img_args = parser.add_argument_group("Input Image Arguments")
+ input_img_args.add_argument("--dir", default=[], type=str,
+ help="folder containing data to run or train on.")
+ input_img_args.add_argument(
+ "--image_path", default=[], type=str, help=
+ "if given and --dir not given, run on single image instead of folder (cannot train with this option)"
+ )
+ input_img_args.add_argument(
+ "--look_one_level_down", action="store_true",
+ help="run processing on all subdirectories of current folder")
+ input_img_args.add_argument("--img_filter", default=[], type=str,
+ help="end string for images to run on")
+ input_img_args.add_argument(
+ "--channel_axis", default=None, type=int,
+ help="axis of image which corresponds to image channels")
+ input_img_args.add_argument("--z_axis", default=None, type=int,
+ help="axis of image which corresponds to Z dimension")
+ input_img_args.add_argument(
+ "--chan", default=0, type=int, help=
+ "channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s")
+ input_img_args.add_argument(
+ "--chan2", default=0, type=int, help=
+ "nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s"
+ )
+ input_img_args.add_argument("--invert", action="store_true",
+ help="invert grayscale channel")
+ input_img_args.add_argument(
+ "--all_channels", action="store_true", help=
+ "use all channels in image if using own model and images with special channels")
+
+ # model settings
+ model_args = parser.add_argument_group("Model Arguments")
+ model_args.add_argument("--pretrained_model", required=False, default="cyto3",
+ type=str,
+ help="model to use for running or starting training")
+ model_args.add_argument("--restore_type", required=False, default=None, type=str,
+ help="model to use for image restoration")
+ model_args.add_argument("--chan2_restore", action="store_true",
+ help="use nuclei restore model for second channel")
+ model_args.add_argument(
+ "--add_model", required=False, default=None, type=str,
+ help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
+ model_args.add_argument(
+ "--transformer", action="store_true", help=
+ "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
+ model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
+ type=str,
+ help="model to use for running 3D ortho views (ZY and ZX)")
+
+ # algorithm settings
+ algorithm_args = parser.add_argument_group("Algorithm Arguments")
+ algorithm_args.add_argument(
+ "--no_resample", action="store_true", help=
+ "disable dynamics on full image (makes algorithm faster for images with large diameters)"
+ )
+ algorithm_args.add_argument(
+ "--no_interp", action="store_true",
+ help="do not interpolate when running dynamics (was default)")
+ algorithm_args.add_argument("--no_norm", action="store_true",
+ help="do not normalize images (normalize=False)")
+ parser.add_argument(
+ '--norm_percentile',
+ nargs=2, # Require exactly two values
+ metavar=('VALUE1', 'VALUE2'),
+ help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
+ )
+ algorithm_args.add_argument(
+ "--do_3D", action="store_true",
+ help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
+ algorithm_args.add_argument(
+ "--diameter", required=False, default=30., type=float, help=
+ "cell diameter, if 0 will use the diameter of the training labels used in the model, or with built-in model will estimate diameter for each image"
+ )
+ algorithm_args.add_argument(
+ "--stitch_threshold", required=False, default=0.0, type=float,
+ help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
+ )
+ algorithm_args.add_argument(
+ "--min_size", required=False, default=15, type=int,
+ help="minimum number of pixels per mask, can turn off with -1")
+ algorithm_args.add_argument(
+ "--flow3D_smooth", required=False, default=0, type=float,
+ help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
+
+ algorithm_args.add_argument(
+ "--flow_threshold", default=0.4, type=float, help=
+ "flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
+ algorithm_args.add_argument(
+ "--cellprob_threshold", default=0, type=float,
+ help="cellprob threshold, default is 0, decrease to find more and larger masks")
+ algorithm_args.add_argument(
+ "--niter", default=0, type=int, help=
+ "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
+ )
+
+ algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
+ help="anisotropy of volume in 3D")
+ algorithm_args.add_argument("--exclude_on_edges", action="store_true",
+ help="discard masks which touch edges of image")
+ algorithm_args.add_argument(
+ "--augment", action="store_true",
+ help="tiles image with overlapping tiles and flips overlapped regions to augment"
+ )
+
+ # output settings
+ output_args = parser.add_argument_group("Output Arguments")
+ output_args.add_argument(
+ "--save_png", action="store_true",
+ help="save masks as png")
+ output_args.add_argument(
+ "--save_tif", action="store_true",
+ help="save masks as tif")
+ output_args.add_argument(
+ "--output_name", default=None, type=str,
+ help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
+ output_args.add_argument("--no_npy", action="store_true",
+ help="suppress saving of npy")
+ output_args.add_argument(
+ "--savedir", default=None, type=str, help=
+ "folder to which segmentation results will be saved (defaults to input image directory)"
+ )
+ output_args.add_argument(
+ "--dir_above", action="store_true", help=
+ "save output folders adjacent to image folder instead of inside it (off by default)"
+ )
+ output_args.add_argument("--in_folders", action="store_true",
+ help="flag to save output in folders (off by default)")
+ output_args.add_argument(
+ "--save_flows", action="store_true", help=
+ "whether or not to save RGB images of flows when masks are saved (disabled by default)"
+ )
+ output_args.add_argument(
+ "--save_outlines", action="store_true", help=
+ "whether or not to save RGB outline images when masks are saved (disabled by default)"
+ )
+ output_args.add_argument(
+ "--save_rois", action="store_true",
+ help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
+ )
+ output_args.add_argument(
+ "--save_txt", action="store_true",
+ help="flag to enable txt outlines for ImageJ (disabled by default)")
+ output_args.add_argument(
+ "--save_mpl", action="store_true",
+ help="save a figure of image/mask/flows using matplotlib (disabled by default). "
+ "This is slow, especially with large images.")
+
+ # training settings
+ training_args = parser.add_argument_group("Training Arguments")
+ training_args.add_argument("--train", action="store_true",
+ help="train network using images in dir")
+ training_args.add_argument("--train_size", action="store_true",
+ help="train size network at end of training")
+ training_args.add_argument("--test_dir", default=[], type=str,
+ help="folder containing test data (optional)")
+ training_args.add_argument(
+ "--file_list", default=[], type=str, help=
+ "path to list of files for training and testing and probabilities for each image (optional)"
+ )
+ training_args.add_argument(
+ "--mask_filter", default="_masks", type=str, help=
+ "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
+ )
+ training_args.add_argument(
+ "--diam_mean", default=30., type=float, help=
+ "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
+ )
+ training_args.add_argument("--learning_rate", default=0.2, type=float,
+ help="learning rate. Default: %(default)s")
+ training_args.add_argument("--weight_decay", default=0.00001, type=float,
+ help="weight decay. Default: %(default)s")
+ training_args.add_argument("--n_epochs", default=500, type=int,
+ help="number of epochs. Default: %(default)s")
+ training_args.add_argument("--batch_size", default=8, type=int,
+ help="batch size. Default: %(default)s")
+ training_args.add_argument(
+ "--nimg_per_epoch", default=None, type=int,
+ help="number of train images per epoch. Default is to use all train images.")
+ training_args.add_argument(
+ "--nimg_test_per_epoch", default=None, type=int,
+ help="number of test images per epoch. Default is to use all test images.")
+ training_args.add_argument(
+ "--min_train_masks", default=5, type=int, help=
+ "minimum number of masks a training image must have to be used. Default: %(default)s"
+ )
+ training_args.add_argument("--SGD", default=1, type=int, help="use SGD")
+ training_args.add_argument(
+ "--save_every", default=100, type=int,
+ help="number of epochs to skip between saves. Default: %(default)s")
+ training_args.add_argument(
+ "--model_name_out", default=None, type=str,
+ help="Name of model to save as, defaults to name describing model architecture. "
+ "Model is saved in the folder specified by --dir in models subfolder.")
+
+ return parser
diff --git a/cellpose/contrib/distributed_segmentation.py b/cellpose/contrib/distributed_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..42a21e3ad662622d2ee1e64e0a7a628a7011cc13
--- /dev/null
+++ b/cellpose/contrib/distributed_segmentation.py
@@ -0,0 +1,924 @@
+# stdlib imports
+import os, getpass, datetime, pathlib, tempfile, functools, glob
+
+# non-stdlib core dependencies
+import numpy as np
+import scipy
+import cellpose.io
+import cellpose.models
+import tifffile
+import imagecodecs
+
+# distributed dependencies
+import dask
+import distributed
+import dask_image.ndmeasure
+import yaml
+import zarr
+import dask_jobqueue
+
+
+
+
+######################## File format functions ################################
+def numpy_array_to_zarr(write_path, array, chunks):
+ """
+ Store an in memory numpy array to disk as a chunked Zarr array
+
+ Parameters
+ ----------
+ write_path : string
+ Filepath where Zarr array will be created
+
+ array : numpy.ndarray
+ The already loaded in-memory numpy array to store as zarr
+
+ chunks : tuple, must be array.ndim length
+ How the array will be chunked in the Zarr array
+
+ Returns
+ -------
+ zarr.core.Array
+ A read+write reference to the zarr array on disk
+ """
+
+ zarr_array = zarr.open(
+ write_path,
+ mode='w',
+ shape=array.shape,
+ chunks=chunks,
+ dtype=array.dtype,
+ )
+ zarr_array[...] = array
+ return zarr_array
+
+
+def wrap_folder_of_tiffs(
+ filename_pattern,
+ block_index_pattern=r'_(Z)(\d+)(Y)(\d+)(X)(\d+)',
+):
+ """
+ Wrap a folder of tiff files with a zarr array without duplicating data.
+ Tiff files must all contain images with the same shape and data type.
+ Tiff file names must contain a pattern indicating where individual files
+ lie in the block grid.
+
+ Distributed computing requires parallel access to small regions of your
+ image from different processes. This is best accomplished with chunked
+ file formats like Zarr and N5. This function can accommodate a folder of
+ tiff files, but it is not equivalent to reformating your data as Zarr or N5.
+ If your individual tiff files/tiles are huge, distributed performance will
+ be poor or not work at all.
+
+ It does not make sense to use this function if you have only one tiff file.
+ That tiff file will become the only chunk in the zarr array, which means all
+ workers will have to load the entire image to fetch their crop of data anyway.
+ If you have a single tiff image, you should just reformat it with the
+ numpy_array_to_zarr function. Single tiff files too large to fit into system
+ memory are not be supported.
+
+ Parameters
+ ----------
+ filename_pattern : string
+ A glob pattern that will match all needed tif files
+
+ block_index_pattern : regular expression string (default: r'_(Z)(\d+)(Y)(\d+)(X)(\d+)')
+ A regular expression pattern that indicates how to parse tiff filenames
+ to determine where each tiff file lies in the overall block grid
+ The default pattern assumes filenames like the following:
+ {any_prefix}_Z000Y000X000{any_suffix}
+ {any_prefix}_Z000Y000X001{any_suffix}
+ ... and so on
+
+ Returns
+ -------
+ zarr.core.Array
+ """
+
+ # define function to read individual files
+ def imread(fname):
+ with open(fname, 'rb') as fh:
+ return imagecodecs.tiff_decode(fh.read(), index=None)
+
+ # create zarr store, open it as zarr array and return
+ store = tifffile.imread(
+ filename_pattern,
+ aszarr=True,
+ imread=imread,
+ pattern=block_index_pattern,
+ axestiled={x:x for x in range(3)},
+ )
+ return zarr.open(store=store)
+
+
+
+
+######################## Cluster related functions ############################
+
+#----------------------- config stuff ----------------------------------------#
+DEFAULT_CONFIG_FILENAME = 'distributed_cellpose_dask_config.yaml'
+
+def _config_path(config_name):
+ """Add config directory path to config filename"""
+ return str(pathlib.Path.home()) + '/.config/dask/' + config_name
+
+
+def _modify_dask_config(
+ config,
+ config_name=DEFAULT_CONFIG_FILENAME,
+):
+ """
+ Modifies dask config dictionary, but also dumps modified
+ config to disk as a yaml file in ~/.config/dask/. This
+ ensures that workers inherit config options.
+ """
+ dask.config.set(config)
+ with open(_config_path(config_name), 'w') as f:
+ yaml.dump(dask.config.config, f, default_flow_style=False)
+
+
+def _remove_config_file(
+ config_name=DEFAULT_CONFIG_FILENAME,
+):
+ """Removes a config file from disk"""
+ config_path = _config_path(config_name)
+ if os.path.exists(config_path): os.remove(config_path)
+
+
+#----------------------- clusters --------------------------------------------#
+class myLocalCluster(distributed.LocalCluster):
+ """
+ This is a thin wrapper extending dask.distributed.LocalCluster to set
+ configs before the cluster or workers are initialized.
+
+ For a list of full arguments (how to specify your worker resources) see:
+ https://distributed.dask.org/en/latest/api.html#distributed.LocalCluster
+ You need to know how many cpu cores and how much RAM your machine has.
+
+ Most users will only need to specify:
+ n_workers
+ ncpus (number of physical cpu cores per worker)
+ memory_limit (which is the limit per worker, should be a string like '16GB')
+ threads_per_worker (for most workflows this should be 1)
+
+ You can also modify any dask configuration option through the
+ config argument.
+
+ If your workstation has a GPU, one of the workers will have exclusive
+ access to it by default. That worker will be much faster than the others.
+ You may want to consider creating only one worker (which will have access
+ to the GPU) and letting that worker process all blocks serially.
+ """
+
+ def __init__(
+ self,
+ ncpus,
+ config={},
+ config_name=DEFAULT_CONFIG_FILENAME,
+ persist_config=False,
+ **kwargs,
+ ):
+ # config
+ self.config_name = config_name
+ self.persist_config = persist_config
+ scratch_dir = f"{os.getcwd()}/"
+ scratch_dir += f".{getpass.getuser()}_distributed_cellpose/"
+ config_defaults = {'temporary-directory':scratch_dir}
+ config = {**config_defaults, **config}
+ _modify_dask_config(config, config_name)
+
+ # construct
+ if "host" not in kwargs: kwargs["host"] = ""
+ super().__init__(**kwargs)
+ self.client = distributed.Client(self)
+
+ # set environment variables for workers (threading)
+ environment_vars = {
+ 'MKL_NUM_THREADS':str(2*ncpus),
+ 'NUM_MKL_THREADS':str(2*ncpus),
+ 'OPENBLAS_NUM_THREADS':str(2*ncpus),
+ 'OPENMP_NUM_THREADS':str(2*ncpus),
+ 'OMP_NUM_THREADS':str(2*ncpus),
+ }
+ def set_environment_vars():
+ for k, v in environment_vars.items():
+ os.environ[k] = v
+ self.client.run(set_environment_vars)
+
+ print("Cluster dashboard link: ", self.dashboard_link)
+
+ def __enter__(self): return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ if not self.persist_config:
+ _remove_config_file(self.config_name)
+ self.client.close()
+ super().__exit__(exc_type, exc_value, traceback)
+
+
+class janeliaLSFCluster(dask_jobqueue.LSFCluster):
+ """
+ This is a thin wrapper extending dask_jobqueue.LSFCluster,
+ which in turn extends dask.distributed.SpecCluster. This wrapper
+ sets configs before the cluster or workers are initialized. This is
+ an adaptive cluster and will scale the number of workers, between user
+ specified limits, based on the number of pending tasks. This wrapper
+ also enforces conventions specific to the Janelia LSF cluster.
+
+ For a full list of arguments see
+ https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.LSFCluster.html
+
+ Most users will only need to specify:
+ ncpus (the number of cpu cores per worker)
+ min_workers
+ max_workers
+ """
+
+ def __init__(
+ self,
+ ncpus,
+ min_workers,
+ max_workers,
+ config={},
+ config_name=DEFAULT_CONFIG_FILENAME,
+ persist_config=False,
+ **kwargs
+ ):
+
+ # store all args in case needed later
+ self.locals_store = {**locals()}
+
+ # config
+ self.config_name = config_name
+ self.persist_config = persist_config
+ scratch_dir = f"/scratch/{getpass.getuser()}/"
+ config_defaults = {
+ 'temporary-directory':scratch_dir,
+ 'distributed.comm.timeouts.connect':'180s',
+ 'distributed.comm.timeouts.tcp':'360s',
+ }
+ config = {**config_defaults, **config}
+ _modify_dask_config(config, config_name)
+
+ # threading is best in low level libraries
+ job_script_prologue = [
+ f"export MKL_NUM_THREADS={2*ncpus}",
+ f"export NUM_MKL_THREADS={2*ncpus}",
+ f"export OPENBLAS_NUM_THREADS={2*ncpus}",
+ f"export OPENMP_NUM_THREADS={2*ncpus}",
+ f"export OMP_NUM_THREADS={2*ncpus}",
+ ]
+
+ # set scratch and log directories
+ if "local_directory" not in kwargs:
+ kwargs["local_directory"] = scratch_dir
+ if "log_directory" not in kwargs:
+ log_dir = f"{os.getcwd()}/dask_worker_logs_{os.getpid()}/"
+ pathlib.Path(log_dir).mkdir(parents=False, exist_ok=True)
+ kwargs["log_directory"] = log_dir
+
+ # graceful exit for lsf jobs (adds -d flag)
+ class quietLSFJob(dask_jobqueue.lsf.LSFJob):
+ cancel_command = "bkill -d"
+
+ # construct
+ super().__init__(
+ ncpus=ncpus,
+ processes=1,
+ cores=1,
+ memory=str(15*ncpus)+'GB',
+ mem=int(15e9*ncpus),
+ job_script_prologue=job_script_prologue,
+ job_cls=quietLSFJob,
+ **kwargs,
+ )
+ self.client = distributed.Client(self)
+ print("Cluster dashboard link: ", self.dashboard_link)
+
+ # set adaptive cluster bounds
+ self.adapt_cluster(min_workers, max_workers)
+
+
+ def __enter__(self): return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ if not self.persist_config:
+ _remove_config_file(self.config_name)
+ self.client.close()
+ super().__exit__(exc_type, exc_value, traceback)
+
+
+ def adapt_cluster(self, min_workers, max_workers):
+ _ = self.adapt(
+ minimum_jobs=min_workers,
+ maximum_jobs=max_workers,
+ interval='10s',
+ wait_count=6,
+ )
+
+
+ def change_worker_attributes(
+ self,
+ min_workers,
+ max_workers,
+ **kwargs,
+ ):
+ """WARNING: this function is dangerous if you don't know what
+ you're doing. Don't call this unless you know exactly what
+ this does."""
+ self.scale(0)
+ for k, v in kwargs.items():
+ self.new_spec['options'][k] = v
+ self.adapt_cluster(min_workers, max_workers)
+
+
+#----------------------- decorator -------------------------------------------#
+def cluster(func):
+ """
+ This decorator ensures a function will run inside a cluster
+ as a context manager. The decorated function, "func", must
+ accept "cluster" and "cluster_kwargs" as parameters. If
+ "cluster" is not None then the user has provided an existing
+ cluster and we just run func. If "cluster" is None then
+ "cluster_kwargs" are used to construct a new cluster, and
+ the function is run inside that cluster context.
+ """
+ @functools.wraps(func)
+ def create_or_pass_cluster(*args, **kwargs):
+ # TODO: this only checks if args are explicitly present in function call
+ # it does not check if they are set correctly in any way
+ assert 'cluster' in kwargs or 'cluster_kwargs' in kwargs, \
+ "Either cluster or cluster_kwargs must be defined"
+ if not 'cluster' in kwargs:
+ cluster_constructor = myLocalCluster
+ F = lambda x: x in kwargs['cluster_kwargs']
+ if F('ncpus') and F('min_workers') and F('max_workers'):
+ cluster_constructor = janeliaLSFCluster
+ with cluster_constructor(**kwargs['cluster_kwargs']) as cluster:
+ kwargs['cluster'] = cluster
+ return func(*args, **kwargs)
+ return func(*args, **kwargs)
+ return create_or_pass_cluster
+
+
+
+
+######################## the function to run on each block ####################
+
+#----------------------- The main function -----------------------------------#
+def process_block(
+ block_index,
+ crop,
+ input_zarr,
+ model_kwargs,
+ eval_kwargs,
+ blocksize,
+ overlap,
+ output_zarr,
+ preprocessing_steps=[],
+ worker_logs_directory=None,
+ test_mode=False,
+):
+ """
+ Preprocess and segment one block, of many, with eventual merger
+ of all blocks in mind. The block is processed as follows:
+
+ (1) Read block from disk, preprocess, and segment.
+ (2) Remove overlaps.
+ (3) Get bounding boxes for every segment.
+ (4) Remap segment IDs to globally unique values.
+ (5) Write segments to disk.
+ (6) Get segmented block faces.
+
+ A user may want to test this function on one block before running
+ the distributed function. When test_mode=True, steps (5) and (6)
+ are omitted and replaced with:
+
+ (5) return remapped segments as a numpy array, boxes, and box_ids
+
+ Parameters
+ ----------
+ block_index : tuple
+ The (i, j, k, ...) index of the block in the overall block grid
+
+ crop : tuple of slice objects
+ The bounding box of the data to read from the input_zarr array
+
+ input_zarr : zarr.core.Array
+ The image data we want to segment
+
+ preprocessing_steps : list of tuples (default: the empty list)
+ Optionally apply an arbitrary pipeline of preprocessing steps
+ to the image block before running cellpose.
+
+ Must be in the following format:
+ [(f, {'arg1':val1, ...}), ...]
+ That is, each tuple must contain only two elements, a function
+ and a dictionary. The function must have the following signature:
+ def F(image, ..., crop=None)
+ That is, the first argument must be a numpy array, which will later
+ be populated by the image data. The function must also take a keyword
+ argument called crop, even if it is not used in the function itself.
+ All other arguments to the function are passed using the dictionary.
+ Here is an example:
+
+ def F(image, sigma, crop=None):
+ return gaussian_filter(image, sigma)
+ def G(image, radius, crop=None):
+ return median_filter(image, radius)
+ preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})]
+
+ model_kwargs : dict
+ Arguments passed to cellpose.models.Cellpose
+ This is how you select and parameterize a model.
+
+ eval_kwargs : dict
+ Arguments passed to the eval function of the Cellpose model
+ This is how you parameterize model evaluation.
+
+ blocksize : iterable (list, tuple, np.ndarray)
+ The number of voxels (the shape) of blocks without overlaps
+
+ overlap : int
+ The number of voxels added to the blocksize to provide context
+ at the edges
+
+ output_zarr : zarr.core.Array
+ A location where segments can be stored temporarily before
+ merger is complete
+
+ worker_logs_directory : string (default: None)
+ A directory path where log files for each worker can be created
+ The directory must exist
+
+ test_mode : bool (default: False)
+ The primary use case of this function is to be called by
+ distributed_eval (defined later in this same module). However
+ you may want to call this function manually to test what
+ happens to an individual block; this is a good idea before
+ ramping up to process big data and also useful for debugging.
+
+ When test_mode is False (default) this function stores
+ the segments and returns objects needed for merging between
+ blocks.
+
+ When test_mode is True this function does not store the
+ segments, and instead returns them to the caller as a numpy
+ array. The boxes and box IDs are also returned. When test_mode
+ is True, you can supply dummy values for many of the inputs,
+ such as:
+
+ block_index = (0, 0, 0)
+ output_zarr=None
+
+ Returns
+ -------
+ If test_mode == False (the default), three things are returned:
+ faces : a list of numpy arrays - the faces of the block segments
+ boxes : a list of crops (tuples of slices), bounding boxes of segments
+ box_ids : 1D numpy array, parallel to boxes, the segment IDs of the
+ boxes
+
+ If test_mode == True, three things are returned:
+ segments : np.ndarray containing the segments with globally unique IDs
+ boxes : a list of crops (tuples of slices), bounding boxes of segments
+ box_ids : 1D numpy array, parallel to boxes, the segment IDs of the
+ boxes
+ """
+ print('RUNNING BLOCK: ', block_index, '\tREGION: ', crop, flush=True)
+ segmentation = read_preprocess_and_segment(
+ input_zarr, crop, preprocessing_steps, model_kwargs, eval_kwargs,
+ worker_logs_directory,
+ )
+ segmentation, crop = remove_overlaps(
+ segmentation, crop, overlap, blocksize,
+ )
+ boxes = bounding_boxes_in_global_coordinates(segmentation, crop)
+ nblocks = get_nblocks(input_zarr.shape, blocksize)
+ segmentation, remap = global_segment_ids(segmentation, block_index, nblocks)
+ if remap[0] == 0: remap = remap[1:]
+
+ if test_mode: return segmentation, boxes, remap
+ output_zarr[tuple(crop)] = segmentation
+ faces = block_faces(segmentation)
+ return faces, boxes, remap
+
+
+#----------------------- component functions ---------------------------------#
+def read_preprocess_and_segment(
+ input_zarr,
+ crop,
+ preprocessing_steps,
+ model_kwargs,
+ eval_kwargs,
+ worker_logs_directory,
+):
+ """Read block from zarr array, run all preprocessing steps, run cellpose"""
+ image = input_zarr[crop]
+ for pp_step in preprocessing_steps:
+ pp_step[1]['crop'] = crop
+ image = pp_step[0](image, **pp_step[1])
+ log_file=None
+ if worker_logs_directory is not None:
+ log_file = f'dask_worker_{distributed.get_worker().name}.log'
+ log_file = pathlib.Path(worker_logs_directory).joinpath(log_file)
+ cellpose.io.logger_setup(stdout_file_replacement=log_file)
+ model = cellpose.models.CellposeModel(**model_kwargs)
+ return model.eval(image, **eval_kwargs)[0].astype(np.uint32)
+
+
+def remove_overlaps(array, crop, overlap, blocksize):
+ """overlaps only there to provide context for boundary voxels
+ and can be removed after segmentation is complete
+ reslice array to remove the overlaps"""
+ crop_trimmed = list(crop)
+ for axis in range(array.ndim):
+ if crop[axis].start != 0:
+ slc = [slice(None),]*array.ndim
+ slc[axis] = slice(overlap, None)
+ array = array[tuple(slc)]
+ a, b = crop[axis].start, crop[axis].stop
+ crop_trimmed[axis] = slice(a + overlap, b)
+ if array.shape[axis] > blocksize[axis]:
+ slc = [slice(None),]*array.ndim
+ slc[axis] = slice(None, blocksize[axis])
+ array = array[tuple(slc)]
+ a = crop_trimmed[axis].start
+ crop_trimmed[axis] = slice(a, a + blocksize[axis])
+ return array, crop_trimmed
+
+
+def bounding_boxes_in_global_coordinates(segmentation, crop):
+ """bounding boxes (tuples of slices) are super useful later
+ best to compute them now while things are distributed"""
+ boxes = scipy.ndimage.find_objects(segmentation)
+ boxes = [b for b in boxes if b is not None]
+ translate = lambda a, b: slice(a.start+b.start, a.start+b.stop)
+ for iii, box in enumerate(boxes):
+ boxes[iii] = tuple(translate(a, b) for a, b in zip(crop, box))
+ return boxes
+
+
+def get_nblocks(shape, blocksize):
+ """Given a shape and blocksize determine the number of blocks per axis"""
+ return np.ceil(np.array(shape) / blocksize).astype(int)
+
+
+def global_segment_ids(segmentation, block_index, nblocks):
+ """pack the block index into the segment IDs so they are
+ globally unique. Everything gets remapped to [1..N] later.
+ A uint32 is split into 5 digits on left and 5 digits on right.
+ This creates limits: 42950 maximum number of blocks and
+ 99999 maximum number of segments per block"""
+ unique, unique_inverse = np.unique(segmentation, return_inverse=True)
+ p = str(np.ravel_multi_index(block_index, nblocks))
+ remap = [np.uint32(p+str(x).zfill(5)) for x in unique]
+ if unique[0] == 0: remap[0] = np.uint32(0) # 0 should just always be 0
+ segmentation = np.array(remap)[unique_inverse.reshape(segmentation.shape)]
+ return segmentation, remap
+
+
+def block_faces(segmentation):
+ """slice faces along every axis"""
+ faces = []
+ for iii in range(segmentation.ndim):
+ a = [slice(None),] * segmentation.ndim
+ a[iii] = slice(0, 1)
+ faces.append(segmentation[tuple(a)])
+ a = [slice(None),] * segmentation.ndim
+ a[iii] = slice(-1, None)
+ faces.append(segmentation[tuple(a)])
+ return faces
+
+
+
+
+######################## Distributed Cellpose #################################
+
+#----------------------- The main function -----------------------------------#
+@cluster
+def distributed_eval(
+ input_zarr,
+ blocksize,
+ write_path,
+ mask=None,
+ preprocessing_steps=[],
+ model_kwargs={},
+ eval_kwargs={},
+ cluster=None,
+ cluster_kwargs={},
+ temporary_directory=None,
+):
+ """
+ Evaluate a cellpose model on overlapping blocks of a big image.
+ Distributed over workstation or cluster resources with Dask.
+ Optionally run preprocessing steps on the blocks before running cellpose.
+ Optionally use a mask to ignore background regions in image.
+ Either cluster or cluster_kwargs parameter must be set to a
+ non-default value; please read these parameter descriptions below.
+ If using cluster_kwargs, the workstation and Janelia LSF cluster cases
+ are distinguished by the arguments present in the dictionary.
+
+ PC/Mac/Linux workstations and the Janelia LSF cluster are supported;
+ running on a different institute cluster will require implementing your
+ own dask cluster class. Look at the JaneliaLSFCluster class in this
+ module as an example, also look at the dask_jobqueue library. A PR with
+ a solid start is the right way to get help running this on your own
+ institute cluster.
+
+ If running on a workstation, please read the docstring for the
+ LocalCluster class defined in this module. That will tell you what to
+ put in the cluster_kwargs dictionary. If using the Janelia cluster,
+ please read the docstring for the JaneliaLSFCluster class.
+
+ Parameters
+ ----------
+ input_zarr : zarr.core.Array
+ A zarr.core.Array instance containing the image data you want to
+ segment.
+
+ blocksize : iterable
+ The size of blocks in voxels. E.g. [128, 256, 256]
+
+ write_path : string
+ The location of a zarr file on disk where you'd like to write your results
+
+ mask : numpy.ndarray (default: None)
+ A foreground mask for the image data; may be at a different resolution
+ (e.g. lower) than the image data. If given, only blocks that contain
+ foreground will be processed. This can save considerable time and
+ expense. It is assumed that the domain of the input_zarr image data
+ and the mask is the same in physical units, but they may be on
+ different sampling/voxel grids.
+
+ preprocessing_steps : list of tuples (default: the empty list)
+ Optionally apply an arbitrary pipeline of preprocessing steps
+ to the image blocks before running cellpose.
+
+ Must be in the following format:
+ [(f, {'arg1':val1, ...}), ...]
+ That is, each tuple must contain only two elements, a function
+ and a dictionary. The function must have the following signature:
+ def F(image, ..., crop=None)
+ That is, the first argument must be a numpy array, which will later
+ be populated by the image data. The function must also take a keyword
+ argument called crop, even if it is not used in the function itself.
+ All other arguments to the function are passed using the dictionary.
+ Here is an example:
+
+ def F(image, sigma, crop=None):
+ return gaussian_filter(image, sigma)
+ def G(image, radius, crop=None):
+ return median_filter(image, radius)
+ preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})]
+
+ model_kwargs : dict (default: {})
+ Arguments passed to cellpose.models.Cellpose
+
+ eval_kwargs : dict (default: {})
+ Arguments passed to cellpose.models.Cellpose.eval
+
+ cluster : A dask cluster object (default: None)
+ Only set if you have constructed your own static cluster. The default
+ behavior is to construct a dask cluster for the duration of this function,
+ then close it when the function is finished.
+
+ cluster_kwargs : dict (default: {})
+ Arguments used to parameterize your cluster.
+ If you are running locally, see the docstring for the myLocalCluster
+ class in this module. If you are running on the Janelia LSF cluster, see
+ the docstring for the janeliaLSFCluster class in this module. If you are
+ running on a different institute cluster, you may need to implement
+ a dask cluster object that conforms to the requirements of your cluster.
+
+ temporary_directory : string (default: None)
+ Temporary files are created during segmentation. The temporary files
+ will be in their own folder within the temporary_directory. The default
+ is the current directory. Temporary files are removed if the function
+ completes successfully.
+
+ Returns
+ -------
+ Two values are returned:
+ (1) A reference to the zarr array on disk containing the stitched cellpose
+ segments for your entire image
+ (2) Bounding boxes for every segment. This is a list of tuples of slices:
+ [(slice(z1, z2), slice(y1, y2), slice(x1, x2)), ...]
+ The list is sorted according to segment ID. That is the smallest segment
+ ID is the first tuple in the list, the largest segment ID is the last
+ tuple in the list.
+ """
+
+ timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
+ worker_logs_dirname = f'dask_worker_logs_{timestamp}'
+ worker_logs_dir = pathlib.Path().absolute().joinpath(worker_logs_dirname)
+ worker_logs_dir.mkdir()
+
+ if 'diameter' not in eval_kwargs.keys():
+ eval_kwargs['diameter'] = 30
+ overlap = eval_kwargs['diameter'] * 2
+ block_indices, block_crops = get_block_crops(
+ input_zarr.shape, blocksize, overlap, mask,
+ )
+
+ # I hate indenting all that code just for the tempdir
+ # but context manager is the only way to really guarantee that
+ # the tempdir gets cleaned up even after unhandled exceptions
+ with tempfile.TemporaryDirectory(
+ prefix='.', suffix='_distributed_cellpose_tempdir',
+ dir=temporary_directory or os.getcwd(),
+ ) as temporary_directory:
+
+ temp_zarr_path = temporary_directory + '/segmentation_unstitched.zarr'
+ temp_zarr = zarr.open(
+ temp_zarr_path, 'w',
+ shape=input_zarr.shape,
+ chunks=blocksize,
+ dtype=np.uint32,
+ )
+
+ futures = cluster.client.map(
+ process_block,
+ block_indices,
+ block_crops,
+ input_zarr=input_zarr,
+ preprocessing_steps=preprocessing_steps,
+ model_kwargs=model_kwargs,
+ eval_kwargs=eval_kwargs,
+ blocksize=blocksize,
+ overlap=overlap,
+ output_zarr=temp_zarr,
+ worker_logs_directory=str(worker_logs_dir),
+ )
+ results = cluster.client.gather(futures)
+ if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
+ cluster.scale(0)
+
+ faces, boxes_, box_ids_ = list(zip(*results))
+ boxes = [box for sublist in boxes_ for box in sublist]
+ box_ids = np.concatenate(box_ids_).astype(int) # unsure how but without cast these are float64
+ new_labeling = determine_merge_relabeling(block_indices, faces, box_ids)
+ debug_unique = np.unique(new_labeling)
+ new_labeling_path = temporary_directory + '/new_labeling.npy'
+ np.save(new_labeling_path, new_labeling)
+
+ # stitching step is cheap, we should release gpus and use small workers
+ if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
+ cluster.change_worker_attributes(
+ min_workers=cluster.locals_store['min_workers'],
+ max_workers=cluster.locals_store['max_workers'],
+ ncpus=1,
+ memory="15GB",
+ mem=int(15e9),
+ queue=None,
+ job_extra_directives=[],
+ )
+
+ segmentation_da = dask.array.from_zarr(temp_zarr)
+ relabeled = dask.array.map_blocks(
+ lambda block: np.load(new_labeling_path)[block],
+ segmentation_da,
+ dtype=np.uint32,
+ chunks=segmentation_da.chunks,
+ )
+ dask.array.to_zarr(relabeled, write_path, overwrite=True)
+ merged_boxes = merge_all_boxes(boxes, new_labeling[box_ids])
+ return zarr.open(write_path, mode='r'), merged_boxes
+
+
+#----------------------- component functions ---------------------------------#
+def get_block_crops(shape, blocksize, overlap, mask):
+ """Given a voxel grid shape, blocksize, and overlap size, construct
+ tuples of slices for every block; optionally only include blocks
+ that contain foreground in the mask. Returns parallel lists,
+ the block indices and the slice tuples."""
+ blocksize = np.array(blocksize)
+ if mask is not None:
+ ratio = np.array(mask.shape) / shape
+ mask_blocksize = np.round(ratio * blocksize).astype(int)
+
+ indices, crops = [], []
+ nblocks = get_nblocks(shape, blocksize)
+ for index in np.ndindex(*nblocks):
+ start = blocksize * index - overlap
+ stop = start + blocksize + 2 * overlap
+ start = np.maximum(0, start)
+ stop = np.minimum(shape, stop)
+ crop = tuple(slice(x, y) for x, y in zip(start, stop))
+
+ foreground = True
+ if mask is not None:
+ start = mask_blocksize * index
+ stop = start + mask_blocksize
+ stop = np.minimum(mask.shape, stop)
+ mask_crop = tuple(slice(x, y) for x, y in zip(start, stop))
+ if not np.any(mask[mask_crop]): foreground = False
+ if foreground:
+ indices.append(index)
+ crops.append(crop)
+ return indices, crops
+
+
+def determine_merge_relabeling(block_indices, faces, used_labels):
+ """Determine boundary segment mergers, remap all label IDs to merge
+ and put all label IDs in range [1..N] for N global segments found"""
+ faces = adjacent_faces(block_indices, faces)
+ # FIX float parameters
+ # print("Used labels:", used_labels, "Type:", type(used_labels))
+ used_labels = used_labels.astype(int)
+ # print("Used labels:", used_labels, "Type:", type(used_labels))
+ label_range = int(np.max(used_labels))
+
+ label_groups = block_face_adjacency_graph(faces, label_range)
+ new_labeling = scipy.sparse.csgraph.connected_components(
+ label_groups, directed=False)[1]
+ # XXX: new_labeling is returned as int32. Loses half range. Potentially a problem.
+ unused_labels = np.ones(label_range + 1, dtype=bool)
+ unused_labels[used_labels] = 0
+ new_labeling[unused_labels] = 0
+ unique, unique_inverse = np.unique(new_labeling, return_inverse=True)
+ new_labeling = np.arange(len(unique), dtype=np.uint32)[unique_inverse]
+ return new_labeling
+
+
+def adjacent_faces(block_indices, faces):
+ """Find faces which touch and pair them together in new data structure"""
+ face_pairs = []
+ faces_index_lookup = {a:b for a, b in zip(block_indices, faces)}
+ for block_index in block_indices:
+ for ax in range(len(block_index)):
+ neighbor_index = np.array(block_index)
+ neighbor_index[ax] += 1
+ neighbor_index = tuple(neighbor_index)
+ try:
+ a = faces_index_lookup[block_index][2*ax + 1]
+ b = faces_index_lookup[neighbor_index][2*ax]
+ face_pairs.append( np.concatenate((a, b), axis=ax) )
+ except KeyError:
+ continue
+ return face_pairs
+
+
+def block_face_adjacency_graph(faces, nlabels):
+ """Shrink labels in face plane, then find which labels touch across the
+ face boundary"""
+ # FIX float parameters
+ # print("Initial nlabels:", nlabels, "Type:", type(nlabels))
+ nlabels = int(nlabels)
+ # print("Final nlabels:", nlabels, "Type:", type(nlabels))
+
+ all_mappings = []
+ structure = scipy.ndimage.generate_binary_structure(3, 1)
+ for face in faces:
+ sl0 = tuple(slice(0, 1) if d==2 else slice(None) for d in face.shape)
+ sl1 = tuple(slice(1, 2) if d==2 else slice(None) for d in face.shape)
+ a = shrink_labels(face[sl0], 1.0)
+ b = shrink_labels(face[sl1], 1.0)
+ face = np.concatenate((a, b), axis=np.argmin(a.shape))
+ mapped = dask_image.ndmeasure._utils._label._across_block_label_grouping(face, structure)
+ all_mappings.append(mapped)
+ i, j = np.concatenate(all_mappings, axis=1)
+ v = np.ones_like(i)
+ return scipy.sparse.coo_matrix((v, (i, j)), shape=(nlabels+1, nlabels+1)).tocsr()
+
+
+def shrink_labels(plane, threshold):
+ """Shrink labels in plane by some distance from their boundary"""
+ gradmag = np.linalg.norm(np.gradient(plane.squeeze()), axis=0)
+ shrunk_labels = np.copy(plane.squeeze())
+ shrunk_labels[gradmag > 0] = 0
+ distances = scipy.ndimage.distance_transform_edt(shrunk_labels)
+ shrunk_labels[distances <= threshold] = 0
+ return shrunk_labels.reshape(plane.shape)
+
+
+def merge_all_boxes(boxes, box_ids):
+ """Merge all boxes that map to the same box_ids"""
+ merged_boxes = []
+ boxes_array = np.array(boxes, dtype=object)
+ # FIX float parameters
+ # print("Box IDs:", box_ids, "Type:", type(box_ids))
+ box_ids = box_ids.astype(int)
+ # print("Box IDs:", box_ids, "Type:", type(box_ids))
+
+ for iii in np.unique(box_ids):
+ merge_indices = np.argwhere(box_ids == iii).squeeze()
+ if merge_indices.shape:
+ merged_box = merge_boxes(boxes_array[merge_indices])
+ else:
+ merged_box = boxes_array[merge_indices]
+ merged_boxes.append(merged_box)
+ return merged_boxes
+
+
+def merge_boxes(boxes):
+ """Take union of two or more parallelpipeds"""
+ box_union = boxes[0]
+ for iii in range(1, len(boxes)):
+ local_union = []
+ for s1, s2 in zip(box_union, boxes[iii]):
+ start = min(s1.start, s2.start)
+ stop = max(s1.stop, s2.stop)
+ local_union.append(slice(start, stop))
+ box_union = tuple(local_union)
+ return box_union
+
+
diff --git a/cellpose/core.py b/cellpose/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4297ee727588699b5075aee6631ce07bc942f5f
--- /dev/null
+++ b/cellpose/core.py
@@ -0,0 +1,331 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess
+import logging
+import numpy as np
+from tqdm import trange, tqdm
+from urllib.parse import urlparse
+import tempfile
+import cv2
+from scipy.stats import mode
+import fastremap
+from . import transforms, dynamics, utils, plot, metrics, resnet_torch
+
+import torch
+from torch import nn
+from torch.utils import mkldnn as mkldnn_utils
+
+TORCH_ENABLED = True
+
+core_logger = logging.getLogger(__name__)
+tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)
+
+
+def use_gpu(gpu_number=0, use_torch=True):
+ """
+ Check if GPU is available for use.
+
+ Args:
+ gpu_number (int): The index of the GPU to be used. Default is 0.
+ use_torch (bool): Whether to use PyTorch for GPU check. Default is True.
+
+ Returns:
+ bool: True if GPU is available, False otherwise.
+
+ Raises:
+ ValueError: If use_torch is False, as cellpose only runs with PyTorch now.
+ """
+ if use_torch:
+ return _use_gpu_torch(gpu_number)
+ else:
+ raise ValueError("cellpose only runs with PyTorch now")
+
+
+def _use_gpu_torch(gpu_number=0):
+ """
+ Checks if CUDA or MPS is available and working with PyTorch.
+
+ Args:
+ gpu_number (int): The GPU device number to use (default is 0).
+
+ Returns:
+ bool: True if CUDA or MPS is available and working, False otherwise.
+ """
+ try:
+ device = torch.device("cuda:" + str(gpu_number))
+ _ = torch.zeros((1,1)).to(device)
+ core_logger.info("** TORCH CUDA version installed and working. **")
+ return True
+ except:
+ pass
+ try:
+ device = torch.device('mps:' + str(gpu_number))
+ _ = torch.zeros((1,1)).to(device)
+ core_logger.info('** TORCH MPS version installed and working. **')
+ return True
+ except:
+ core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.')
+ return False
+
+
+def assign_device(use_torch=True, gpu=False, device=0):
+ """
+ Assigns the device (CPU or GPU or mps) to be used for computation.
+
+ Args:
+ use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True.
+ gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
+ device (int or str, optional): The device index or name to be used. Defaults to 0.
+
+ Returns:
+ torch.device, bool (True if GPU is used, False otherwise)
+ """
+
+ if isinstance(device, str):
+ if device != "mps" or not(gpu and torch.backends.mps.is_available()):
+ device = int(device)
+ if gpu and use_gpu(use_torch=True):
+ try:
+ if torch.cuda.is_available():
+ device = torch.device(f'cuda:{device}')
+ core_logger.info(">>>> using GPU (CUDA)")
+ gpu = True
+ cpu = False
+ except:
+ gpu = False
+ cpu = True
+ try:
+ if torch.backends.mps.is_available():
+ device = torch.device('mps')
+ core_logger.info(">>>> using GPU (MPS)")
+ gpu = True
+ cpu = False
+ except:
+ gpu = False
+ cpu = True
+ else:
+ device = torch.device('cpu')
+ core_logger.info('>>>> using CPU')
+ gpu = False
+ cpu = True
+
+ if cpu:
+ device = torch.device("cpu")
+ core_logger.info(">>>> using CPU")
+ gpu = False
+ return device, gpu
+
+
+def check_mkl(use_torch=True):
+ """
+ Checks if MKL-DNN is enabled and working.
+
+ Args:
+ use_torch (bool, optional): Whether to use torch. Defaults to True.
+
+ Returns:
+ bool: True if MKL-DNN is enabled, False otherwise.
+ """
+ mkl_enabled = torch.backends.mkldnn.is_available()
+ if mkl_enabled:
+ mkl_enabled = True
+ else:
+ core_logger.info(
+ "WARNING: MKL version on torch not working/installed - CPU version will be slightly slower."
+ )
+ core_logger.info(
+ "see https://pytorch.org/docs/stable/backends.html?highlight=mkl")
+ return mkl_enabled
+
+
+def _to_device(x, device):
+ """
+ Converts the input tensor or numpy array to the specified device.
+
+ Args:
+ x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
+ device (torch.device): The target device.
+
+ Returns:
+ torch.Tensor: The converted tensor on the specified device.
+ """
+ if not isinstance(x, torch.Tensor):
+ X = torch.from_numpy(x).to(device, dtype=torch.float32)
+ return X
+ else:
+ return x
+
+
+def _from_device(X):
+ """
+ Converts a PyTorch tensor from the device to a NumPy array on the CPU.
+
+ Args:
+ X (torch.Tensor): The input PyTorch tensor.
+
+ Returns:
+ numpy.ndarray: The converted NumPy array.
+ """
+ x = X.detach().cpu().numpy()
+ return x
+
+
+def _forward(net, x):
+ """Converts images to torch tensors, runs the network model, and returns numpy arrays.
+
+ Args:
+ net (torch.nn.Module): The network model.
+ x (numpy.ndarray): The input images.
+
+ Returns:
+ Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
+ """
+ X = _to_device(x, net.device)
+ net.eval()
+ if net.mkldnn:
+ net = mkldnn_utils.to_mkldnn(net)
+ with torch.no_grad():
+ y, style = net(X)[:2]
+ del X
+ y = _from_device(y)
+ style = _from_device(style)
+ return y, style
+
+
+def run_net(net, imgi, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
+ rsz=None):
+ """
+ Run network on stack of images.
+
+ (faster if augment is False)
+
+ Args:
+ net (class): cellpose network (model.net)
+ imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan].
+ batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
+ rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
+ augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
+ tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
+ bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
+
+ Returns:
+ Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
+ y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
+ style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
+ """
+ # run network
+ nout = net.nout
+ Lz, Ly0, Lx0, nchan = imgi.shape
+ if rsz is not None:
+ if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
+ rsz = [rsz, rsz]
+ Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1])
+ else:
+ Lyr, Lxr = Ly0, Lx0
+ ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr)
+ pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
+ Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2
+ if augment:
+ ny = max(2, int(np.ceil(2. * Ly / bsize)))
+ nx = max(2, int(np.ceil(2. * Lx / bsize)))
+ ly, lx = bsize, bsize
+ else:
+ ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
+ nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
+ ly, lx = min(bsize, Ly), min(bsize, Lx)
+ yf = np.zeros((Lz, nout, Ly, Lx), "float32")
+ styles = np.zeros((Lz, 256), "float32")
+
+ # run multiple slices at the same time
+ ntiles = ny * nx
+ nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch
+ niter = int(np.ceil(Lz / nimgs))
+ ziterator = (trange(niter, file=tqdm_out, mininterval=30)
+ if niter > 10 or Lz > 1 else range(niter))
+ for k in ziterator:
+ inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
+ IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32")
+ for i, b in enumerate(inds):
+ # pad image for net so Ly and Lx are divisible by 4
+ imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy()
+ imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant")
+ IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(
+ imgb, bsize=bsize, augment=augment,
+ tile_overlap=tile_overlap)
+ IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
+ (ny * nx, nchan, ly, lx))
+
+ ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32")
+ stylea = np.zeros((IMGa.shape[0], 256), "float32")
+ for j in range(0, IMGa.shape[0], batch_size):
+ bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
+ ya[bslc], stylea[bslc] = _forward(net, IMGa[bslc])
+ for i, b in enumerate(inds):
+ y = ya[i * ntiles : (i + 1) * ntiles]
+ if augment:
+ y = np.reshape(y, (ny, nx, 3, ly, lx))
+ y = transforms.unaugment_tiles(y)
+ y = np.reshape(y, (-1, 3, ly, lx))
+ yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx)
+ yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]]
+ stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
+ stylei /= (stylei**2).sum()**0.5
+ styles[b] = stylei
+ # slices from padding
+ yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
+ yf = yf.transpose(0,2,3,1)
+ return yf, np.array(styles)
+
+
+def run_3D(net, imgs, batch_size=8, augment=False,
+ tile_overlap=0.1, bsize=224, net_ortho=None,
+ progress=None):
+ """
+ Run network on image z-stack.
+
+ (faster if augment is False)
+
+ Args:
+ imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
+ batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
+ rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
+ augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
+ tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
+ bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
+ net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
+ progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
+
+ Returns:
+ Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
+ y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
+ style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
+ """
+ sstr = ["YX", "ZY", "ZX"]
+ pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
+ ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
+ cp = [(1, 2), (0, 2), (0, 1)]
+ cpy = [(0, 1), (0, 1), (0, 1)]
+ shape = imgs.shape[:-1]
+ #cellprob = np.zeros(shape, "float32")
+ yf = np.zeros((*shape, 4), "float32")
+ for p in range(3):
+ xsl = imgs.transpose(pm[p])
+ # per image
+ core_logger.info("running %s: %d planes of size (%d, %d)" %
+ (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
+ y, style = run_net(net if p==0 or net_ortho is None else net_ortho,
+ xsl, batch_size=batch_size, augment=augment,
+ bsize=bsize, tile_overlap=tile_overlap,
+ rsz=None)
+ yf[..., -1] += y[..., -1].transpose(ipm[p])
+ for j in range(2):
+ yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
+ y = None; del y
+
+ if progress is not None:
+ progress.setValue(25 + 15 * p)
+
+ return yf, style
diff --git a/cellpose/denoise.py b/cellpose/denoise.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b3ef5d54bbca9be70a9b416151df17d407a0f5b
--- /dev/null
+++ b/cellpose/denoise.py
@@ -0,0 +1,1484 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os, time, datetime
+import numpy as np
+from scipy.stats import mode
+import cv2
+import torch
+from torch import nn
+from torch.nn.functional import conv2d, interpolate
+from tqdm import trange
+from pathlib import Path
+
+import logging
+
+denoise_logger = logging.getLogger(__name__)
+
+from cellpose import transforms, resnet_torch, utils, io
+from cellpose.core import run_net
+from cellpose.resnet_torch import CPnet
+from cellpose.models import CellposeModel, model_path, normalize_default, assign_device, check_mkl
+
+MODEL_NAMES = []
+for ctype in ["cyto3", "cyto2", "nuclei"]:
+ for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
+ MODEL_NAMES.append(f"{ntype}_{ctype}")
+ if ctype != "cyto3":
+ for ltype in ["per", "seg", "rec"]:
+ MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
+ if ctype != "cyto3":
+ MODEL_NAMES.append(f"aniso_{ctype}")
+
+criterion = nn.MSELoss(reduction="mean")
+criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
+
+
+def deterministic(seed=0):
+ """ set random seeds to create test data """
+ import random
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+ np.random.seed(seed) # Numpy module.
+ random.seed(seed) # Python random module.
+ torch.manual_seed(seed)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+
+
+def loss_fn_rec(lbl, y):
+ """ loss function between true labels lbl and prediction y """
+ loss = 80. * criterion(y, lbl)
+ return loss
+
+
+def loss_fn_seg(lbl, y):
+ """ loss function between true labels lbl and prediction y """
+ veci = 5. * lbl[:, 1:]
+ lbl = (lbl[:, 0] > .5).float()
+ loss = criterion(y[:, :2], veci)
+ loss /= 2.
+ loss2 = criterion2(y[:, 2], lbl)
+ loss = loss + loss2
+ return loss
+
+
+def get_sigma(Tdown):
+ """ Calculates the correlation matrices across channels for the perceptual loss.
+
+ Args:
+ Tdown (list): List of tensors output by each downsampling block of network.
+
+ Returns:
+ list: List of correlations for each input tensor.
+ """
+ Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
+ Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
+ Sigma = [
+ torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
+ for x in Tnorm
+ ]
+ return Sigma
+
+
+def imstats(X, net1):
+ """
+ Calculates the image correlation matrices for the perceptual loss.
+
+ Args:
+ X (torch.Tensor): Input image tensor.
+ net1: Cellpose net.
+
+ Returns:
+ list: A list of tensors of correlation matrices.
+ """
+ _, _, Tdown = net1(X)
+ Sigma = get_sigma(Tdown)
+ Sigma = [x.detach() for x in Sigma]
+ return Sigma
+
+
+def loss_fn_per(img, net1, yl):
+ """
+ Calculates the perceptual loss function for image restoration.
+
+ Args:
+ img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
+ net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
+ yl (torch.Tensor): Clean image tensor.
+
+ Returns:
+ torch.Tensor: Mean perceptual loss.
+ """
+ Sigma = imstats(img, net1)
+ sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
+ Sigma_test = get_sigma(yl)
+ losses = torch.zeros(len(Sigma[0]), device=img.device)
+ for k in range(len(Sigma)):
+ losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
+ return losses.mean()
+
+
+def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
+ """
+ Calculates the test loss for image restoration tasks.
+
+ Args:
+ net0 (torch.nn.Module): The image restoration network.
+ X (torch.Tensor): The input image tensor.
+ net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
+ img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
+ lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
+ lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
+
+ Returns:
+ tuple: A tuple containing the total loss and the perceptual loss.
+ """
+ net0.eval()
+ if net1 is not None:
+ net1.eval()
+ loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
+
+ with torch.no_grad():
+ img_dn = net0(X)[0]
+ if lam[2] > 0.:
+ loss += lam[2] * loss_fn_rec(img, img_dn)
+ if lam[1] > 0. or lam[0] > 0.:
+ y, _, ydown = net1(img_dn)
+ if lam[1] > 0.:
+ loss += lam[1] * loss_fn_seg(lbl, y)
+ if lam[0] > 0.:
+ loss_per = loss_fn_per(img, net1, ydown)
+ loss += lam[0] * loss_per
+ return loss, loss_per
+
+
+def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
+ """
+ Calculates the train loss for image restoration tasks.
+
+ Args:
+ net0 (torch.nn.Module): The image restoration network.
+ X (torch.Tensor): The input image tensor.
+ net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
+ img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
+ lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
+ lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
+
+ Returns:
+ tuple: A tuple containing the total loss and the perceptual loss.
+ """
+ net0.train()
+ if net1 is not None:
+ net1.eval()
+ loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
+
+ img_dn = net0(X)[0]
+ if lam[2] > 0.:
+ loss += lam[2] * loss_fn_rec(img, img_dn)
+ if lam[1] > 0. or lam[0] > 0.:
+ y, _, ydown = net1(img_dn)
+ if lam[1] > 0.:
+ loss += lam[1] * loss_fn_seg(lbl, y)
+ if lam[0] > 0.:
+ loss_per = loss_fn_per(img, net1, ydown)
+ loss += lam[0] * loss_per
+ return loss, loss_per
+
+
+def img_norm(imgi):
+ """
+ Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
+
+ Args:
+ imgi (torch.Tensor): Input image tensor.
+
+ Returns:
+ torch.Tensor: Normalized image tensor.
+ """
+ shape = imgi.shape
+ imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
+ perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
+ keepdim=True)
+ for k in range(imgi.shape[1]):
+ hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
+ imgi[hask, k] -= perc[0, hask, k]
+ imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
+ imgi = imgi.reshape(shape)
+ return imgi
+
+
+def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
+ ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
+ ds=None, uniform_blur=False, partial_blur=False):
+ """Adds noise to the input image.
+
+ Args:
+ lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
+ alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
+ beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
+ poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
+ blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
+ gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
+ downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
+ ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
+ diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
+ pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
+ iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
+ sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
+ sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
+ ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
+
+ Returns:
+ torch.Tensor: The noisy image tensor of the same shape as the input image.
+ """
+ device = lbl.device
+ imgi = torch.zeros_like(lbl)
+ Ly, Lx = lbl.shape[-2:]
+
+ diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
+ #ds0 = 1 if ds is None else ds.item()
+ ds = ds * torch.ones(
+ (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
+
+ # downsample
+ ii = []
+ idownsample = np.random.rand(len(lbl)) < downsample
+ if (ds is None and idownsample.sum() > 0.) or not iso:
+ ds = torch.ones(len(lbl), dtype=torch.long, device=device)
+ ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
+ device=device)
+ ii = torch.nonzero(ds > 1).flatten()
+ elif ds is not None and (ds > 1).sum():
+ ii = torch.nonzero(ds > 1).flatten()
+
+ # add gaussian blur
+ iblur = torch.rand(len(lbl), device=device) < blur
+ iblur[ii] = True
+ if iblur.sum() > 0:
+ if sigma0 is None:
+ if uniform_blur and iso:
+ xr = torch.rand(len(lbl), device=device)
+ if len(ii) > 0:
+ xr[ii] = ds[ii].float() / 2. / gblur
+ sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
+ sigma1 = sigma0.clone()
+ elif not iso:
+ xr = torch.rand(len(lbl), device=device)
+ if len(ii) > 0:
+ xr[ii] = (ds[ii].float()) / gblur
+ xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
+ xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
+ sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
+ sigma1 = sigma0.clone() / 10.
+ else:
+ xrand = np.random.exponential(1, size=iblur.sum())
+ xrand = np.clip(xrand * 0.5, 0.1, 1.0)
+ xrand *= gblur
+ sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
+ device)
+ sigma1 = sigma0.clone()
+ else:
+ sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
+ sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
+
+ # create gaussian filter
+ xr = max(8, sigma0.max().long() * 2)
+ gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
+ (2 * sigma0.unsqueeze(-1)**2))
+ gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
+ gfilt1 = torch.zeros_like(gfilt0)
+ gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
+ gfilt1[sigma1 != sigma0] = torch.exp(
+ -torch.arange(-xr + 1, xr, device=device)**2 /
+ (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
+ gfilt1[sigma1 == 0] = 0.
+ gfilt1[sigma1 == 0, xr] = 1.
+ gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
+ gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
+ gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
+
+ lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
+ padding=gfilt.shape[-1] // 2,
+ groups=gfilt.shape[0]).transpose(1, 0)
+ if partial_blur:
+ #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
+ imgi[iblur] = lbl[iblur].clone()
+ Lxc = int(Lx * 0.85)
+ ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
+ torch.arange(0, Lxc, dtype=torch.float32),
+ indexing="ij")
+ mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
+ mask -= mask.min()
+ mask /= mask.max()
+ lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
+ imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
+ (1-mask) * imgi[iblur, :, :, :Lxc])
+ else:
+ imgi[iblur] = lbl_blur
+
+ imgi[~iblur] = lbl[~iblur]
+
+ # apply downsample
+ for k in ii:
+ i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
+ imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
+
+ # add poisson noise
+ ipoisson = np.random.rand(len(lbl)) < poisson
+ if ipoisson.sum() > 0:
+ if pscale is None:
+ pscale = torch.zeros(len(lbl))
+ m = torch.distributions.gamma.Gamma(alpha, beta)
+ pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
+ #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
+ pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
+ else:
+ pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
+ imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
+ imgi[~ipoisson] = imgi[~ipoisson]
+
+ # renormalize
+ imgi = img_norm(imgi)
+
+ return imgi
+
+
+def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
+ downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
+ ds_max=7, uniform_blur=False, iso=True, rotate=True,
+ device=torch.device("cuda"), xy=(224, 224),
+ nchan_noise=1, keep_raw=True):
+ """
+ Applies random rotation, resizing, and noise to the input data.
+
+ Args:
+ data (numpy.ndarray): The input data.
+ labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
+ diams (float, optional): The diameter of the objects. Defaults to None.
+ poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
+ blur (float, optional): The blur probability. Defaults to 0.7.
+ downsample (float, optional): The downsample probability. Defaults to 0.0.
+ beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
+ gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
+ diam_mean (float, optional): The mean diameter. Defaults to 30.
+ ds_max (int, optional): The maximum downsample value. Defaults to 7.
+ iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
+ rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
+ device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
+ xy (tuple, optional): The size of the output image. Defaults to (224, 224).
+ nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
+ keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
+
+ Returns:
+ torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
+ torch.Tensor: The augmented labels.
+ float: The scale factor applied to the image.
+ """
+ if device == None:
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
+
+ diams = 30 if diams is None else diams
+ random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
+ random_rsc = diams / random_diam #/ random_diam
+ #rsc /= random_scale
+ xy0 = (340, 340)
+ nchan = data[0].shape[0]
+ data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
+ labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
+ for i in range(
+ len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
+ sc = random_rsc[i]
+ img = data[i]
+ lbl = labels[i] if labels is not None else None
+ # create affine transform to resize
+ Ly, Lx = img.shape[-2:]
+ dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
+ dxy = (np.random.rand(2,) - .5) * dxy
+ cc = np.array([Lx / 2, Ly / 2])
+ cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
+ pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
+ pts2 = np.float32(
+ [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
+ M = cv2.getAffineTransform(pts1, pts2)
+
+ # apply to image
+ for c in range(nchan):
+ img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
+ #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
+ data_new[i, c] = img_rsz
+ if keep_raw:
+ data_new[i, c + nchan] = img_rsz
+
+ if lbl is not None:
+ # apply to labels
+ labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
+ labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
+ labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
+
+ rsc = random_diam / diam_mean
+
+ # add noise before augmentations
+ img = torch.from_numpy(data_new).to(device)
+ img = torch.clamp(img, 0.)
+ # just add noise to cyto if nchan_noise=1
+ img[:, :nchan_noise] = add_noise(
+ img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
+ downsample=downsample, beta=beta, gblur=gblur,
+ diams=torch.from_numpy(random_diam).to(device).float())
+ # img -= img.mean(dim=(-2,-1), keepdim=True)
+ # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
+ img = img.cpu().numpy()
+
+ # augmentations
+ img, lbl, scale = transforms.random_rotate_and_resize(
+ img,
+ Y=labels_new,
+ xy=xy,
+ rotate=False if not iso else rotate,
+ #(iso and downsample==0),
+ rescale=rsc,
+ scale_range=0.5)
+ img = torch.from_numpy(img).to(device)
+ lbl = torch.from_numpy(lbl).to(device)
+
+ return img, lbl, scale
+
+
+def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
+ """
+ Creates a Cellpose network with a single input channel.
+
+ Args:
+ device (str): The device to run the network on.
+ model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
+ pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
+
+ Returns:
+ torch.nn.Module: The Cellpose network with a single input channel.
+ """
+ if pretrained_model is not None and not os.path.exists(pretrained_model):
+ model_type = pretrained_model
+ pretrained_model = None
+ nbase = [32, 64, 128, 256]
+ nchan = 1
+ net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
+ filename = model_path(model_type,
+ 0) if pretrained_model is None else pretrained_model
+ weights = torch.load(filename, weights_only=True)
+ zp = 0
+ print(filename)
+ for name in net1.state_dict():
+ if ("res_down_0.conv.conv_0" not in name and
+ #"output" not in name and
+ "res_down_0.proj" not in name and name != "diam_mean" and
+ name != "diam_labels"):
+ net1.state_dict()[name].copy_(weights[name])
+ elif "res_down_0" in name:
+ if len(weights[name].shape) > 0:
+ new_weight = torch.zeros_like(net1.state_dict()[name])
+ if weights[name].shape[0] == 2:
+ new_weight[:] = weights[name][0]
+ elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
+ new_weight[:, zp] = weights[name][:, 0]
+ else:
+ new_weight = weights[name]
+ else:
+ new_weight = weights[name]
+ net1.state_dict()[name].copy_(new_weight)
+ return net1
+
+
+class CellposeDenoiseModel():
+ """ model to run Cellpose and Image restoration """
+
+ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
+ restore_type="denoise_cyto3", nchan=2,
+ chan2_restore=False, device=None):
+
+ self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
+ device=device)
+ self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
+ pretrained_model=pretrained_model, device=device)
+
+ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
+ normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
+ augment=False, resample=True, invert=False, flow_threshold=0.4,
+ cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
+ min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
+ """
+ Restore array or list of images using the image restoration model, and then segment.
+
+ Args:
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
+ Defaults to None.
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
+ can also pass dictionary of parameters (all keys are optional, default values shown):
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
+ Defaults to True.
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
+ (only used if diameter is None). Defaults to None.
+ diameter (float, optional): diameter for each image,
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
+ augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
+ resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
+ invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
+ flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
+ cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
+ do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
+ stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
+ min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
+ flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
+ niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
+ interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
+
+ Returns:
+ A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
+ flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
+ styles: style vector summarizing each image of size 256;
+ imgs: Restored images.
+ """
+
+ if isinstance(normalize, dict):
+ normalize_params = {**normalize_default, **normalize}
+ elif not isinstance(normalize, bool):
+ raise ValueError("normalize parameter must be a bool or a dict")
+ else:
+ normalize_params = normalize_default
+ normalize_params["normalize"] = normalize
+ normalize_params["invert"] = invert
+
+ img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
+ channel_axis=channel_axis, z_axis=z_axis,
+ do_3D=do_3D,
+ normalize=normalize_params, rescale=rescale,
+ diameter=diameter,
+ tile_overlap=tile_overlap, bsize=bsize)
+
+ # turn off special normalization for segmentation
+ normalize_params = normalize_default
+
+ # change channels for segmentation
+ if channels is not None:
+ channels_new = [0, 0] if channels[0] == 0 else [1, 2]
+ else:
+ channels_new = None
+ # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
+ diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
+ masks, flows, styles = self.cp.eval(
+ img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
+ z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
+ normalize=normalize_params, rescale=rescale, diameter=diameter,
+ tile_overlap=tile_overlap, augment=augment, resample=resample,
+ invert=invert, flow_threshold=flow_threshold,
+ cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
+ stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
+ interp=interp, bsize=bsize)
+
+ return masks, flows, styles, img_restore
+
+
+class DenoiseModel():
+ """
+ DenoiseModel class for denoising images using Cellpose denoising model.
+
+ Args:
+ gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
+ pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
+ Can be a string or path. Defaults to False.
+ nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
+ model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
+ chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
+ diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
+ device (torch.device, optional): Device to use for computation. Defaults to None.
+
+ Attributes:
+ nchan (int): Number of channels in the input images.
+ diam_mean (float): Mean diameter of the objects in the images.
+ net (CPnet): Cellpose network for denoising.
+ pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
+ net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
+ net_type (str): Type of the denoising network.
+
+ Methods:
+ eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
+ normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
+ Denoise array or list of images using the denoising model.
+
+ _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
+ tile_overlap=0.1)
+ Run denoising model on a single channel.
+ """
+
+ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
+ chan2=False, diam_mean=30., device=None):
+ self.nchan = nchan
+ if pretrained_model and (not isinstance(pretrained_model, str) and
+ not isinstance(pretrained_model, Path)):
+ raise ValueError("pretrained_model must be a string or path")
+
+ self.diam_mean = diam_mean
+ builtin = True
+ if model_type is not None or (pretrained_model and
+ not os.path.exists(pretrained_model)):
+ pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
+ if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
+ pretrained_model_string = "denoise_cyto3"
+ pretrained_model = model_path(pretrained_model_string)
+ if (pretrained_model and not os.path.exists(pretrained_model)):
+ denoise_logger.warning("pretrained model has incorrect path")
+ denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
+ self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
+ else:
+ if pretrained_model:
+ builtin = False
+ pretrained_model_string = pretrained_model
+ denoise_logger.info(f">>>> loading model {pretrained_model_string}")
+
+ # assign network device
+ self.mkldnn = None
+ if device is None:
+ sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
+ self.device = device if device is not None else sdevice
+ if device is not None:
+ device_gpu = self.device.type == "cuda"
+ self.gpu = gpu if device is None else device_gpu
+ if not self.gpu:
+ self.mkldnn = check_mkl(True)
+
+ # create network
+ self.nchan = nchan
+ self.nclasses = 1
+ nbase = [32, 64, 128, 256]
+ self.nchan = nchan
+ self.nbase = [nchan, *nbase]
+
+ self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn,
+ max_pool=True, diam_mean=diam_mean).to(self.device)
+
+ self.pretrained_model = pretrained_model
+ self.net_chan2 = None
+ if self.pretrained_model:
+ self.net.load_model(self.pretrained_model, device=self.device)
+ denoise_logger.info(
+ f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
+ )
+ if chan2 and builtin:
+ chan2_path = model_path(
+ os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
+ print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
+ self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
+ mkldnn=self.mkldnn, max_pool=True,
+ diam_mean=17.).to(self.device)
+ self.net_chan2.load_model(chan2_path, device=self.device)
+ self.net_type = "cellpose_denoise"
+
+ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
+ normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
+ tile_overlap=0.1, bsize=224):
+ """
+ Restore array or list of images using the image restoration model.
+
+ Args:
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
+ Defaults to None.
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
+ can also pass dictionary of parameters (all keys are optional, default values shown):
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
+ Defaults to True.
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
+ (only used if diameter is None). Defaults to None.
+ diameter (float, optional): diameter for each image,
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
+
+ Returns:
+ list: A list of 2D/3D arrays of restored images
+
+ """
+ if isinstance(x, list) or x.squeeze().ndim == 5:
+ tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
+ nimg = len(x)
+ iterator = trange(nimg, file=tqdm_out,
+ mininterval=30) if nimg > 1 else range(nimg)
+ imgs = []
+ for i in iterator:
+ imgi = self.eval(
+ x[i], batch_size=batch_size,
+ channels=channels[i] if channels is not None and
+ ((len(channels) == len(x) and
+ (isinstance(channels[i], list) or
+ isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
+ else channels, channel_axis=channel_axis, z_axis=z_axis,
+ normalize=normalize,
+ do_3D=do_3D,
+ rescale=rescale[i] if isinstance(rescale, list) or
+ isinstance(rescale, np.ndarray) else rescale,
+ diameter=diameter[i] if isinstance(diameter, list) or
+ isinstance(diameter, np.ndarray) else diameter,
+ tile_overlap=tile_overlap, bsize=bsize)
+ imgs.append(imgi)
+ if isinstance(x, np.ndarray):
+ imgs = np.array(imgs)
+ return imgs
+
+ else:
+ # reshape image
+ x = transforms.convert_image(x, channels, channel_axis=channel_axis,
+ z_axis=z_axis, do_3D=do_3D, nchan=None)
+ if x.ndim < 4:
+ squeeze = True
+ x = x[np.newaxis, ...]
+ else:
+ squeeze = False
+
+ # may need to interpolate image before running upsampling
+ self.ratio = 1.
+ if "upsample" in self.pretrained_model:
+ Ly, Lx = x.shape[-3:-1]
+ if diameter is not None and 3 <= diameter < self.diam_mean:
+ self.ratio = self.diam_mean / diameter
+ denoise_logger.info(
+ f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
+ )
+ Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
+ x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
+ else:
+ denoise_logger.warning(
+ f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
+ )
+ #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
+
+ self.batch_size = batch_size
+
+ if diameter is not None and diameter > 0:
+ rescale = self.diam_mean / diameter
+ elif rescale is None:
+ rescale = 1.0
+
+ if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
+ x = x[..., :1]
+
+ for c in range(x.shape[-1]):
+ rescale0 = rescale * 30. / 17. if c == 1 else rescale
+ if c == 0 or self.net_chan2 is None:
+ x[...,
+ c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
+ normalize=normalize, rescale=rescale0,
+ tile_overlap=tile_overlap, bsize=bsize)[...,0]
+ else:
+ x[...,
+ c] = self._eval(self.net_chan2, x[...,
+ c:c + 1], batch_size=batch_size,
+ normalize=normalize, rescale=rescale0,
+ tile_overlap=tile_overlap, bsize=bsize)[...,0]
+ x = x[0] if squeeze else x
+ return x
+
+ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
+ tile_overlap=0.1, bsize=224):
+ """
+ Run image restoration model on a single channel.
+
+ Args:
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
+ can also pass dictionary of parameters (all keys are optional, default values shown):
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
+ Defaults to True.
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
+ (only used if diameter is None). Defaults to None.
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
+
+ Returns:
+ list: A list of 2D/3D arrays of restored images
+
+ """
+ if isinstance(normalize, dict):
+ normalize_params = {**normalize_default, **normalize}
+ elif not isinstance(normalize, bool):
+ raise ValueError("normalize parameter must be a bool or a dict")
+ else:
+ normalize_params = normalize_default
+ normalize_params["normalize"] = normalize
+
+ tic = time.time()
+ shape = x.shape
+ nimg = shape[0]
+
+ do_normalization = True if normalize_params["normalize"] else False
+
+ img = np.asarray(x)
+ if do_normalization:
+ img = transforms.normalize_img(img, **normalize_params)
+ if rescale != 1.0:
+ img = transforms.resize_image(img, rsz=rescale)
+ yf, style = run_net(self.net, img, bsize=bsize,
+ tile_overlap=tile_overlap)
+ yf = transforms.resize_image(yf, shape[1], shape[2])
+ imgs = yf
+ del yf, style
+
+ # imgs = np.zeros((*x.shape[:-1], 1), np.float32)
+ # for i in iterator:
+ # img = np.asarray(x[i])
+ # if do_normalization:
+ # img = transforms.normalize_img(img, **normalize_params)
+ # if rescale != 1.0:
+ # img = transforms.resize_image(img, rsz=[rescale, rescale])
+ # if img.ndim == 2:
+ # img = img[:, :, np.newaxis]
+ # yf, style = run_net(net, img, batch_size=batch_size, augment=False,
+ # tile=tile, tile_overlap=tile_overlap, bsize=bsize)
+ # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
+
+ # if img.ndim == 2:
+ # img = img[:, :, np.newaxis]
+ # imgs[i] = img
+ # del yf, style
+ net_time = time.time() - tic
+ if nimg > 1:
+ denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
+
+ return imgs
+
+
+def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
+ test_labels=None, test_files=None, train_probs=None, test_probs=None,
+ lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
+ save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
+ iso=True, uniform_blur=False, downsample=0., ds_max=7,
+ learning_rate=0.005, n_epochs=500,
+ weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
+ nimg_test_per_epoch=None, model_name=None):
+
+ # net properties
+ device = net.device
+ nchan = net.nchan
+ diam_mean = net.diam_mean.item()
+
+ args = np.array([poisson, beta, blur, gblur, downsample])
+ if args.ndim == 1:
+ args = args[:, np.newaxis]
+ poisson, beta, blur, gblur, downsample = args
+ nnoise = len(poisson)
+
+ d = datetime.datetime.now()
+ if save_path is not None:
+ if model_name is None:
+ filename = ""
+ lstrs = ["per", "seg", "rec"]
+ for k, (l, s) in enumerate(zip(lam, lstrs)):
+ filename += f"{s}_{l:.2f}_"
+ if not iso:
+ filename += "aniso_"
+ if poisson.sum() > 0:
+ filename += "poisson_"
+ if blur.sum() > 0:
+ filename += "blur_"
+ if downsample.sum() > 0:
+ filename += "downsample_"
+ filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
+ filename = os.path.join(save_path, filename)
+ else:
+ filename = os.path.join(save_path, model_name)
+ print(filename)
+ for i in range(len(poisson)):
+ denoise_logger.info(
+ f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
+ )
+ net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
+
+ learning_rate_const = learning_rate
+ LR = np.linspace(0, learning_rate_const, 10)
+ LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
+ for i in range(10):
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
+ learning_rate = LR
+
+ batch_size = 8
+ optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
+ weight_decay=weight_decay)
+ if train_data is not None:
+ nimg = len(train_data)
+ diam_train = np.array(
+ [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
+ diam_train[diam_train < 5] = 5.
+ if test_data is not None:
+ diam_test = np.array(
+ [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
+ diam_test[diam_test < 5] = 5.
+ nimg_test = len(test_data)
+ else:
+ nimg = len(train_files)
+ denoise_logger.info(">>> using files instead of loading dataset")
+ train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
+ denoise_logger.info(">>> computing diameters")
+ diam_train = np.array([
+ utils.diameters(io.imread(train_labels_files[k])[0])[0]
+ for k in trange(len(train_labels_files))
+ ])
+ diam_train[diam_train < 5] = 5.
+ if test_files is not None:
+ nimg_test = len(test_files)
+ test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
+ diam_test = np.array([
+ utils.diameters(io.imread(test_labels_files[k])[0])[0]
+ for k in trange(len(test_labels_files))
+ ])
+ diam_test[diam_test < 5] = 5.
+ train_probs = 1. / nimg * np.ones(nimg,
+ "float64") if train_probs is None else train_probs
+ if test_files is not None or test_data is not None:
+ test_probs = 1. / nimg_test * np.ones(
+ nimg_test, "float64") if test_probs is None else test_probs
+
+ tic = time.time()
+
+ nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
+ if test_files is not None or test_data is not None:
+ nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
+
+ nbatch = 0
+ train_losses, test_losses = [], []
+ for iepoch in range(n_epochs):
+ np.random.seed(iepoch)
+ rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
+ p=train_probs)
+ torch.manual_seed(iepoch)
+ np.random.seed(iepoch)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = learning_rate[iepoch]
+ lavg, lavg_per, nsum = 0, 0, 0
+ for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
+ inds = rperm[ibatch : ibatch + batch_size * nnoise]
+ if train_data is None:
+ imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
+ lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
+ else:
+ imgs = [train_data[i][:nchan] for i in inds]
+ lbls = [train_labels[i][1:] for i in inds]
+ #inoise = nbatch % nnoise
+ rnoise = np.random.permutation(nnoise)
+ for i, inoise in enumerate(rnoise):
+ if i * batch_size < len(imgs):
+ imgi, lbli, scale = random_rotate_and_resize_noise(
+ imgs[i * batch_size : (i + 1) * batch_size],
+ lbls[i * batch_size : (i + 1) * batch_size],
+ diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
+ poisson=poisson[inoise],
+ beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
+ downsample=downsample[inoise], uniform_blur=uniform_blur,
+ diam_mean=diam_mean, ds_max=ds_max,
+ device=device)
+ if i == 0:
+ img = imgi
+ lbl = lbli
+ else:
+ img = torch.cat((img, imgi), axis=0)
+ lbl = torch.cat((lbl, lbli), axis=0)
+
+ if nnoise > 0:
+ iperm = np.random.permutation(img.shape[0])
+ img, lbl = img[iperm], lbl[iperm]
+
+ for i in range(nnoise):
+ optimizer.zero_grad()
+ imgi = img[i * batch_size: (i + 1) * batch_size]
+ lbli = lbl[i * batch_size: (i + 1) * batch_size]
+ if imgi.shape[0] > 0:
+ loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
+ img=imgi[:, nchan:], lbl=lbli, lam=lam)
+ loss.backward()
+ optimizer.step()
+ lavg += loss.item() * imgi.shape[0]
+ lavg_per += loss_per.item() * imgi.shape[0]
+
+ nsum += len(img)
+ nbatch += 1
+
+ if iepoch % 5 == 0 or iepoch < 10:
+ lavg = lavg / nsum
+ lavg_per = lavg_per / nsum
+ if test_data is not None or test_files is not None:
+ lavgt, nsum = 0., 0
+ np.random.seed(42)
+ rperm = np.random.choice(np.arange(0, nimg_test),
+ size=(nimg_test_per_epoch,), p=test_probs)
+ inoise = iepoch % nnoise
+ torch.manual_seed(inoise)
+ for ibatch in range(0, nimg_test_per_epoch, batch_size):
+ inds = rperm[ibatch:ibatch + batch_size]
+ if test_data is None:
+ imgs = [
+ np.maximum(0,
+ io.imread(test_files[i])[:nchan]) for i in inds
+ ]
+ lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
+ else:
+ imgs = [test_data[i][:nchan] for i in inds]
+ lbls = [test_labels[i][1:] for i in inds]
+ img, lbl, scale = random_rotate_and_resize_noise(
+ imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
+ beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
+ iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
+ diam_mean=diam_mean, ds_max=ds_max, device=device)
+ loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
+ img=img[:, nchan:], lbl=lbl, lam=lam)
+
+ lavgt += loss.item() * img.shape[0]
+ nsum += len(img)
+ lavgt = lavgt / nsum
+ denoise_logger.info(
+ "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
+ % (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
+ learning_rate[iepoch]))
+ test_losses.append(lavgt)
+ else:
+ denoise_logger.info(
+ "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
+ (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
+ train_losses.append(lavg)
+
+ if save_path is not None:
+ if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
+ if save_each: #separate files as model progresses
+ filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
+ else:
+ filename0 = filename
+ denoise_logger.info(f"saving network parameters to {filename0}")
+ net.save_model(filename0)
+ else:
+ filename = save_path
+
+ return filename, train_losses, test_losses
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description="cellpose parameters")
+
+ input_img_args = parser.add_argument_group("input image arguments")
+ input_img_args.add_argument("--dir", default=[], type=str,
+ help="folder containing data to run or train on.")
+ input_img_args.add_argument("--img_filter", default=[], type=str,
+ help="end string for images to run on")
+
+ model_args = parser.add_argument_group("model arguments")
+ model_args.add_argument("--pretrained_model", default=[], type=str,
+ help="pretrained denoising model")
+
+ training_args = parser.add_argument_group("training arguments")
+ training_args.add_argument("--test_dir", default=[], type=str,
+ help="folder containing test data (optional)")
+ training_args.add_argument("--file_list", default=[], type=str,
+ help="npy file containing list of train and test files")
+ training_args.add_argument("--seg_model_type", default="cyto2", type=str,
+ help="model to use for seg training loss")
+ training_args.add_argument(
+ "--noise_type", default=[], type=str,
+ help="noise type to use (if input, then other noise params are ignored)")
+ training_args.add_argument("--poisson", default=0.8, type=float,
+ help="fraction of images to add poisson noise to")
+ training_args.add_argument("--beta", default=0.7, type=float,
+ help="scale of poisson noise")
+ training_args.add_argument("--blur", default=0., type=float,
+ help="fraction of images to blur")
+ training_args.add_argument("--gblur", default=1.0, type=float,
+ help="scale of gaussian blurring stddev")
+ training_args.add_argument("--downsample", default=0., type=float,
+ help="fraction of images to downsample")
+ training_args.add_argument("--ds_max", default=7, type=int,
+ help="max downsampling factor")
+ training_args.add_argument("--lam_per", default=1.0, type=float,
+ help="weighting of perceptual loss")
+ training_args.add_argument("--lam_seg", default=1.5, type=float,
+ help="weighting of segmentation loss")
+ training_args.add_argument("--lam_rec", default=0., type=float,
+ help="weighting of reconstruction loss")
+ training_args.add_argument(
+ "--diam_mean", default=30., type=float, help=
+ "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
+ )
+ training_args.add_argument("--learning_rate", default=0.001, type=float,
+ help="learning rate. Default: %(default)s")
+ training_args.add_argument("--n_epochs", default=2000, type=int,
+ help="number of epochs. Default: %(default)s")
+ training_args.add_argument(
+ "--save_each", default=False, action="store_true",
+ help="save each epoch as separate model")
+ training_args.add_argument(
+ "--nimg_per_epoch", default=0, type=int,
+ help="number of images per epoch. Default is length of training images")
+ training_args.add_argument(
+ "--nimg_test_per_epoch", default=0, type=int,
+ help="number of test images per epoch. Default is length of testing images")
+
+ io.logger_setup()
+
+ args = parser.parse_args()
+ lams = [args.lam_per, args.lam_seg, args.lam_rec]
+ print("lam", lams)
+
+ if len(args.noise_type) > 0:
+ noise_type = args.noise_type
+ uniform_blur = False
+ iso = True
+ if noise_type == "poisson":
+ poisson = 0.8
+ blur = 0.
+ downsample = 0.
+ beta = 0.7
+ gblur = 1.0
+ elif noise_type == "blur_expr":
+ poisson = 0.8
+ blur = 0.8
+ downsample = 0.
+ beta = 0.1
+ gblur = 0.5
+ elif noise_type == "blur":
+ poisson = 0.8
+ blur = 0.8
+ downsample = 0.
+ beta = 0.1
+ gblur = 10.0
+ uniform_blur = True
+ elif noise_type == "downsample_expr":
+ poisson = 0.8
+ blur = 0.8
+ downsample = 0.8
+ beta = 0.03
+ gblur = 1.0
+ elif noise_type == "downsample":
+ poisson = 0.8
+ blur = 0.8
+ downsample = 0.8
+ beta = 0.03
+ gblur = 5.0
+ uniform_blur = True
+ elif noise_type == "all":
+ poisson = [0.8, 0.8, 0.8]
+ blur = [0., 0.8, 0.8]
+ downsample = [0., 0., 0.8]
+ beta = [0.7, 0.1, 0.03]
+ gblur = [0., 10.0, 5.0]
+ uniform_blur = True
+ elif noise_type == "aniso":
+ poisson = 0.8
+ blur = 0.8
+ downsample = 0.8
+ beta = 0.1
+ gblur = args.ds_max * 1.5
+ iso = False
+ else:
+ raise ValueError(f"{noise_type} noise_type is not supported")
+ else:
+ poisson, beta = args.poisson, args.beta
+ blur, gblur = args.blur, args.gblur
+ downsample = args.downsample
+
+ pretrained_model = None if len(
+ args.pretrained_model) == 0 else args.pretrained_model
+ model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
+ pretrained_model=pretrained_model)
+
+ train_data, labels, train_files, train_probs = None, None, None, None
+ test_data, test_labels, test_files, test_probs = None, None, None, None
+ if len(args.file_list) == 0:
+ output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
+ images, labels, image_names, test_images, test_labels, image_names_test = output
+ train_data = []
+ for i in range(len(images)):
+ img = images[i].astype("float32")
+ if img.ndim > 2:
+ img = img[0]
+ train_data.append(
+ np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
+ if len(args.test_dir) > 0:
+ test_data = []
+ for i in range(len(test_images)):
+ img = test_images[i].astype("float32")
+ if img.ndim > 2:
+ img = img[0]
+ test_data.append(
+ np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
+ save_path = os.path.join(args.dir, "../models/")
+ else:
+ root = args.dir
+ denoise_logger.info(
+ ">>> using file_list (assumes images are normalized and have flows!)")
+ dat = np.load(args.file_list, allow_pickle=True).item()
+ train_files = dat["train_files"]
+ test_files = dat["test_files"]
+ train_probs = dat["train_probs"] if "train_probs" in dat else None
+ test_probs = dat["test_probs"] if "test_probs" in dat else None
+ if str(train_files[0])[:len(str(root))] != str(root):
+ for i in range(len(train_files)):
+ new_path = root / Path(*train_files[i].parts[-3:])
+ if i == 0:
+ print(f"changing path from {train_files[i]} to {new_path}")
+ train_files[i] = new_path
+
+ for i in range(len(test_files)):
+ new_path = root / Path(*test_files[i].parts[-3:])
+ test_files[i] = new_path
+ save_path = os.path.join(args.dir, "models/")
+
+ os.makedirs(save_path, exist_ok=True)
+
+ nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
+ nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
+
+ model_path = train(
+ model.net, train_data=train_data, train_labels=labels, train_files=train_files,
+ test_data=test_data, test_labels=test_labels, test_files=test_files,
+ train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
+ blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
+ iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
+ learning_rate=args.learning_rate,
+ lam=lams,
+ seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
+ nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
+
+
+def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
+ poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
+ save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
+ momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
+ nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
+ model_name=None):
+ """ train function uses loss function model.loss_fn in models.py
+
+ (data should already be normalized)
+
+ """
+
+ d = datetime.datetime.now()
+
+ model.n_epochs = n_epochs
+ if isinstance(learning_rate, (list, np.ndarray)):
+ if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
+ raise ValueError("learning_rate.ndim must equal 1")
+ elif len(learning_rate) != n_epochs:
+ raise ValueError(
+ "if learning_rate given as list or np.ndarray it must have length n_epochs"
+ )
+ model.learning_rate = learning_rate
+ model.learning_rate_const = mode(learning_rate)[0][0]
+ else:
+ model.learning_rate_const = learning_rate
+ # set learning rate schedule
+ if SGD:
+ LR = np.linspace(0, model.learning_rate_const, 10)
+ if model.n_epochs > 250:
+ LR = np.append(
+ LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
+ for i in range(10):
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
+ else:
+ LR = np.append(
+ LR,
+ model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
+ else:
+ LR = model.learning_rate_const * np.ones(model.n_epochs)
+ model.learning_rate = LR
+
+ model.batch_size = batch_size
+ model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
+ model._set_criterion()
+
+ nimg = len(train_data)
+
+ # compute average cell diameter
+ if diameter is None:
+ diam_train = np.array(
+ [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
+ diam_train_mean = diam_train[diam_train > 0].mean()
+ model.diam_labels = diam_train_mean
+ if rescale:
+ diam_train[diam_train < 5] = 5.
+ if test_data is not None:
+ diam_test = np.array([
+ utils.diameters(test_labels[k][0])[0]
+ for k in range(len(test_labels))
+ ])
+ diam_test[diam_test < 5] = 5.
+ denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
+ elif rescale:
+ diam_train_mean = diameter
+ model.diam_labels = diameter
+ denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
+ diam_train = diameter * np.ones(len(train_labels), "float32")
+ if test_data is not None:
+ diam_test = diameter * np.ones(len(test_labels), "float32")
+
+ denoise_logger.info(
+ f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
+ )
+ model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
+
+ nchan = train_data[0].shape[0]
+ denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
+ denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
+ (model.learning_rate_const, model.batch_size, weight_decay))
+
+ if test_data is not None:
+ denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
+ else:
+ denoise_logger.info(f">>>> ntrain = {nimg}")
+
+ tic = time.time()
+
+ lavg, nsum = 0, 0
+
+ if save_path is not None:
+ _, file_label = os.path.split(save_path)
+ file_path = os.path.join(save_path, "models/")
+
+ if not os.path.exists(file_path):
+ os.makedirs(file_path)
+ else:
+ denoise_logger.warning("WARNING: no save_path given, model not saving")
+
+ ksave = 0
+
+ # cannot train with mkldnn
+ model.net.mkldnn = False
+
+ # get indices for each epoch for training
+ np.random.seed(0)
+ inds_all = np.zeros((0,), "int32")
+ if nimg_per_epoch is None or nimg > nimg_per_epoch:
+ nimg_per_epoch = nimg
+ denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
+ while len(inds_all) < n_epochs * nimg_per_epoch:
+ rperm = np.random.permutation(nimg)
+ inds_all = np.hstack((inds_all, rperm))
+
+ for iepoch in range(model.n_epochs):
+ if SGD:
+ model._set_learning_rate(model.learning_rate[iepoch])
+ np.random.seed(iepoch)
+ rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
+ for ibatch in range(0, nimg_per_epoch, batch_size):
+ inds = rperm[ibatch:ibatch + batch_size]
+ imgi, lbl, scale = random_rotate_and_resize_noise(
+ [train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
+ poisson=poisson, blur=blur, downsample=downsample,
+ diams=diam_train[inds], diam_mean=model.diam_mean)
+ imgi = imgi[:, :1] # keep noisy only
+ if z_masking:
+ nc = imgi.shape[1]
+ nb = imgi.shape[0]
+ ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
+ nc // 2 - 1, size=nb))
+ ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
+ nc // 2 - 1, size=nb))
+ for b in range(nb):
+ imgi[b, :ncmin[b]] = 0
+ imgi[b, ncmax[b]:] = 0
+
+ train_loss = model._train_step(imgi, lbl)
+ lavg += train_loss
+ nsum += len(imgi)
+
+ if iepoch % 10 == 0 or iepoch == 5:
+ lavg = lavg / nsum
+ if test_data is not None:
+ lavgt, nsum = 0., 0
+ np.random.seed(42)
+ rperm = np.arange(0, len(test_data), 1, int)
+ for ibatch in range(0, len(test_data), batch_size):
+ inds = rperm[ibatch:ibatch + batch_size]
+ imgi, lbl, scale = random_rotate_and_resize_noise(
+ [test_data[i] for i in inds],
+ [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
+ downsample=downsample, diams=diam_test[inds],
+ diam_mean=model.diam_mean)
+ imgi = imgi[:, :1] # keep noisy only
+ test_loss = model._test_eval(imgi, lbl)
+ lavgt += test_loss
+ nsum += len(imgi)
+
+ denoise_logger.info(
+ "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
+ (iepoch, time.time() - tic, lavg, lavgt / nsum,
+ model.learning_rate[iepoch]))
+ else:
+ denoise_logger.info(
+ "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
+ (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
+
+ lavg, nsum = 0, 0
+
+ if save_path is not None:
+ if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
+ # save model at the end
+ if save_each: #separate files as model progresses
+ if model_name is None:
+ filename = "{}_{}_{}_{}".format(
+ model.net_type, file_label,
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
+ else:
+ filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
+ else:
+ if model_name is None:
+ filename = "{}_{}_{}".format(model.net_type, file_label,
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
+ else:
+ filename = model_name
+ filename = os.path.join(file_path, filename)
+ ksave += 1
+ denoise_logger.info(f"saving network parameters to {filename}")
+ model.net.save_model(filename)
+ else:
+ filename = save_path
+
+ # reset to mkldnn if available
+ model.net.mkldnn = model.mkldnn
+ return filename
diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py
new file mode 100644
index 0000000000000000000000000000000000000000..db163ded6e76eac60ad8a2a563d7505083899f2b
--- /dev/null
+++ b/cellpose/dynamics.py
@@ -0,0 +1,1041 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import time, os
+from scipy.ndimage import maximum_filter1d, find_objects, center_of_mass
+import torch
+import numpy as np
+import tifffile
+from tqdm import trange
+from numba import njit, prange, float32, int32, vectorize
+import cv2
+import fastremap
+
+import logging
+
+dynamics_logger = logging.getLogger(__name__)
+
+from . import utils, metrics, transforms
+
+import torch
+from torch import optim, nn
+import torch.nn.functional as F
+from . import resnet_torch
+
+@njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True)
+def _extend_centers(T, y, x, ymed, xmed, Lx, niter):
+ """Run diffusion from the center of the mask on the mask pixels.
+
+ Args:
+ T (numpy.ndarray): Array of shape (Ly * Lx) where diffusion is run.
+ y (numpy.ndarray): Array of y-coordinates of pixels inside the mask.
+ x (numpy.ndarray): Array of x-coordinates of pixels inside the mask.
+ ymed (int): Center of the mask in the y-coordinate.
+ xmed (int): Center of the mask in the x-coordinate.
+ Lx (int): Size of the x-dimension of the masks.
+ niter (int): Number of iterations to run diffusion.
+
+ Returns:
+ numpy.ndarray: Array of shape (Ly * Lx) representing the amount of diffused particles at each pixel.
+ """
+ for t in range(niter):
+ T[ymed * Lx + xmed] += 1
+ T[y * Lx +
+ x] = 1 / 9. * (T[y * Lx + x] + T[(y - 1) * Lx + x] + T[(y + 1) * Lx + x] +
+ T[y * Lx + x - 1] + T[y * Lx + x + 1] +
+ T[(y - 1) * Lx + x - 1] + T[(y - 1) * Lx + x + 1] +
+ T[(y + 1) * Lx + x - 1] + T[(y + 1) * Lx + x + 1])
+ return T
+
+
+def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
+ device=torch.device("cpu")):
+ """Runs diffusion on GPU to generate flows for training images or quality control.
+
+ Args:
+ neighbors (torch.Tensor): 9 x pixels in masks.
+ meds (torch.Tensor): Mask centers.
+ isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
+ shape (tuple): Shape of the tensor.
+ n_iter (int, optional): Number of iterations. Defaults to 200.
+ device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
+
+ Returns:
+ torch.Tensor: Generated flows.
+
+ """
+ if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
+ T = torch.zeros(shape, dtype=torch.float, device=device)
+ else:
+ T = torch.zeros(shape, dtype=torch.double, device=device)
+
+ for i in range(n_iter):
+ T[tuple(meds.T)] += 1
+ Tneigh = T[tuple(neighbors)]
+ Tneigh *= isneighbor
+ T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
+ del meds, isneighbor, Tneigh
+
+ if T.ndim == 2:
+ grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]]
+ del neighbors
+ dy = grads[0] - grads[1]
+ dx = grads[2] - grads[3]
+ del grads
+ mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
+ else:
+ grads = T[tuple(neighbors[:, 1:])]
+ del neighbors
+ dz = grads[0] - grads[1]
+ dy = grads[2] - grads[3]
+ dx = grads[4] - grads[5]
+ del grads
+ mu_torch = np.stack(
+ (dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
+ return mu_torch
+
+@njit(nogil=True)
+def get_centers(masks, slices):
+ """
+ Get the centers of the masks and their extents.
+
+ Args:
+ masks (ndarray): The labeled masks.
+ slices (ndarray): The slices of the masks.
+
+ Returns:
+ A tuple containing the centers of the masks and the extents of the masks.
+ """
+ centers = np.zeros((len(slices), 2), "int32")
+ ext = np.zeros((len(slices),), "int32")
+ for p in prange(len(slices)):
+ si = slices[p]
+ i = si[0]
+ sr, sc = si[1:3], si[3:5]
+ # find center in slice around mask
+ yi, xi = np.nonzero(masks[sr[0]:sr[-1], sc[0]:sc[-1]] == (i + 1))
+ ymed = yi.mean()
+ xmed = xi.mean()
+ # center is closest point to (ymed, xmed) within mask
+ imin = ((xi - xmed)**2 + (yi - ymed)**2).argmin()
+ ymed = yi[imin] + sr[0]
+ xmed = xi[imin] + sc[0]
+ centers[p] = np.array([ymed, xmed])
+ ext[p] = (sr[-1] - sr[0]) + (sc[-1] - sc[0]) + 2
+ return centers, ext
+
+
+def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
+ """Convert masks to flows using diffusion from center pixel.
+
+ Center of masks where diffusion starts is defined by pixel closest to median within the mask.
+
+ Args:
+ masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
+ device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu").
+ niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
+
+ Returns:
+ np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
+
+
+ Returns:
+ A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
+ meds_p are cell centers.
+ """
+ if device is None:
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
+
+ Ly0, Lx0 = masks.shape
+ Ly, Lx = Ly0 + 2, Lx0 + 2
+
+ masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
+ masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
+ shape = masks_padded.shape
+
+ ### get mask pixel neighbors
+ y, x = torch.nonzero(masks_padded, as_tuple=True)
+ y = y.int()
+ x = x.int()
+ neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.long, device=device)
+ yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]]
+ for i in range(9):
+ neighbors[0, i] = y + yxi[0][i]
+ neighbors[1, i] = x + yxi[1][i]
+ isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device)
+ m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]]
+ for i in range(1, 9):
+ isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0
+ del m0, masks_padded
+
+ ### get center-of-mass within cell
+ slices = find_objects(masks)
+ # turn slices into array
+ slices = np.array([
+ np.array([i, si[0].start, si[0].stop, si[1].start, si[1].stop])
+ for i, si in enumerate(slices)
+ if si is not None
+ ])
+ centers, ext = get_centers(masks, slices)
+ meds_p = torch.from_numpy(centers).to(device).long()
+ meds_p += 1 # for padding
+
+ ### run diffusion
+ n_iter = 2 * ext.max() if niter is None else niter
+ mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter,
+ device=device)
+ mu = mu.astype("float64")
+
+ # new normalization
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
+
+ # put into original image
+ mu0 = np.zeros((2, Ly0, Lx0))
+ mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
+
+ return mu0, meds_p.cpu().numpy() - 1
+
+
+def masks_to_flows_gpu_3d(masks, device=None, niter=None):
+ """Convert masks to flows using diffusion from center pixel.
+
+ Args:
+ masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
+ device (torch.device, optional): The device to run the computation on. Defaults to None.
+ niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
+
+ Returns:
+ np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
+
+ """
+ if device is None:
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
+
+ Lz0, Ly0, Lx0 = masks.shape
+ Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
+
+ masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
+ masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
+
+ # get mask pixel neighbors
+ z, y, x = torch.nonzero(masks_padded).T
+ neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
+ neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
+ neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
+
+ neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
+
+ # get mask centers
+ slices = find_objects(masks)
+
+ centers = np.zeros((masks.max(), 3), "int")
+ for i, si in enumerate(slices):
+ if si is not None:
+ sz, sy, sx = si
+ #lz, ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1
+ zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
+ zi = zi.astype(np.int32) + 1 # add padding
+ yi = yi.astype(np.int32) + 1 # add padding
+ xi = xi.astype(np.int32) + 1 # add padding
+ zmed = np.mean(zi)
+ ymed = np.mean(yi)
+ xmed = np.mean(xi)
+ imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
+ zmed = zi[imin]
+ ymed = yi[imin]
+ xmed = xi[imin]
+ centers[i, 0] = zmed + sz.start
+ centers[i, 1] = ymed + sy.start
+ centers[i, 2] = xmed + sx.start
+
+ # get neighbor validator (not all neighbors are in same mask)
+ neighbor_masks = masks_padded[tuple(neighbors)]
+ isneighbor = neighbor_masks == neighbor_masks[0]
+ ext = np.array(
+ [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
+ for sz, sy, sx in slices])
+ n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
+
+ # run diffusion
+ shape = masks_padded.shape
+ mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
+ device=device)
+ # normalize
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
+
+ # put into original image
+ mu0 = np.zeros((3, Lz0, Ly0, Lx0))
+ mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
+ return mu0
+
+
+def masks_to_flows_cpu(masks, niter=None, device=None):
+ """Convert masks to flows using diffusion from center pixel.
+
+ Center of masks where diffusion starts is defined to be the closest pixel to the mean of all pixels that is inside the mask.
+ Result of diffusion is converted into flows by computing the gradients of the diffusion density map.
+
+ Args:
+ masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels
+ niter (int, optional): Number of iterations for computing flows. Defaults to None.
+
+ Returns:
+ A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
+ meds_p are cell centers.
+ """
+ Ly, Lx = masks.shape
+ mu = np.zeros((2, Ly, Lx), np.float64)
+
+ slices = find_objects(masks)
+ meds = []
+ for i in prange(len(slices)):
+ si = slices[i]
+ if si is not None:
+ sr, sc = si
+ ly, lx = sr.stop - sr.start + 2, sc.stop - sc.start + 2
+ ### get center-of-mass within cell
+ y, x = np.nonzero(masks[sr, sc] == (i + 1))
+ y = y.astype(np.int32) + 1
+ x = x.astype(np.int32) + 1
+ ymed = y.mean()
+ xmed = x.mean()
+ imin = ((x - xmed)**2 + (y - ymed)**2).argmin()
+ xmed = x[imin]
+ ymed = y[imin]
+
+ n_iter = 2 * np.int32(ly + lx) if niter is None else niter
+ T = np.zeros((ly) * (lx), np.float64)
+ T = _extend_centers(T, y, x, ymed, xmed, np.int32(lx), np.int32(n_iter))
+ dy = T[(y + 1) * lx + x] - T[(y - 1) * lx + x]
+ dx = T[y * lx + x + 1] - T[y * lx + x - 1]
+ mu[:, sr.start + y - 1, sc.start + x - 1] = np.stack((dy, dx))
+ meds.append([ymed - 1, xmed - 1])
+
+ # new normalization
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
+
+ return mu, meds
+
+
+def masks_to_flows(masks, device=torch.device("cpu"), niter=None):
+ """Convert masks to flows using diffusion from center pixel.
+
+ Center of masks where diffusion starts is defined to be the closest pixel to the mean of all pixels that is inside the mask.
+ Result of diffusion is converted into flows by computing the gradients of the diffusion density map.
+
+ Args:
+ masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels
+
+ Returns:
+ np.ndarray: mu is float 3D or 4D array of flows in (Z)XY.
+ """
+ if masks.max() == 0:
+ dynamics_logger.warning("empty masks!")
+ return np.zeros((2, *masks.shape), "float32")
+
+ if device.type == "cuda" or device.type == "mps":
+ masks_to_flows_device = masks_to_flows_gpu
+ else:
+ masks_to_flows_device = masks_to_flows_cpu
+
+ if masks.ndim == 3:
+ Lz, Ly, Lx = masks.shape
+ mu = np.zeros((3, Lz, Ly, Lx), np.float32)
+ for z in range(Lz):
+ mu0 = masks_to_flows_device(masks[z], device=device, niter=niter)[0]
+ mu[[1, 2], z] += mu0
+ for y in range(Ly):
+ mu0 = masks_to_flows_device(masks[:, y], device=device, niter=niter)[0]
+ mu[[0, 2], :, y] += mu0
+ for x in range(Lx):
+ mu0 = masks_to_flows_device(masks[:, :, x], device=device, niter=niter)[0]
+ mu[[0, 1], :, :, x] += mu0
+ return mu
+ elif masks.ndim == 2:
+ mu, mu_c = masks_to_flows_device(masks, device=device, niter=niter)
+ return mu
+
+ else:
+ raise ValueError("masks_to_flows only takes 2D or 3D arrays")
+
+
+def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
+ return_flows=True):
+ """Converts labels (list of masks or flows) to flows for training model.
+
+ Args:
+ labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
+ it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
+ is used to create flows and cell probabilities.
+ files (list of str, optional): The files to save the flows to. If provided, flows are saved to
+ files to be reused. Defaults to None.
+ device (str, optional): The device to use for computation. Defaults to None.
+ redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
+ niter (int, optional): The number of iterations for computing flows. Defaults to None.
+
+ Returns:
+ list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
+ flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
+ and flows[k][4] is heat distribution.
+ """
+ nimg = len(labels)
+ if labels[0].ndim < 3:
+ labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
+
+ flows = []
+ # flows need to be recomputed
+ if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
+ dynamics_logger.info("computing flows for labels")
+
+ # compute flows; labels are fixed here to be unique, so they need to be passed back
+ # make sure labels are unique!
+ labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
+ iterator = trange if nimg > 1 else range
+ for n in iterator(nimg):
+ labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
+ vecn = masks_to_flows(labels[n][0].astype(int), device=device, niter=niter)
+
+ # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
+ flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
+ axis=0).astype(np.float32)
+ if files is not None:
+ file_name = os.path.splitext(files[n])[0]
+ tifffile.imwrite((file_name + "_flows.tif").replace("/y/", "/flows/"), flow)
+ if return_flows:
+ flows.append(flow)
+ else:
+ dynamics_logger.info("flows precomputed")
+ if return_flows:
+ flows = [labels[n].astype(np.float32) for n in range(nimg)]
+ return flows
+
+
+@njit([
+ "(int16[:,:,:], float32[:], float32[:], float32[:,:])",
+ "(float32[:,:,:], float32[:], float32[:], float32[:,:])"
+], cache=True)
+def map_coordinates(I, yc, xc, Y):
+ """
+ Bilinear interpolation of image "I" in-place with y-coordinates yc and x-coordinates xc to Y.
+
+ Args:
+ I (numpy.ndarray): Input image of shape (C, Ly, Lx).
+ yc (numpy.ndarray): New y-coordinates.
+ xc (numpy.ndarray): New x-coordinates.
+ Y (numpy.ndarray): Output array of shape (C, ni).
+
+ Returns:
+ None
+ """
+ C, Ly, Lx = I.shape
+ yc_floor = yc.astype(np.int32)
+ xc_floor = xc.astype(np.int32)
+ yc = yc - yc_floor
+ xc = xc - xc_floor
+ for i in range(yc_floor.shape[0]):
+ yf = min(Ly - 1, max(0, yc_floor[i]))
+ xf = min(Lx - 1, max(0, xc_floor[i]))
+ yf1 = min(Ly - 1, yf + 1)
+ xf1 = min(Lx - 1, xf + 1)
+ y = yc[i]
+ x = xc[i]
+ for c in range(C):
+ Y[c, i] = (np.float32(I[c, yf, xf]) * (1 - y) * (1 - x) +
+ np.float32(I[c, yf, xf1]) * (1 - y) * x +
+ np.float32(I[c, yf1, xf]) * y * (1 - x) +
+ np.float32(I[c, yf1, xf1]) * y * x)
+
+
+def steps_interp(dP, inds, niter, device=torch.device("cpu")):
+ """ Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values.
+
+ Euler integration of dynamics dP for niter steps.
+
+ Args:
+ p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations.
+ dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field.
+ niter (int): Number of iterations to perform.
+ device (torch.device, optional): Device to use for computation. Defaults to None.
+
+ Returns:
+ numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations.
+
+ Raises:
+ None
+
+ """
+
+ shape = dP.shape[1:]
+ ndim = len(shape)
+ if (device.type == "cuda" or device.type == "mps") or ndim==3:
+ pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device)
+ im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device)
+ # Y and X dimensions, flipped X-1, Y-1
+ # pt is [1 1 1 3 n_points]
+ for n in range(ndim):
+ if ndim==3:
+ pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
+ else:
+ pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
+ im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32)
+ shape = np.array(shape)[::-1].astype("float") - 1
+
+ # normalize pt between 0 and 1, normalize the flow
+ for k in range(ndim):
+ im[:, k] *= 2. / shape[k]
+ pt[..., k] /= shape[k]
+
+ # normalize to between -1 and 1
+ pt *= 2
+ pt -= 1
+
+ # dynamics
+ for t in range(niter):
+ dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)
+ for k in range(ndim): #clamp the final pixel locations
+ pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.)
+
+ #undo the normalization from before, reverse order of operations
+ pt += 1
+ pt *= 0.5
+ for k in range(ndim):
+ pt[..., k] *= shape[k]
+
+ if ndim==3:
+ pt = pt[..., [2, 1, 0]].squeeze()
+ pt = pt.unsqueeze(0) if pt.ndim==1 else pt
+ return pt.T
+ else:
+ pt = pt[..., [1, 0]].squeeze()
+ pt = pt.unsqueeze(0) if pt.ndim==1 else pt
+ return pt.T
+
+ else:
+ p = np.zeros((ndim, len(inds[0])), "float32")
+ for n in range(ndim):
+ p[n] = inds[n]
+ dPt = np.zeros(p.shape, "float32")
+ for t in range(niter):
+ map_coordinates(dP, p[0], p[1], dPt)
+ for k in range(len(p)):
+ p[k] = np.minimum(shape[k] - 1, np.maximum(0, p[k] + dPt[k]))
+ return p
+
+@njit("(float32[:,:],float32[:,:,:,:], int32)", nogil=True)
+def steps3D(p, dP, niter):
+ """ Run dynamics of pixels to recover masks in 3D.
+
+ Euler integration of dynamics dP for niter steps.
+
+ Args:
+ p (np.ndarray): Pixels with cellprob > cellprob_threshold [3 x npts].
+ dP (np.ndarray): Flows [3 x Lz x Ly x Lx].
+ niter (int): Number of iterations of dynamics to run.
+
+ Returns:
+ np.ndarray: Final locations of each pixel after dynamics.
+ """
+ shape = dP.shape[1:]
+ for t in range(niter):
+ for j in range(p.shape[1]):
+ p0, p1, p2 = int(p[0, j]), int(p[1, j]), int(p[2, j])
+ step = dP[:, p0, p1, p2]
+ for k in range(3):
+ p[k, j] = min(shape[k] - 1, max(0, p[k, j] + step[k]))
+ return p
+
+@njit("(float32[:,:], float32[:,:,:], int32)", nogil=True)
+def steps2D(p, dP, niter):
+ """Run dynamics of pixels to recover masks in 2D.
+
+ Euler integration of dynamics dP for niter steps.
+
+ Args:
+ p (np.ndarray): Pixels with cellprob > cellprob_threshold [2 x npts].
+ dP (np.ndarray): Flows [2 x Ly x Lx].
+ niter (int): Number of iterations of dynamics to run.
+
+ Returns:
+ np.ndarray: Final locations of each pixel after dynamics.
+ """
+ shape = dP.shape[1:]
+ for t in range(niter):
+ for j in range(p.shape[1]):
+ # starting coordinates
+ p0, p1 = int(p[0, j]), int(p[1, j])
+ step = dP[:, p0, p1]
+ for k in range(p.shape[0]):
+ p[k, j] = min(shape[k] - 1, max(0, p[k, j] + step[k]))
+ return p
+
+def follow_flows(dP, inds, niter=200, interp=True, device=torch.device("cpu")):
+ """ Run dynamics to recover masks in 2D or 3D.
+
+ Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
+ are used (as defined by inds).
+
+ Args:
+ dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
+ mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
+ niter (int, optional): Number of iterations of dynamics to run. Default is 200.
+ interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
+ device (torch.device, optional): Device to use for computation. Default is None.
+
+ Returns:
+ A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
+ inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
+ """
+ shape = np.array(dP.shape[1:]).astype(np.int32)
+ ndim = len(inds)
+ niter = np.uint32(niter)
+
+ if interp:
+ p = steps_interp(dP, inds, niter, device=device)
+ else:
+ p = np.zeros((ndim, len(inds[0])), "float32")
+ for n in range(ndim):
+ p[n] = inds[n]
+ steps_fcn = steps2D if ndim == 2 else steps3D
+ p = steps_fcn(p, dP, niter)
+
+ return p
+
+
+def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
+ """Remove masks which have inconsistent flows.
+
+ Uses metrics.flow_error to compute flows from predicted masks
+ and compare flows to predicted flows from the network. Discards
+ masks with flow errors greater than the threshold.
+
+ Args:
+ masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels,
+ size [Ly x Lx] or [Lz x Ly x Lx].
+ flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
+ threshold (float, optional): Masks with flow error greater than threshold are discarded.
+ Default is 0.4.
+
+ Returns:
+ masks (int, 2D or 3D array): Masks with inconsistent flow masks removed,
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
+ """
+ device0 = device
+ if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"):
+
+ major_version, minor_version = torch.__version__.split(".")[:2]
+ torch.cuda.empty_cache()
+ if major_version == "1" and int(minor_version) < 10:
+ # for PyTorch version lower than 1.10
+ def mem_info():
+ total_mem = torch.cuda.get_device_properties(device0.index).total_memory
+ used_mem = torch.cuda.memory_allocated(device0.index)
+ free_mem = total_mem - used_mem
+ return total_mem, free_mem
+ else:
+ # for PyTorch version 1.10 and above
+ def mem_info():
+ free_mem, total_mem = torch.cuda.mem_get_info(device0.index)
+ return total_mem, free_mem
+ total_mem, free_mem = mem_info()
+ if masks.size * 32 > free_mem:
+ dynamics_logger.warning(
+ "WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold"
+ )
+ dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow")
+ device0 = torch.device("cpu")
+
+ merrors, _ = metrics.flow_error(masks, flows, device0)
+ badi = 1 + (merrors > threshold).nonzero()[0]
+ masks[np.isin(masks, badi)] = 0
+ return masks
+
+
+def max_pool3d(h, kernel_size=5):
+ """ memory efficient max_pool thanks to Mark Kittisopikul
+
+ for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3
+
+ """
+ _, nd, ny, nx = h.shape
+ m = h.clone().detach()
+ kruns, k0 = kernel_size // 2, 1
+ for k in range(kruns):
+ for d in range(-k0, k0+1):
+ for y in range(-k0, k0+1):
+ for x in range(-k0, k0+1):
+ mv = m[:, max(-d,0):min(nd-d,nd), max(-y,0):min(ny-y,ny), max(-x,0):min(nx-x,nx)]
+ hv = h[:, max(d,0):min(nd+d,nd), max(y,0):min(ny+y,ny), max(x,0):min(nx+x,nx)]
+ torch.maximum(mv, hv, out=mv)
+ return m
+
+def max_pool2d(h, kernel_size=5):
+ """ memory efficient max_pool thanks to Mark Kittisopikul """
+ _, ny, nx = h.shape
+ m = h.clone().detach()
+ k0 = kernel_size // 2
+ for y in range(-k0, k0+1):
+ for x in range(-k0, k0+1):
+ mv = m[:, max(-y,0):min(ny-y,ny), max(-x,0):min(nx-x,nx)]
+ hv = h[:, max(y,0):min(ny+y,ny), max(x,0):min(nx+x,nx)]
+ torch.maximum(mv, hv, out=mv)
+ return m
+
+def max_pool1d(h, kernel_size=5, axis=1, out=None):
+ """ memory efficient max_pool thanks to Mark Kittisopikul
+
+ for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3
+
+ """
+ if out is None:
+ out = h.clone()
+ else:
+ out.copy_(h)
+
+ nd = h.shape[axis]
+ k0 = kernel_size // 2
+ for d in range(-k0, k0+1):
+ if axis==1:
+ mv = out[:, max(-d,0):min(nd-d,nd)]
+ hv = h[:, max(d,0):min(nd+d,nd)]
+ elif axis==2:
+ mv = out[:, :, max(-d,0):min(nd-d,nd)]
+ hv = h[:, :, max(d,0):min(nd+d,nd)]
+ elif axis==3:
+ mv = out[:, :, :, max(-d,0):min(nd-d,nd)]
+ hv = h[:, :, :, max(d,0):min(nd+d,nd)]
+ torch.maximum(mv, hv, out=mv)
+ return out
+
+def max_pool_nd(h, kernel_size=5):
+ """ memory efficient max_pool in 2d or 3d """
+ ndim = h.ndim - 1
+ hmax = max_pool1d(h, kernel_size=kernel_size, axis=1)
+ hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2)
+ if ndim==2:
+ del hmax
+ return hmax2
+ else:
+ hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax)
+ del hmax2
+ return hmax
+
+# from torch.nn.functional import max_pool2d
+def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
+ """Create masks using pixel convergence after running dynamics.
+
+ Makes a histogram of final pixel locations p, initializes masks
+ at peaks of histogram and extends the masks from the peaks so that
+ they include all pixels with more than 2 final pixels p. Discards
+ masks with flow errors greater than the threshold.
+
+ Parameters:
+ p (float32, 3D or 4D array): Final locations of each pixel after dynamics,
+ size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
+ iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
+ iscell False to stay in their original location.
+ rpad (int, optional): Histogram edge padding. Default is 20.
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
+ total image size are removed. Default is 0.4.
+
+ Returns:
+ M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
+ """
+
+ ndim = len(shape0)
+ device = pt.device
+
+ rpad = 20
+ pt += rpad
+ pt = torch.clamp(pt, min=0)
+ for i in range(len(pt)):
+ pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1)
+
+ # # add extra padding to make divisible by 5
+ # shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int))
+ shape = tuple(np.array(shape0) + 2*rpad)
+
+ # sparse coo torch
+ coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int),
+ shape)
+ h1 = coo.to_dense()
+ del coo
+
+ hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5)
+ hmax1 = hmax1.squeeze()
+ seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10))
+ del hmax1
+ if len(seeds1) == 0:
+ dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.")
+ return np.zeros(shape0, dtype="uint16")
+
+ npts = h1[tuple(seeds1.T)]
+ isort1 = npts.argsort()
+ seeds1 = seeds1[isort1]
+
+ n_seeds = len(seeds1)
+ h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
+ for k in range(n_seeds):
+ slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)])
+ h_slc[k] = h1[slc]
+ del h1
+ seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
+ if ndim==2:
+ seed_masks[:,5,5] = 1
+ else:
+ seed_masks[:,5,5,5] = 1
+
+ for iter in range(5):
+ # extend
+ seed_masks = max_pool_nd(seed_masks, kernel_size=3)
+ seed_masks *= h_slc > 2
+ del h_slc
+ seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T)
+ for k in range(n_seeds)]
+ del seed_masks
+
+ dtype = torch.int32 if n_seeds < 2**16 else torch.int64
+ M1 = torch.zeros(shape, dtype=dtype, device=device)
+ for k in range(n_seeds):
+ M1[seeds_new[k]] = 1 + k
+
+ M1 = M1[tuple(pt.long())]
+ M1 = M1.cpu().numpy()
+
+ dtype = "uint16" if n_seeds < 2**16 else "uint32"
+ M0 = np.zeros(shape0, dtype=dtype)
+ M0[inds] = M1
+
+ # remove big masks
+ uniq, counts = fastremap.unique(M0, return_counts=True)
+ big = np.prod(shape0) * max_size_fraction
+ bigc = uniq[counts > big]
+ if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
+ M0 = fastremap.mask(M0, bigc)
+ fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
+ M0 = M0.reshape(tuple(shape0))
+
+ #print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
+ return M0
+
+
+def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0,
+ flow_threshold=0.4, interp=True, do_3D=False, min_size=15,
+ max_size_fraction=0.4, resize=None, device=torch.device("cpu")):
+ """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
+
+ Args:
+ dP (numpy.ndarray): The dynamics flow field array.
+ cellprob (numpy.ndarray): The cell probability array.
+ p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
+ niter (int, optional): The number of iterations for mask computation. Defaults to 200.
+ cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
+ flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
+ interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
+ do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
+ min_size (int, optional): The minimum size of the masks. Defaults to 15.
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
+ total image size are removed. Default is 0.4.
+ resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
+ device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
+
+ Returns:
+ tuple: A tuple containing the computed masks and the final pixel locations.
+ """
+ mask = compute_masks(dP, cellprob, niter=niter,
+ cellprob_threshold=cellprob_threshold,
+ flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
+ max_size_fraction=max_size_fraction,
+ device=device)
+
+ if resize is not None:
+ if len(resize) == 2:
+ mask = transforms.resize_image(mask, resize[0], resize[1], no_channels=True,
+ interpolation=cv2.INTER_NEAREST)
+ else:
+ Lz, Ly, Lx = resize
+ if mask.shape[0] != Lz or mask.shape[1] != Ly:
+ dynamics_logger.info("resizing 3D masks to original image size")
+ if mask.shape[1] != Ly:
+ mask = transforms.resize_image(mask, Ly=Ly, Lx=Lx,
+ no_channels=True,
+ interpolation=cv2.INTER_NEAREST)
+ if mask.shape[0] != Lz:
+ mask = transforms.resize_image(mask.transpose(1,0,2),
+ Ly=Lz, Lx=Lx,
+ no_channels=True,
+ interpolation=cv2.INTER_NEAREST).transpose(1,0,2)
+
+ mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
+
+ return mask
+
+def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
+ flow_threshold=0.4, interp=True, do_3D=False, min_size=-1,
+ max_size_fraction=0.4, device=torch.device("cpu")):
+ """Compute masks using dynamics from dP and cellprob.
+
+ Args:
+ dP (numpy.ndarray): The dynamics flow field array.
+ cellprob (numpy.ndarray): The cell probability array.
+ p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
+ niter (int, optional): The number of iterations for mask computation. Defaults to 200.
+ cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
+ flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
+ interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
+ do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
+ min_size (int, optional): The minimum size of the masks. Defaults to 15.
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
+ total image size are removed. Default is 0.4.
+ device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
+
+ Returns:
+ tuple: A tuple containing the computed masks and the final pixel locations.
+ """
+
+ if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels
+ inds = np.nonzero(cellprob > cellprob_threshold)
+ if len(inds[0]) == 0:
+ dynamics_logger.info("No cell pixels found.")
+ shape = cellprob.shape
+ mask = np.zeros(shape, "uint16")
+ return mask
+
+ p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5.,
+ inds=inds, niter=niter, interp=interp,
+ device=device)
+ if not torch.is_tensor(p_final):
+ p_final = torch.from_numpy(p_final).to(device, dtype=torch.int)
+ else:
+ p_final = p_final.int()
+ # calculate masks
+ if device.type == "mps":
+ p_final = p_final.to(torch.device("cpu"))
+ mask = get_masks_torch(p_final, inds, dP.shape[1:],
+ max_size_fraction=max_size_fraction)
+ del p_final
+ # flow thresholding factored out of get_masks
+ if not do_3D:
+ if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
+ # make sure labels are unique at output of get_masks
+ mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold,
+ device=device)
+
+ if mask.max() < 2**16 and mask.dtype != "uint16":
+ mask = mask.astype("uint16")
+
+ else: # nothing to compute, just make it compatible
+ dynamics_logger.info("No cell pixels found.")
+ shape = cellprob.shape
+ mask = np.zeros(cellprob.shape, "uint16")
+ return mask
+
+ if min_size > 0:
+ mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
+
+ if mask.dtype == np.uint32:
+ dynamics_logger.warning(
+ "more than 65535 masks in image, masks returned as np.uint32")
+
+ return mask
+
+def get_masks_orig(p, iscell=None, rpad=20, max_size_fraction=0.4):
+ """Create masks using pixel convergence after running dynamics.
+
+ Original implementation on CPU with histogramdd
+ (histogramdd uses excessive memory with large images)
+
+ Makes a histogram of final pixel locations p, initializes masks
+ at peaks of histogram and extends the masks from the peaks so that
+ they include all pixels with more than 2 final pixels p. Discards
+ masks with flow errors greater than the threshold.
+
+ Parameters:
+ p (float32, 3D or 4D array): Final locations of each pixel after dynamics,
+ size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
+ iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
+ iscell False to stay in their original location.
+ rpad (int, optional): Histogram edge padding. Default is 20.
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
+ total image size are removed. Default is 0.4.
+
+ Returns:
+ M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
+ """
+ pflows = []
+ edges = []
+ shape0 = p.shape[1:]
+ dims = len(p)
+ if iscell is not None:
+ if dims == 3:
+ inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]),
+ np.arange(shape0[2]), indexing="ij")
+ elif dims == 2:
+ inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]),
+ indexing="ij")
+ for i in range(dims):
+ p[i, ~iscell] = inds[i][~iscell]
+
+ for i in range(dims):
+ pflows.append(p[i].flatten().astype("int32"))
+ edges.append(np.arange(-.5 - rpad, shape0[i] + .5 + rpad, 1))
+
+ h, _ = np.histogramdd(tuple(pflows), bins=edges)
+ hmax = h.copy()
+ for i in range(dims):
+ hmax = maximum_filter1d(hmax, 5, axis=i)
+
+ seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10))
+ Nmax = h[seeds]
+ isort = np.argsort(Nmax)[::-1]
+ for s in seeds:
+ s[:] = s[isort]
+
+ pix = list(np.array(seeds).T)
+
+ shape = h.shape
+ if dims == 3:
+ expand = np.nonzero(np.ones((3, 3, 3)))
+ else:
+ expand = np.nonzero(np.ones((3, 3)))
+
+ for iter in range(5):
+ for k in range(len(pix)):
+ if iter == 0:
+ pix[k] = list(pix[k])
+ newpix = []
+ iin = []
+ for i, e in enumerate(expand):
+ epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
+ epix = epix.flatten()
+ iin.append(np.logical_and(epix >= 0, epix < shape[i]))
+ newpix.append(epix)
+ iin = np.all(tuple(iin), axis=0)
+ for p in newpix:
+ p = p[iin]
+ newpix = tuple(newpix)
+ igood = h[newpix] > 2
+ for i in range(dims):
+ pix[k][i] = newpix[i][igood]
+ if iter == 4:
+ pix[k] = tuple(pix[k])
+
+ M = np.zeros(h.shape, np.uint32)
+ for k in range(len(pix)):
+ M[pix[k]] = 1 + k
+
+ for i in range(dims):
+ pflows[i] = pflows[i] + rpad
+ M0 = M[tuple(pflows)]
+
+ # remove big masks
+ uniq, counts = fastremap.unique(M0, return_counts=True)
+ big = np.prod(shape0) * max_size_fraction
+ bigc = uniq[counts > big]
+ if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
+ M0 = fastremap.mask(M0, bigc)
+ fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
+ M0 = np.reshape(M0, shape0)
+ return M0
diff --git a/cellpose/export.py b/cellpose/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b240620206054310577a3ffa4a631070aa1fe56
--- /dev/null
+++ b/cellpose/export.py
@@ -0,0 +1,410 @@
+"""Auxiliary module for bioimageio format export
+
+Example usage:
+
+```bash
+#!/bin/bash
+
+# Define default paths and parameters
+DEFAULT_CHANNELS="1 0"
+DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
+DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
+DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
+DEFAULT_MODEL_ID="philosophical-panda"
+DEFAULT_MODEL_ICON="πΌ"
+DEFAULT_MODEL_VERSION="0.1.0"
+DEFAULT_MODEL_NAME="My Cool Cellpose"
+DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
+DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
+DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
+DEFAULT_MODEL_TAGS="cellpose 3d 2d"
+DEFAULT_MODEL_LICENSE="MIT"
+DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
+
+# Run the Python script with default parameters
+python export.py \
+ --channels $DEFAULT_CHANNELS \
+ --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
+ --path_readme "$DEFAULT_PATH_README" \
+ --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
+ --model_version "$DEFAULT_MODEL_VERSION" \
+ --model_name "$DEFAULT_MODEL_NAME" \
+ --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
+ --model_authors "$DEFAULT_MODEL_AUTHORS" \
+ --model_cite "$DEFAULT_MODEL_CITE" \
+ --model_tags $DEFAULT_MODEL_TAGS \
+ --model_license "$DEFAULT_MODEL_LICENSE" \
+ --model_repo "$DEFAULT_MODEL_REPO"
+```
+"""
+
+import os
+import sys
+import json
+import argparse
+from pathlib import Path
+from urllib.parse import urlparse
+
+import torch
+import numpy as np
+
+from cellpose.io import imread
+from cellpose.utils import download_url_to_file
+from cellpose.transforms import pad_image_ND, normalize_img, convert_image
+from cellpose.resnet_torch import CPnetBioImageIO
+
+from bioimageio.spec.model.v0_5 import (
+ ArchitectureFromFileDescr,
+ Author,
+ AxisId,
+ ChannelAxis,
+ CiteEntry,
+ Doi,
+ FileDescr,
+ Identifier,
+ InputTensorDescr,
+ IntervalOrRatioDataDescr,
+ LicenseId,
+ ModelDescr,
+ ModelId,
+ OrcidId,
+ OutputTensorDescr,
+ ParameterizedSize,
+ PytorchStateDictWeightsDescr,
+ SizeReference,
+ SpaceInputAxis,
+ SpaceOutputAxis,
+ TensorId,
+ TorchscriptWeightsDescr,
+ Version,
+ WeightsDescr,
+)
+# Define ARBITRARY_SIZE if it is not available in the module
+try:
+ from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
+except ImportError:
+ ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
+
+from bioimageio.spec.common import HttpUrl
+from bioimageio.spec import save_bioimageio_package
+from bioimageio.core import test_model
+
+DEFAULT_CHANNELS = [2, 1]
+DEFAULT_NORMALIZE_PARAMS = {
+ "axis": -1,
+ "lowhigh": None,
+ "percentile": None,
+ "normalize": True,
+ "norm3D": False,
+ "sharpen_radius": 0,
+ "smooth_radius": 0,
+ "tile_norm_blocksize": 0,
+ "tile_norm_smooth3D": 1,
+ "invert": False,
+}
+IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
+
+
+def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
+ """
+ Download and normalize image.
+ """
+ filename = os.path.basename(urlparse(IMAGE_URL).path)
+ path_image = path_dir_temp / filename
+ if not path_image.exists():
+ sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
+ download_url_to_file(IMAGE_URL, path_image)
+ img = imread(path_image).astype(np.float32)
+ img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
+ img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
+ img = np.transpose(img, (0, 3, 1, 2))
+ img, _, _ = pad_image_ND(img)
+ return img
+
+
+def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
+ cpnet_kwargs = {
+ "nbase": [nchan, 32, 64, 128, 256],
+ "nout": 3,
+ "sz": 3,
+ "mkldnn": False,
+ "conv_3D": False,
+ "max_pool": True,
+ }
+ cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
+ state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
+ cpnet_biio.load_state_dict(state_dict_cuda)
+ cpnet_biio.eval() # crucial for the prediction results
+ return cpnet_biio, cpnet_kwargs
+
+
+def descr_gen_input(path_test_input, nchan=2):
+ input_axes = [
+ SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
+ ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
+ SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
+ SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
+ ]
+ data_descr = IntervalOrRatioDataDescr(type="float32")
+ path_test_input = Path(path_test_input)
+ descr_input = InputTensorDescr(
+ id=TensorId("raw"),
+ axes=input_axes,
+ test_tensor=FileDescr(source=path_test_input),
+ data=data_descr,
+ )
+ return descr_input
+
+
+def descr_gen_output_flow(path_test_output):
+ output_axes_output_tensor = [
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
+ ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
+ SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
+ SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
+ ]
+ path_test_output = Path(path_test_output)
+ descr_output = OutputTensorDescr(
+ id=TensorId("flow"),
+ axes=output_axes_output_tensor,
+ test_tensor=FileDescr(source=path_test_output),
+ )
+ return descr_output
+
+
+def descr_gen_output_downsampled(path_dir_temp, nbase=None):
+ if nbase is None:
+ nbase = [32, 64, 128, 256]
+
+ output_axes_downsampled_tensors = [
+ [
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
+ ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
+ SpaceOutputAxis(
+ id=AxisId("y"),
+ size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
+ scale=2**offset,
+ ),
+ SpaceOutputAxis(
+ id=AxisId("x"),
+ size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
+ scale=2**offset,
+ ),
+ ]
+ for offset, base in enumerate(nbase)
+ ]
+ path_downsampled_tensors = [
+ Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
+ ]
+ descr_output_downsampled_tensors = [
+ OutputTensorDescr(
+ id=TensorId(f"downsampled_{i}"),
+ axes=axes,
+ test_tensor=FileDescr(source=path),
+ )
+ for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
+ ]
+ return descr_output_downsampled_tensors
+
+
+def descr_gen_output_style(path_test_style, nchannel=256):
+ output_axes_style_tensor = [
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
+ ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
+ ]
+ path_style_tensor = Path(path_test_style)
+ descr_output_style_tensor = OutputTensorDescr(
+ id=TensorId("style"),
+ axes=output_axes_style_tensor,
+ test_tensor=FileDescr(source=path_style_tensor),
+ )
+ return descr_output_style_tensor
+
+
+def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
+ if path_cpnet_wrapper is None:
+ path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
+ pytorch_architecture = ArchitectureFromFileDescr(
+ callable=Identifier("CPnetBioImageIO"),
+ source=Path(path_cpnet_wrapper),
+ kwargs=cpnet_kwargs,
+ )
+ return pytorch_architecture
+
+
+def descr_gen_documentation(path_doc, markdown_text):
+ with open(path_doc, "w") as f:
+ f.write(markdown_text)
+
+
+def package_to_bioimageio(
+ path_pretrained_model,
+ path_save_trace,
+ path_readme,
+ list_path_cover_images,
+ descr_input,
+ descr_output,
+ descr_output_downsampled_tensors,
+ descr_output_style_tensor,
+ pytorch_version,
+ pytorch_architecture,
+ model_id,
+ model_icon,
+ model_version,
+ model_name,
+ model_documentation,
+ model_authors,
+ model_cite,
+ model_tags,
+ model_license,
+ model_repo,
+):
+ """Package model description to BioImage.IO format."""
+ my_model_descr = ModelDescr(
+ id=ModelId(model_id) if model_id is not None else None,
+ id_emoji=model_icon,
+ version=Version(model_version),
+ name=model_name,
+ description=model_documentation,
+ authors=[
+ Author(
+ name=author["name"],
+ affiliation=author["affiliation"],
+ github_user=author["github_user"],
+ orcid=OrcidId(author["orcid"]),
+ )
+ for author in model_authors
+ ],
+ cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
+ covers=[Path(img) for img in list_path_cover_images],
+ license=LicenseId(model_license),
+ tags=model_tags,
+ documentation=Path(path_readme),
+ git_repo=HttpUrl(model_repo),
+ inputs=[descr_input],
+ outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
+ weights=WeightsDescr(
+ pytorch_state_dict=PytorchStateDictWeightsDescr(
+ source=Path(path_pretrained_model),
+ architecture=pytorch_architecture,
+ pytorch_version=pytorch_version,
+ ),
+ torchscript=TorchscriptWeightsDescr(
+ source=Path(path_save_trace),
+ pytorch_version=pytorch_version,
+ parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
+ ),
+ ),
+ )
+
+ return my_model_descr
+
+
+def parse_args():
+ # fmt: off
+ parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
+ parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
+ parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
+ parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
+ parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
+ parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
+ parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
+ parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
+ parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
+ parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
+ parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
+ parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
+ parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
+ parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
+ parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
+ return parser.parse_args()
+ # fmt: on
+
+
+def main():
+ args = parse_args()
+
+ # Parse user-provided paths and arguments
+ channels = args.channels
+ model_cite = json.loads(args.model_cite)
+ model_authors = json.loads(args.model_authors)
+
+ path_readme = Path(args.path_readme)
+ path_pretrained_model = Path(args.path_pretrained_model)
+ list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
+
+ # Auto-generated paths
+ path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
+ path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
+ path_dir_temp.mkdir(parents=True, exist_ok=True)
+
+ path_save_trace = path_dir_temp / "cp_traced.pt"
+ path_test_input = path_dir_temp / "test_input.npy"
+ path_test_output = path_dir_temp / "test_output.npy"
+ path_test_style = path_dir_temp / "test_style.npy"
+ path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
+
+ # Download test input image
+ img_np = download_and_normalize_image(path_dir_temp, channels=channels)
+ np.save(path_test_input, img_np)
+ img = torch.tensor(img_np).float()
+
+ # Load model
+ cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
+
+ # Test model and save output
+ tuple_output_tensor = cpnet_biio(img)
+ np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
+ np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
+ for i, t in enumerate(tuple_output_tensor[2:]):
+ np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
+
+ # Save traced model
+ model_traced = torch.jit.trace(cpnet_biio, img)
+ model_traced.save(path_save_trace)
+
+ # Generate model description
+ descr_input = descr_gen_input(path_test_input)
+ descr_output = descr_gen_output_flow(path_test_output)
+ descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
+ descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
+ pytorch_version = Version(torch.__version__)
+ pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
+
+ # Package model
+ my_model_descr = package_to_bioimageio(
+ path_pretrained_model,
+ path_save_trace,
+ path_readme,
+ list_path_cover_images,
+ descr_input,
+ descr_output,
+ descr_output_downsampled_tensors,
+ descr_output_style_tensor,
+ pytorch_version,
+ pytorch_architecture,
+ args.model_id,
+ args.model_icon,
+ args.model_version,
+ args.model_name,
+ args.model_documentation,
+ model_authors,
+ model_cite,
+ args.model_tags,
+ args.model_license,
+ args.model_repo,
+ )
+
+ # Test model
+ summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
+ summary.display()
+ summary = test_model(my_model_descr, weight_format="torchscript")
+ summary.display()
+
+ # Save BioImage.IO package
+ package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
+ print("package path:", package_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..614b877af9402e63f045a8ee0f32a8007f0f7787
--- /dev/null
+++ b/cellpose/gui/gui.py
@@ -0,0 +1,2526 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import sys, os, pathlib, warnings, datetime, time, copy
+
+from qtpy import QtGui, QtCore
+from superqt import QRangeSlider, QCollapsible
+from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
+import pyqtgraph as pg
+
+import numpy as np
+from scipy.stats import mode
+import cv2
+
+from . import guiparts, menus, io
+from .. import models, core, dynamics, version, denoise, train
+from ..utils import download_url_to_file, masks_to_outlines, diameters
+from ..io import get_image_files, imsave, imread
+from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
+from ..models import normalize_default
+from ..plot import disk
+
+try:
+ import matplotlib.pyplot as plt
+ MATPLOTLIB = True
+except:
+ MATPLOTLIB = False
+
+try:
+ from google.cloud import storage
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "key/cellpose-data-writer.json")
+ SERVER_UPLOAD = True
+except:
+ SERVER_UPLOAD = False
+
+Horizontal = QtCore.Qt.Orientation.Horizontal
+
+
+class Slider(QRangeSlider):
+
+ def __init__(self, parent, name, color):
+ super().__init__(Horizontal)
+ self.setEnabled(False)
+ self.valueChanged.connect(lambda: self.levelChanged(parent))
+ self.name = name
+
+ self.setStyleSheet(""" QSlider{
+ background-color: transparent;
+ }
+ """)
+ self.show()
+
+ def levelChanged(self, parent):
+ parent.level_change(self.name)
+
+
+class QHLine(QFrame):
+
+ def __init__(self):
+ super(QHLine, self).__init__()
+ self.setFrameShape(QFrame.HLine)
+ #self.setFrameShadow(QFrame.Sunken)
+ self.setLineWidth(8)
+
+
+def make_bwr():
+ # make a bwr colormap
+ b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
+ r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis]
+ g = np.append(np.linspace(0, 255, 128),
+ np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
+ color = np.concatenate((r, g, b), axis=-1).astype(np.uint8)
+ bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
+ return bwr
+
+
+def make_spectral():
+ # make spectral colormap
+ r = np.array([
+ 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80,
+ 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88,
+ 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23,
+ 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103,
+ 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167,
+ 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231,
+ 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255
+ ])
+ g = np.array([
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3,
+ 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111,
+ 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239,
+ 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143,
+ 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150,
+ 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175,
+ 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201,
+ 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226,
+ 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251,
+ 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199,
+ 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135,
+ 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63,
+ 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41,
+ 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180,
+ 189, 197, 205, 213, 222, 230, 238, 246, 254
+ ])
+ b = np.array([
+ 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143,
+ 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247,
+ 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183,
+ 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124,
+ 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90,
+ 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50,
+ 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10,
+ 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74,
+ 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205,
+ 213, 222, 230, 238, 246, 254
+ ])
+ color = (np.vstack((r, g, b)).T).astype(np.uint8)
+ spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
+ return spectral
+
+
+def make_cmap(cm=0):
+ # make a single channel colormap
+ r = np.arange(0, 256)
+ color = np.zeros((256, 3))
+ color[:, cm] = r
+ color = color.astype(np.uint8)
+ cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
+ return cmap
+
+
+def run(image=None):
+ from ..io import logger_setup
+ logger, log_file = logger_setup()
+ # Always start by initializing Qt (only once per application)
+ warnings.filterwarnings("ignore")
+ app = QApplication(sys.argv)
+ icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
+ guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
+ if not icon_path.is_file():
+ cp_dir = pathlib.Path.home().joinpath(".cellpose")
+ cp_dir.mkdir(exist_ok=True)
+ print("downloading logo")
+ download_url_to_file(
+ "https://www.cellpose.org/static/images/cellpose_transparent.png",
+ icon_path, progress=True)
+ if not guip_path.is_file():
+ print("downloading help window image")
+ download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
+ guip_path, progress=True)
+ icon_path = str(icon_path.resolve())
+ app_icon = QtGui.QIcon()
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
+ app.setWindowIcon(app_icon)
+ app.setStyle("Fusion")
+ app.setPalette(guiparts.DarkPalette())
+ #app.setStyleSheet("QLineEdit { color: yellow }")
+
+ # models.download_model_weights() # does not exist
+ MainW(image=image, logger=logger)
+ ret = app.exec_()
+ sys.exit(ret)
+
+
+class MainW(QMainWindow):
+
+ def __init__(self, image=None, logger=None):
+ super(MainW, self).__init__()
+
+ self.logger = logger
+ pg.setConfigOptions(imageAxisOrder="row-major")
+ self.setGeometry(50, 50, 1200, 1000)
+ self.setWindowTitle(f"cellpose v{version}")
+ self.cp_path = os.path.dirname(os.path.realpath(__file__))
+ app_icon = QtGui.QIcon()
+ icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
+ icon_path = str(icon_path.resolve())
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
+ self.setWindowIcon(app_icon)
+ # rgb(150,255,150)
+ self.setStyleSheet(guiparts.stylesheet())
+
+ menus.mainmenu(self)
+ menus.editmenu(self)
+ menus.modelmenu(self)
+ menus.helpmenu(self)
+
+ self.stylePressed = """QPushButton {Text-align: center;
+ background-color: rgb(150,50,150);
+ border-color: white;
+ color:white;}
+ QToolTip {
+ background-color: black;
+ color: white;
+ border: black solid 1px
+ }"""
+ self.styleUnpressed = """QPushButton {Text-align: center;
+ background-color: rgb(50,50,50);
+ border-color: white;
+ color:white;}
+ QToolTip {
+ background-color: black;
+ color: white;
+ border: black solid 1px
+ }"""
+ self.loaded = False
+
+ # ---- MAIN WIDGET LAYOUT ---- #
+ self.cwidget = QWidget(self)
+ self.lmain = QGridLayout()
+ self.cwidget.setLayout(self.lmain)
+ self.setCentralWidget(self.cwidget)
+ self.lmain.setVerticalSpacing(0)
+ self.lmain.setContentsMargins(0, 0, 0, 10)
+
+ self.imask = 0
+ self.scrollarea = QScrollArea()
+ self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
+ self.scrollarea.setStyleSheet("""QScrollArea { border: none }""")
+ self.scrollarea.setWidgetResizable(True)
+ self.swidget = QWidget(self)
+ self.scrollarea.setWidget(self.swidget)
+ self.l0 = QGridLayout()
+ self.swidget.setLayout(self.l0)
+ b = self.make_buttons()
+ self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9)
+
+ # ---- drawing area ---- #
+ self.win = pg.GraphicsLayoutWidget()
+
+ self.lmain.addWidget(self.win, 0, 9, 40, 30)
+
+ self.win.scene().sigMouseClicked.connect(self.plot_clicked)
+ self.win.scene().sigMouseMoved.connect(self.mouse_moved)
+ self.make_viewbox()
+ self.lmain.setColumnStretch(10, 1)
+ bwrmap = make_bwr()
+ self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
+ self.cmap = []
+ # spectral colormap
+ self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0,
+ alpha=False))
+ # single channel colormaps
+ for i in range(3):
+ self.cmap.append(
+ make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False))
+
+ if MATPLOTLIB:
+ self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) *
+ 255).astype(np.uint8)
+ np.random.seed(42) # make colors stable
+ self.colormap = self.colormap[np.random.permutation(1000000)]
+ else:
+ np.random.seed(42) # make colors stable
+ self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype(
+ np.uint8)
+ self.NZ = 1
+ self.restore = None
+ self.ratio = 1.
+ self.reset()
+
+ # if called with image, load it
+ if image is not None:
+ self.filename = image
+ io._load_image(self, self.filename)
+
+ # training settings
+ d = datetime.datetime.now()
+ self.training_params = {
+ "model_index": 0,
+ "learning_rate": 0.1,
+ "weight_decay": 0.0001,
+ "n_epochs": 100,
+ "SGD": True,
+ "model_name": "CP" + d.strftime("_%Y%m%d_%H%M%S"),
+ }
+
+ self.load_3D = False
+ self.stitch_threshold = 0.
+ self.flow3D_smooth = 0.
+ self.anisotropy = 1.
+ self.min_size = 15
+ self.resample = True
+
+ self.setAcceptDrops(True)
+ self.win.show()
+ self.show()
+
+ def help_window(self):
+ HW = guiparts.HelpWindow(self)
+ HW.show()
+
+ def train_help_window(self):
+ THW = guiparts.TrainHelpWindow(self)
+ THW.show()
+
+ def gui_window(self):
+ EG = guiparts.ExampleGUI(self)
+ EG.show()
+
+ def make_buttons(self):
+ self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold)
+ self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold)
+ self.medfont = QtGui.QFont("Arial", 9)
+ self.smallfont = QtGui.QFont("Arial", 8)
+
+ b = 0
+ self.satBox = QGroupBox("Views")
+ self.satBox.setFont(self.boldfont)
+ self.satBoxG = QGridLayout()
+ self.satBox.setLayout(self.satBoxG)
+ self.l0.addWidget(self.satBox, b, 0, 1, 9)
+
+ b0 = 0
+ self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
+ self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
+ self.RGBDropDown = QComboBox()
+ self.RGBDropDown.addItems(
+ ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"])
+ self.RGBDropDown.setFont(self.medfont)
+ self.RGBDropDown.currentIndexChanged.connect(self.color_choose)
+ self.satBoxG.addWidget(self.RGBDropDown, b0, 0, 1, 3)
+
+ label = QLabel("
[↑ / ↓ or W/S]
")
+ label.setFont(self.smallfont)
+ self.satBoxG.addWidget(label, b0, 3, 1, 3)
+ label = QLabel("[R / G / B \n toggles color ]")
+ label.setFont(self.smallfont)
+ self.satBoxG.addWidget(label, b0, 6, 1, 3)
+
+ b0 += 1
+ self.ViewDropDown = QComboBox()
+ self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"])
+ self.ViewDropDown.setFont(self.medfont)
+ self.ViewDropDown.model().item(3).setEnabled(False)
+ self.ViewDropDown.currentIndexChanged.connect(self.update_plot)
+ self.satBoxG.addWidget(self.ViewDropDown, b0, 0, 2, 3)
+
+ label = QLabel("[pageup / pagedown]")
+ label.setFont(self.smallfont)
+ self.satBoxG.addWidget(label, b0, 3, 1, 5)
+
+ b0 += 2
+ label = QLabel("")
+ label.setToolTip(
+ "NOTE: manually changing the saturation bars does not affect normalization in segmentation"
+ )
+ self.satBoxG.addWidget(label, b0, 0, 1, 5)
+
+ self.autobtn = QCheckBox("auto-adjust saturation")
+ self.autobtn.setToolTip("sets scale-bars as normalized for segmentation")
+ self.autobtn.setFont(self.medfont)
+ self.autobtn.setChecked(True)
+ self.satBoxG.addWidget(self.autobtn, b0, 1, 1, 8)
+
+ b0 += 1
+ self.sliders = []
+ colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]]
+ colornames = ["red", "Chartreuse", "DodgerBlue"]
+ names = ["red", "green", "blue"]
+ for r in range(3):
+ b0 += 1
+ if r == 0:
+ label = QLabel('gray/
red')
+ else:
+ label = QLabel(names[r] + ":")
+ label.setStyleSheet(f"color: {colornames[r]}")
+ label.setFont(self.boldmedfont)
+ self.satBoxG.addWidget(label, b0, 0, 1, 2)
+ self.sliders.append(Slider(self, names[r], colors[r]))
+ self.sliders[-1].setMinimum(-.1)
+ self.sliders[-1].setMaximum(255.1)
+ self.sliders[-1].setValue([0, 255])
+ self.sliders[-1].setToolTip(
+ "NOTE: manually changing the saturation bars does not affect normalization in segmentation"
+ )
+ #self.sliders[-1].setTickPosition(QSlider.TicksRight)
+ self.satBoxG.addWidget(self.sliders[-1], b0, 2, 1, 7)
+
+ b += 1
+ self.drawBox = QGroupBox("Drawing")
+ self.drawBox.setFont(self.boldfont)
+ self.drawBoxG = QGridLayout()
+ self.drawBox.setLayout(self.drawBoxG)
+ self.l0.addWidget(self.drawBox, b, 0, 1, 9)
+ self.autosave = True
+
+ b0 = 0
+ self.brush_size = 3
+ self.BrushChoose = QComboBox()
+ self.BrushChoose.addItems(["1", "3", "5", "7", "9"])
+ self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
+ self.BrushChoose.setFixedWidth(40)
+ self.BrushChoose.setFont(self.medfont)
+ self.drawBoxG.addWidget(self.BrushChoose, b0, 3, 1, 2)
+ label = QLabel("brush size:")
+ label.setFont(self.medfont)
+ self.drawBoxG.addWidget(label, b0, 0, 1, 3)
+
+ b0 += 1
+ # turn off masks
+ self.layer_off = False
+ self.masksOn = True
+ self.MCheckBox = QCheckBox("MASKS ON [X]")
+ self.MCheckBox.setFont(self.medfont)
+ self.MCheckBox.setChecked(True)
+ self.MCheckBox.toggled.connect(self.toggle_masks)
+ self.drawBoxG.addWidget(self.MCheckBox, b0, 0, 1, 5)
+
+ b0 += 1
+ # turn off outlines
+ self.outlinesOn = False # turn off by default
+ self.OCheckBox = QCheckBox("outlines on [Z]")
+ self.OCheckBox.setFont(self.medfont)
+ self.drawBoxG.addWidget(self.OCheckBox, b0, 0, 1, 5)
+ self.OCheckBox.setChecked(False)
+ self.OCheckBox.toggled.connect(self.toggle_masks)
+
+ b0 += 1
+ self.SCheckBox = QCheckBox("single stroke")
+ self.SCheckBox.setFont(self.medfont)
+ self.SCheckBox.setChecked(True)
+ self.SCheckBox.toggled.connect(self.autosave_on)
+ self.SCheckBox.setEnabled(True)
+ self.drawBoxG.addWidget(self.SCheckBox, b0, 0, 1, 5)
+
+ # buttons for deleting multiple cells
+ self.deleteBox = QGroupBox("delete multiple ROIs")
+ self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)")
+ self.deleteBox.setFont(self.medfont)
+ self.deleteBoxG = QGridLayout()
+ self.deleteBox.setLayout(self.deleteBoxG)
+ self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4)
+ self.MakeDeletionRegionButton = QPushButton("region-select")
+ self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells)
+ self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4)
+ self.MakeDeletionRegionButton.setFont(self.smallfont)
+ self.MakeDeletionRegionButton.setFixedWidth(70)
+ self.DeleteMultipleROIButton = QPushButton("click-select")
+ self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells)
+ self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4)
+ self.DeleteMultipleROIButton.setFont(self.smallfont)
+ self.DeleteMultipleROIButton.setFixedWidth(70)
+ self.DoneDeleteMultipleROIButton = QPushButton("done")
+ self.DoneDeleteMultipleROIButton.clicked.connect(
+ self.done_remove_multiple_cells)
+ self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2)
+ self.DoneDeleteMultipleROIButton.setFont(self.smallfont)
+ self.DoneDeleteMultipleROIButton.setFixedWidth(35)
+ self.CancelDeleteMultipleROIButton = QPushButton("cancel")
+ self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple)
+ self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2)
+ self.CancelDeleteMultipleROIButton.setFont(self.smallfont)
+ self.CancelDeleteMultipleROIButton.setFixedWidth(35)
+
+ b += 1
+ b0 = 0
+ self.segBox = QGroupBox("Segmentation")
+ self.segBoxG = QGridLayout()
+ self.segBox.setLayout(self.segBoxG)
+ self.l0.addWidget(self.segBox, b, 0, 1, 9)
+ self.segBox.setFont(self.boldfont)
+
+ self.diameter = 30
+ label = QLabel("diameter (pixels):")
+ label.setFont(self.medfont)
+ label.setToolTip(
+ 'you can manually enter the approximate diameter for your cells, \nor press βcalibrateβ to let the model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking βscale disk onβ)'
+ )
+ self.segBoxG.addWidget(label, b0, 0, 1, 4)
+ self.Diameter = QLineEdit()
+ self.Diameter.setToolTip(
+ 'you can manually enter the approximate diameter for your cells, \nor press βcalibrateβ to let the "cyto3" model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking βscale disk onβ)'
+ )
+ self.Diameter.setText(str(self.diameter))
+ self.Diameter.setFont(self.medfont)
+ self.Diameter.returnPressed.connect(self.update_scale)
+ self.Diameter.setFixedWidth(50)
+ self.segBoxG.addWidget(self.Diameter, b0, 4, 1, 2)
+
+ # compute diameter
+ self.SizeButton = QPushButton("calibrate")
+ self.SizeButton.setFont(self.medfont)
+ self.SizeButton.clicked.connect(self.calibrate_size)
+ self.segBoxG.addWidget(self.SizeButton, b0, 6, 1, 3)
+ #self.SizeButton.setFixedWidth(65)
+ self.SizeButton.setEnabled(False)
+ self.SizeButton.setToolTip(
+ 'you can manually enter the approximate diameter for your cells, \nor press βcalibrateβ to let the cyto3 model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking βscale disk onβ)'
+ )
+
+ b0 += 1
+ # choose channel
+ self.ChannelChoose = [QComboBox(), QComboBox()]
+ self.ChannelChoose[0].addItems(["0: gray", "1: red", "2: green", "3: blue"])
+ self.ChannelChoose[1].addItems(["0: none", "1: red", "2: green", "3: blue"])
+ cstr = ["chan to segment:", "chan2 (optional): "]
+ for i in range(2):
+ self.ChannelChoose[i].setFont(self.medfont)
+ label = QLabel(cstr[i])
+ label.setFont(self.medfont)
+ if i == 0:
+ label.setToolTip(
+ "this is the channel in which the cytoplasm or nuclei exist that you want to segment"
+ )
+ self.ChannelChoose[i].setToolTip(
+ "this is the channel in which the cytoplasm or nuclei exist that you want to segment"
+ )
+ else:
+ label.setToolTip(
+ "if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option"
+ )
+ self.ChannelChoose[i].setToolTip(
+ "if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option"
+ )
+ self.segBoxG.addWidget(label, b0 + i, 0, 1, 4)
+ self.segBoxG.addWidget(self.ChannelChoose[i], b0 + i, 4, 1, 5)
+
+ b0 += 2
+
+ # use GPU
+ self.useGPU = QCheckBox("use GPU")
+ self.useGPU.setToolTip(
+ "if you have specially installed the cuda version of torch, then you can activate this"
+ )
+ self.useGPU.setFont(self.medfont)
+ self.check_gpu()
+ self.segBoxG.addWidget(self.useGPU, b0, 0, 1, 3)
+
+ # compute segmentation with general models
+ self.net_text = ["run cyto3"]
+ nett = ["cellpose super-generalist model"]
+
+ #label = QLabel("Run:")
+ #label.setFont(self.boldfont)
+ #label.setFont(self.medfont)
+ #self.segBoxG.addWidget(label, b0, 0, 1, 2)
+ self.StyleButtons = []
+ jj = 4
+ for j in range(len(self.net_text)):
+ self.StyleButtons.append(
+ guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
+ w = 5
+ self.segBoxG.addWidget(self.StyleButtons[-1], b0, jj, 1, w)
+ jj += w
+ #self.StyleButtons[-1].setFixedWidth(140)
+ self.StyleButtons[-1].setToolTip(nett[j])
+
+ b0 += 1
+ self.roi_count = QLabel("0 ROIs")
+ self.roi_count.setFont(self.boldfont)
+ self.roi_count.setAlignment(QtCore.Qt.AlignLeft)
+ self.segBoxG.addWidget(self.roi_count, b0, 0, 1, 4)
+
+ self.progress = QProgressBar(self)
+ self.segBoxG.addWidget(self.progress, b0, 4, 1, 5)
+
+ b0 += 1
+ self.segaBox = QCollapsible("additional settings")
+ self.segaBox.setFont(self.medfont)
+ self.segaBox._toggle_btn.setFont(self.medfont)
+ self.segaBoxG = QGridLayout()
+ _content = QWidget()
+ _content.setLayout(self.segaBoxG)
+ _content.setMaximumHeight(0)
+ _content.setMinimumHeight(0)
+ #_content.layout().setContentsMargins(QtCore.QMargins(0, -20, -20, -20))
+ self.segaBox.setContent(_content)
+ self.segBoxG.addWidget(self.segaBox, b0, 0, 1, 9)
+
+ b0 = 0
+ # post-hoc paramater tuning
+ label = QLabel("flow\nthreshold:")
+ label.setToolTip(
+ "threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run"
+ )
+ label.setFont(self.medfont)
+ self.segaBoxG.addWidget(label, b0, 0, 1, 2)
+ self.flow_threshold = QLineEdit()
+ self.flow_threshold.setText("0.4")
+ self.flow_threshold.returnPressed.connect(self.compute_cprob)
+ self.flow_threshold.setFixedWidth(40)
+ self.flow_threshold.setFont(self.medfont)
+ self.segaBoxG.addWidget(self.flow_threshold, b0, 2, 1, 2)
+ self.flow_threshold.setToolTip(
+ "threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run"
+ )
+
+ label = QLabel("cellprob\nthreshold:")
+ label.setToolTip(
+ "threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run"
+ )
+ label.setFont(self.medfont)
+ self.segaBoxG.addWidget(label, b0, 4, 1, 2)
+ self.cellprob_threshold = QLineEdit()
+ self.cellprob_threshold.setText("0.0")
+ self.cellprob_threshold.returnPressed.connect(self.compute_cprob)
+ self.cellprob_threshold.setFixedWidth(40)
+ self.cellprob_threshold.setFont(self.medfont)
+ self.cellprob_threshold.setToolTip(
+ "threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run"
+ )
+ self.segaBoxG.addWidget(self.cellprob_threshold, b0, 6, 1, 2)
+
+ b0 += 1
+ label = QLabel("norm percentiles:")
+ label.setToolTip(
+ "sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)"
+ )
+ label.setFont(self.medfont)
+ self.segaBoxG.addWidget(label, b0, 0, 1, 8)
+
+ b0 += 1
+ self.norm_vals = [1., 99.]
+ self.norm_edits = []
+ labels = ["lower", "upper"]
+ tooltips = [
+ "pixels at this percentile set to 0 (default 1.0)",
+ "pixels at this percentile set to 1 (default 99.0)"
+ ]
+ for p in range(2):
+ label = QLabel(f"{labels[p]}:")
+ label.setToolTip(tooltips[p])
+ label.setFont(self.medfont)
+ self.segaBoxG.addWidget(label, b0, 4 * (p % 2), 1, 2)
+ self.norm_edits.append(QLineEdit())
+ self.norm_edits[p].setText(str(self.norm_vals[p]))
+ self.norm_edits[p].setFixedWidth(40)
+ self.norm_edits[p].setFont(self.medfont)
+ self.segaBoxG.addWidget(self.norm_edits[p], b0, 4 * (p % 2) + 2, 1, 2)
+ self.norm_edits[p].setToolTip(tooltips[p])
+
+ b0 += 1
+ label = QLabel("niter dynamics:")
+ label.setFont(self.medfont)
+ label.setToolTip(
+ "number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria"
+ )
+ self.segaBoxG.addWidget(label, b0, 0, 1, 4)
+ self.niter = QLineEdit()
+ self.niter.setText("0")
+ self.niter.setFixedWidth(40)
+ self.niter.setFont(self.medfont)
+ self.niter.setToolTip(
+ "number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria"
+ )
+ self.segaBoxG.addWidget(self.niter, b0, 4, 1, 2)
+
+ b += 1
+ b0 = 0
+ self.modelBox = QGroupBox("Other models")
+ self.modelBoxG = QGridLayout()
+ self.modelBox.setLayout(self.modelBoxG)
+ self.l0.addWidget(self.modelBox, b, 0, 1, 9)
+ self.modelBox.setFont(self.boldfont)
+ # choose models
+ self.ModelChooseC = QComboBox()
+ self.ModelChooseC.setFont(self.medfont)
+ current_index = 0
+ self.ModelChooseC.addItems(["custom models"])
+ if len(self.model_strings) > 0:
+ self.ModelChooseC.addItems(self.model_strings)
+ self.ModelChooseC.setFixedWidth(175)
+ self.ModelChooseC.setCurrentIndex(current_index)
+ tipstr = 'add or train your own models in the "Models" file menu and choose model here'
+ self.ModelChooseC.setToolTip(tipstr)
+ self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True))
+ self.modelBoxG.addWidget(self.ModelChooseC, b0, 0, 1, 8)
+
+ # compute segmentation w/ custom model
+ self.ModelButtonC = QPushButton(u"run")
+ self.ModelButtonC.setFont(self.medfont)
+ self.ModelButtonC.setFixedWidth(35)
+ self.ModelButtonC.clicked.connect(
+ lambda: self.compute_segmentation(custom=True))
+ self.modelBoxG.addWidget(self.ModelButtonC, b0, 8, 1, 1)
+ self.ModelButtonC.setEnabled(False)
+
+ self.net_names = [
+ "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3",
+ "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3",
+ "cyto", "cyto2", "CPx"]
+
+ nett = [
+ "nuclei", "cellpose (cyto2_cp3)", "tissuenet_cp3", "livecell_cp3",
+ "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3",
+ "deepbacs_cp3", "cyto", "cyto2",
+ "CPx (from Cellpose2)"
+ ]
+ b0 += 1
+ self.ModelChooseB = QComboBox()
+ self.ModelChooseB.setFont(self.medfont)
+ self.ModelChooseB.addItems(["dataset-specific models"])
+ self.ModelChooseB.addItems(nett)
+ self.ModelChooseB.setFixedWidth(175)
+ tipstr = "dataset-specific models"
+ self.ModelChooseB.setToolTip(tipstr)
+ self.ModelChooseB.activated.connect(lambda: self.model_choose(custom=False))
+ self.modelBoxG.addWidget(self.ModelChooseB, b0, 0, 1, 8)
+
+ # compute segmentation w/ cp model
+ self.ModelButtonB = QPushButton(u"run")
+ self.ModelButtonB.setFont(self.medfont)
+ self.ModelButtonB.setFixedWidth(35)
+ self.ModelButtonB.clicked.connect(
+ lambda: self.compute_segmentation(custom=False))
+ self.modelBoxG.addWidget(self.ModelButtonB, b0, 8, 1, 1)
+ self.ModelButtonB.setEnabled(False)
+
+ b += 1
+ self.denoiseBox = QGroupBox("Image restoration")
+ self.denoiseBox.setFont(self.boldfont)
+ self.denoiseBoxG = QGridLayout()
+ self.denoiseBox.setLayout(self.denoiseBoxG)
+ self.l0.addWidget(self.denoiseBox, b, 0, 1, 9)
+
+ b0 = 0
+
+ # DENOISING
+ self.DenoiseButtons = []
+ nett = [
+ "clear restore/filter",
+ "filter image (settings below)",
+ "denoise (please set cell diameter first)",
+ "deblur (please set cell diameter first)",
+ "upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)",
+ "one-click model trained to denoise+deblur+upsample (please set cell diameter first)"
+ ]
+ self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"]
+ self.restore = None
+ self.ratio = 1.
+ jj = 0
+ w = 3
+ for j in range(len(self.denoise_text)):
+ self.DenoiseButtons.append(
+ guiparts.DenoiseButton(self, self.denoise_text[j]))
+ self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w)
+ self.DenoiseButtons[-1].setFixedWidth(75)
+ self.DenoiseButtons[-1].setToolTip(nett[j])
+ self.DenoiseButtons[-1].setFont(self.medfont)
+ b0 += 1 if j%2==1 else 0
+ jj = 0 if j%2==1 else jj + w
+
+ # b0+=1
+ self.save_norm = QCheckBox("save restored/filtered image")
+ self.save_norm.setFont(self.medfont)
+ self.save_norm.setToolTip("save restored/filtered image in _seg.npy file")
+ self.save_norm.setChecked(True)
+ # self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8)
+
+ b0 -= 3
+ label = QLabel("restore-dataset:")
+ label.setToolTip(
+ "choose dataset and click [denoise], [deblur], [upsample], or [one-click]")
+ label.setFont(self.medfont)
+ self.denoiseBoxG.addWidget(label, b0, 6, 1, 3)
+
+ b0 += 1
+ self.DenoiseChoose = QComboBox()
+ self.DenoiseChoose.setFont(self.medfont)
+ self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"])
+ self.DenoiseChoose.setFixedWidth(85)
+ tipstr = "choose model type and click [denoise], [deblur], or [upsample]"
+ self.DenoiseChoose.setToolTip(tipstr)
+ self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3)
+
+ b0 += 2
+ # FILTERING
+ self.filtBox = QCollapsible("custom filter settings")
+ self.filtBox._toggle_btn.setFont(self.medfont)
+ self.filtBoxG = QGridLayout()
+ _content = QWidget()
+ _content.setLayout(self.filtBoxG)
+ _content.setMaximumHeight(0)
+ _content.setMinimumHeight(0)
+ #_content.layout().setContentsMargins(QtCore.QMargins(0, -20, -20, -20))
+ self.filtBox.setContent(_content)
+ self.denoiseBoxG.addWidget(self.filtBox, b0, 0, 1, 9)
+
+ self.filt_vals = [0., 0., 0., 0.]
+ self.filt_edits = []
+ labels = [
+ "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize",
+ "tile_norm\nsmooth3D"
+ ]
+ tooltips = [
+ "set size of surround-subtraction filter for sharpening image",
+ "set size of gaussian filter for smoothing image",
+ "set size of tiles to use to normalize image",
+ "set amount of smoothing of normalization values across planes"
+ ]
+
+ for p in range(4):
+ label = QLabel(f"{labels[p]}:")
+ label.setToolTip(tooltips[p])
+ label.setFont(self.medfont)
+ self.filtBoxG.addWidget(label, b0 + p // 2, 4 * (p % 2), 1, 2)
+ self.filt_edits.append(QLineEdit())
+ self.filt_edits[p].setText(str(self.filt_vals[p]))
+ self.filt_edits[p].setFixedWidth(40)
+ self.filt_edits[p].setFont(self.medfont)
+ self.filtBoxG.addWidget(self.filt_edits[p], b0 + p // 2, 4 * (p % 2) + 2, 1,
+ 2)
+ self.filt_edits[p].setToolTip(tooltips[p])
+
+ b0 += 3
+ self.norm3D_cb = QCheckBox("norm3D")
+ self.norm3D_cb.setFont(self.medfont)
+ self.norm3D_cb.setChecked(True)
+ self.norm3D_cb.setToolTip("run same normalization across planes")
+ self.filtBoxG.addWidget(self.norm3D_cb, b0, 0, 1, 3)
+
+ self.invert_cb = QCheckBox("invert")
+ self.invert_cb.setFont(self.medfont)
+ self.invert_cb.setToolTip("invert image")
+ self.filtBoxG.addWidget(self.invert_cb, b0, 3, 1, 3)
+
+ b += 1
+ self.l0.addWidget(QLabel(""), b, 0, 1, 9)
+ self.l0.setRowStretch(b, 100)
+
+ b += 1
+ # scale toggle
+ self.scale_on = True
+ self.ScaleOn = QCheckBox("scale disk on")
+ self.ScaleOn.setFont(self.medfont)
+ self.ScaleOn.setStyleSheet("color: rgb(150,50,150);")
+ self.ScaleOn.setChecked(True)
+ self.ScaleOn.setToolTip("see current diameter as red disk at bottom")
+ self.ScaleOn.toggled.connect(self.toggle_scale)
+ self.l0.addWidget(self.ScaleOn, b, 0, 1, 5)
+
+ return b
+
+ def level_change(self, r):
+ r = ["red", "green", "blue"].index(r)
+ if self.loaded:
+ sval = self.sliders[r].value()
+ self.saturation[r][self.currentZ] = sval
+ if not self.autobtn.isChecked():
+ for r in range(3):
+ for i in range(len(self.saturation[r])):
+ self.saturation[r][i] = self.saturation[r][self.currentZ]
+ self.update_plot()
+
+ def keyPressEvent(self, event):
+ if self.loaded:
+ if not (event.modifiers() &
+ (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
+ QtCore.Qt.AltModifier) or self.in_stroke):
+ updated = False
+ if len(self.current_point_set) > 0:
+ if event.key() == QtCore.Qt.Key_Return:
+ self.add_set()
+ else:
+ nviews = self.ViewDropDown.count() - 1
+ nviews += int(
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
+ 1).isEnabled())
+ if event.key() == QtCore.Qt.Key_X:
+ self.MCheckBox.toggle()
+ if event.key() == QtCore.Qt.Key_Z:
+ self.OCheckBox.toggle()
+ if event.key() == QtCore.Qt.Key_Left or event.key(
+ ) == QtCore.Qt.Key_A:
+ self.get_prev_image()
+ elif event.key() == QtCore.Qt.Key_Right or event.key(
+ ) == QtCore.Qt.Key_D:
+ self.get_next_image()
+ elif event.key() == QtCore.Qt.Key_PageDown:
+ self.view = (self.view + 1) % (nviews)
+ self.ViewDropDown.setCurrentIndex(self.view)
+ elif event.key() == QtCore.Qt.Key_PageUp:
+ self.view = (self.view - 1) % (nviews)
+ self.ViewDropDown.setCurrentIndex(self.view)
+
+ # can change background or stroke size if cell not finished
+ if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
+ self.color = (self.color - 1) % (6)
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_Down or event.key(
+ ) == QtCore.Qt.Key_S:
+ self.color = (self.color + 1) % (6)
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_R:
+ if self.color != 1:
+ self.color = 1
+ else:
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_G:
+ if self.color != 2:
+ self.color = 2
+ else:
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_B:
+ if self.color != 3:
+ self.color = 3
+ else:
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif (event.key() == QtCore.Qt.Key_Comma or
+ event.key() == QtCore.Qt.Key_Period):
+ count = self.BrushChoose.count()
+ gci = self.BrushChoose.currentIndex()
+ if event.key() == QtCore.Qt.Key_Comma:
+ gci = max(0, gci - 1)
+ else:
+ gci = min(count - 1, gci + 1)
+ self.BrushChoose.setCurrentIndex(gci)
+ self.brush_choose()
+ if not updated:
+ self.update_plot()
+ if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
+ self.p0.keyPressEvent(event)
+
+ def autosave_on(self):
+ if self.SCheckBox.isChecked():
+ self.autosave = True
+ else:
+ self.autosave = False
+
+ def check_gpu(self, torch=True):
+ # also decide whether or not to use torch
+ self.useGPU.setChecked(False)
+ self.useGPU.setEnabled(False)
+ if core.use_gpu(use_torch=True):
+ self.useGPU.setEnabled(True)
+ self.useGPU.setChecked(True)
+ else:
+ self.useGPU.setStyleSheet("color: rgb(80,80,80);")
+
+ def get_channels(self):
+ channels = [
+ self.ChannelChoose[0].currentIndex(), self.ChannelChoose[1].currentIndex()
+ ]
+ if hasattr(self, "current_model"):
+ if self.current_model == "nuclei":
+ channels[1] = 0
+ if channels[0] == 0:
+ channels[1] = 0
+ if self.nchan == 1:
+ channels = [0, 0]
+ elif self.nchan == 2:
+ if channels[0] == 3:
+ channels[0] = 1 if channels[1] != 1 else 2
+ print(
+ f"GUI_WARNING: only two channels in image, cannot use blue channel, changing channels"
+ )
+ if channels[1] == 3:
+ channels[1] = 1 if channels[0] != 1 else 2
+ print(
+ f"GUI_WARNING: only two channels in image, cannot use blue channel, changing channels"
+ )
+ self.ChannelChoose[0].setCurrentIndex(channels[0])
+ self.ChannelChoose[1].setCurrentIndex(channels[1])
+ return channels
+
+ def model_choose(self, custom=False):
+ index = self.ModelChooseC.currentIndex(
+ ) if custom else self.ModelChooseB.currentIndex()
+ if index > 0:
+ if custom:
+ model_name = self.ModelChooseC.currentText()
+ else:
+ model_name = self.net_names[index - 1]
+ print(f"GUI_INFO: selected model {model_name}, loading now")
+ self.initialize_model(model_name=model_name, custom=custom)
+ self.diameter = self.model.diam_labels
+ self.Diameter.setText("%0.2f" % self.diameter)
+ print(
+ f"GUI_INFO: diameter set to {self.diameter: 0.2f} (but can be changed)")
+
+ def calibrate_size(self):
+ self.initialize_model(model_name="cyto3")
+ diams, _ = self.model.sz.eval(self.stack[self.currentZ].copy(),
+ channels=self.get_channels(),
+ progress=self.progress)
+ diams = np.maximum(5.0, diams)
+ self.logger.info("estimated diameter of cells using %s model = %0.1f pixels" %
+ (self.current_model, diams))
+ self.Diameter.setText("%0.1f" % diams)
+ self.diameter = diams
+ self.update_scale()
+ self.progress.setValue(100)
+
+ def toggle_scale(self):
+ if self.scale_on:
+ self.p0.removeItem(self.scale)
+ self.scale_on = False
+ else:
+ self.p0.addItem(self.scale)
+ self.scale_on = True
+
+ def enable_buttons(self):
+ if len(self.model_strings) > 0:
+ self.ModelButtonC.setEnabled(True)
+ for i in range(len(self.StyleButtons)):
+ self.StyleButtons[i].setEnabled(True)
+ for i in range(len(self.DenoiseButtons)):
+ self.DenoiseButtons[i].setEnabled(True)
+ if self.load_3D:
+ self.DenoiseButtons[-2].setEnabled(False)
+ self.ModelButtonB.setEnabled(True)
+ self.SizeButton.setEnabled(True)
+ self.newmodel.setEnabled(True)
+ self.loadMasks.setEnabled(True)
+
+ for n in range(self.nchan):
+ self.sliders[n].setEnabled(True)
+ for n in range(self.nchan, 3):
+ self.sliders[n].setEnabled(True)
+
+ self.toggle_mask_ops()
+
+ self.update_plot()
+ self.setWindowTitle(self.filename)
+
+ def disable_buttons_removeROIs(self):
+ if len(self.model_strings) > 0:
+ self.ModelButtonC.setEnabled(False)
+ for i in range(len(self.StyleButtons)):
+ self.StyleButtons[i].setEnabled(False)
+ self.ModelButtonB.setEnabled(False)
+ self.SizeButton.setEnabled(False)
+ self.newmodel.setEnabled(False)
+ self.loadMasks.setEnabled(False)
+ self.saveSet.setEnabled(False)
+ self.savePNG.setEnabled(False)
+ self.saveFlows.setEnabled(False)
+ self.saveOutlines.setEnabled(False)
+ self.saveROIs.setEnabled(False)
+
+ self.MakeDeletionRegionButton.setEnabled(False)
+ self.DeleteMultipleROIButton.setEnabled(False)
+ self.DoneDeleteMultipleROIButton.setEnabled(True)
+ self.CancelDeleteMultipleROIButton.setEnabled(True)
+
+ def toggle_mask_ops(self):
+ self.update_layer()
+ self.toggle_saving()
+ self.toggle_removals()
+
+ def toggle_saving(self):
+ if self.ncells > 0:
+ self.saveSet.setEnabled(True)
+ self.savePNG.setEnabled(True)
+ self.saveFlows.setEnabled(True)
+ self.saveOutlines.setEnabled(True)
+ self.saveROIs.setEnabled(True)
+ else:
+ self.saveSet.setEnabled(False)
+ self.savePNG.setEnabled(False)
+ self.saveFlows.setEnabled(False)
+ self.saveOutlines.setEnabled(False)
+ self.saveROIs.setEnabled(False)
+
+ def toggle_removals(self):
+ if self.ncells > 0:
+ self.ClearButton.setEnabled(True)
+ self.remcell.setEnabled(True)
+ self.undo.setEnabled(True)
+ self.MakeDeletionRegionButton.setEnabled(True)
+ self.DeleteMultipleROIButton.setEnabled(True)
+ self.DoneDeleteMultipleROIButton.setEnabled(False)
+ self.CancelDeleteMultipleROIButton.setEnabled(False)
+ else:
+ self.ClearButton.setEnabled(False)
+ self.remcell.setEnabled(False)
+ self.undo.setEnabled(False)
+ self.MakeDeletionRegionButton.setEnabled(False)
+ self.DeleteMultipleROIButton.setEnabled(False)
+ self.DoneDeleteMultipleROIButton.setEnabled(False)
+ self.CancelDeleteMultipleROIButton.setEnabled(False)
+
+ def remove_action(self):
+ if self.selected > 0:
+ self.remove_cell(self.selected)
+
+ def undo_action(self):
+ if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ):
+ self.remove_stroke()
+ else:
+ # remove previous cell
+ if self.ncells > 0:
+ self.remove_cell(self.ncells)
+
+ def undo_remove_action(self):
+ self.undo_remove_cell()
+
+ def get_files(self):
+ folder = os.path.dirname(self.filename)
+ mask_filter = "_masks"
+ images = get_image_files(folder, mask_filter)
+ fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
+ f0 = os.path.split(self.filename)[-1]
+ idx = np.nonzero(np.array(fnames) == f0)[0][0]
+ return images, idx
+
+ def get_prev_image(self):
+ images, idx = self.get_files()
+ idx = (idx - 1) % len(images)
+ io._load_image(self, filename=images[idx])
+
+ def get_next_image(self, load_seg=True):
+ images, idx = self.get_files()
+ idx = (idx + 1) % len(images)
+ io._load_image(self, filename=images[idx], load_seg=load_seg)
+
+ def dragEnterEvent(self, event):
+ if event.mimeData().hasUrls():
+ event.accept()
+ else:
+ event.ignore()
+
+ def dropEvent(self, event):
+ files = [u.toLocalFile() for u in event.mimeData().urls()]
+ if os.path.splitext(files[0])[-1] == ".npy":
+ io._load_seg(self, filename=files[0], load_3D=self.load_3D)
+ else:
+ io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D)
+
+ def toggle_masks(self):
+ if self.MCheckBox.isChecked():
+ self.masksOn = True
+ else:
+ self.masksOn = False
+ if self.OCheckBox.isChecked():
+ self.outlinesOn = True
+ else:
+ self.outlinesOn = False
+ if not self.masksOn and not self.outlinesOn:
+ self.p0.removeItem(self.layer)
+ self.layer_off = True
+ else:
+ if self.layer_off:
+ self.p0.addItem(self.layer)
+ self.draw_layer()
+ self.update_layer()
+ if self.loaded:
+ self.update_plot()
+ self.update_layer()
+
+ def make_viewbox(self):
+ self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True,
+ name="plot1", border=[100, 100,
+ 100], invertY=True)
+ self.p0.setCursor(QtCore.Qt.CrossCursor)
+ self.brush_size = 3
+ self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1)
+ self.p0.setMenuEnabled(False)
+ self.p0.setMouseEnabled(x=True, y=True)
+ self.img = pg.ImageItem(viewbox=self.p0, parent=self)
+ self.img.autoDownsample = False
+ self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self)
+ self.layer.setLevels([0, 255])
+ self.scale = pg.ImageItem(viewbox=self.p0, parent=self)
+ self.scale.setLevels([0, 255])
+ self.p0.scene().contextMenuItem = self.p0
+ #self.p0.setMouseEnabled(x=False,y=False)
+ self.Ly, self.Lx = 512, 512
+ self.p0.addItem(self.img)
+ self.p0.addItem(self.layer)
+ self.p0.addItem(self.scale)
+
+ def reset(self):
+ # ---- start sets of points ---- #
+ self.selected = 0
+ self.nchan = 3
+ self.loaded = False
+ self.channel = [0, 1]
+ self.current_point_set = []
+ self.in_stroke = False
+ self.strokes = []
+ self.stroke_appended = True
+ self.resize = False
+ self.ncells = 0
+ self.zdraw = []
+ self.removed_cell = []
+ self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
+
+ # -- zero out image stack -- #
+ self.opacity = 128 # how opaque masks should be
+ self.outcolor = [200, 200, 255, 200]
+ self.NZ, self.Ly, self.Lx = 1, 224, 224
+ self.saturation = []
+ for r in range(3):
+ self.saturation.append([[0, 255] for n in range(self.NZ)])
+ self.sliders[r].setValue([0, 255])
+ self.sliders[r].setEnabled(False)
+ self.sliders[r].show()
+ self.currentZ = 0
+ self.flows = [[], [], [], [], [[]]]
+ # masks matrix
+ # image matrix with a scale disk
+ self.stack = np.zeros((1, self.Ly, self.Lx, 3))
+ self.Lyr, self.Lxr = self.Ly, self.Lx
+ self.Ly0, self.Lx0 = self.Ly, self.Lx
+ self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
+ self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
+ self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
+ self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
+ if self.restore and "upsample" in self.restore:
+ self.cellpix_resize = self.cellpix
+ self.cellpix_orig = self.cellpix
+ self.outpix_resize = self.cellpix
+ self.outpix_orig = self.cellpix
+ self.ismanual = np.zeros(0, "bool")
+
+ # -- set menus to default -- #
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ self.view = 0
+ self.ViewDropDown.setCurrentIndex(0)
+ self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
+ self.delete_restore()
+
+ self.clear_all()
+
+ #self.update_plot()
+ self.filename = []
+ self.loaded = False
+ self.recompute_masks = False
+
+ self.deleting_multiple = False
+ self.removing_cells_list = []
+ self.removing_region = False
+ self.remove_roi_obj = None
+
+ def delete_restore(self):
+ """ delete restored imgs but don't reset settings """
+ if hasattr(self, "stack_filtered"):
+ del self.stack_filtered
+ if hasattr(self, "cellpix_orig"):
+ self.cellpix = self.cellpix_orig.copy()
+ self.outpix = self.outpix_orig.copy()
+ del self.outpix_orig, self.outpix_resize
+ del self.cellpix_orig, self.cellpix_resize
+
+ def clear_restore(self):
+ """ delete restored imgs and reset settings """
+ print("GUI_INFO: clearing restored image")
+ self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
+ if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1:
+ self.ViewDropDown.setCurrentIndex(0)
+ self.delete_restore()
+ self.restore = None
+ self.ratio = 1.
+ self.set_normalize_params(self.get_normalize_params())
+
+ def brush_choose(self):
+ self.brush_size = self.BrushChoose.currentIndex() * 2 + 1
+ if self.loaded:
+ self.layer.setDrawKernel(kernel_size=self.brush_size)
+ self.update_layer()
+
+ def clear_all(self):
+ self.prev_selected = 0
+ self.selected = 0
+ if self.restore and "upsample" in self.restore:
+ self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8)
+ self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
+ self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
+ self.cellpix_resize = self.cellpix.copy()
+ self.outpix_resize = self.outpix.copy()
+ self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
+ self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
+ else:
+ self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
+ self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
+ self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
+
+ self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
+ self.ncells = 0
+ self.toggle_removals()
+ self.update_scale()
+ self.update_layer()
+
+ def select_cell(self, idx):
+ self.prev_selected = self.selected
+ self.selected = idx
+ if self.selected > 0:
+ z = self.currentZ
+ self.layerz[self.cellpix[z] == idx] = np.array(
+ [255, 255, 255, self.opacity])
+ self.update_layer()
+
+ def select_cell_multi(self, idx):
+ if idx > 0:
+ z = self.currentZ
+ self.layerz[self.cellpix[z] == idx] = np.array(
+ [255, 255, 255, self.opacity])
+ self.update_layer()
+
+ def unselect_cell(self):
+ if self.selected > 0:
+ idx = self.selected
+ if idx < self.ncells + 1:
+ z = self.currentZ
+ self.layerz[self.cellpix[z] == idx] = np.append(
+ self.cellcolors[idx], self.opacity)
+ if self.outlinesOn:
+ self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
+ np.uint8)
+ #[0,0,0,self.opacity])
+ self.update_layer()
+ self.selected = 0
+
+ def unselect_cell_multi(self, idx):
+ z = self.currentZ
+ self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx],
+ self.opacity)
+ if self.outlinesOn:
+ self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
+ np.uint8)
+ # [0,0,0,self.opacity])
+ self.update_layer()
+
+ def remove_cell(self, idx):
+ if isinstance(idx, (int, np.integer)):
+ idx = [idx]
+ # because the function remove_single_cell updates the state of the cellpix and outpix arrays
+ # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
+ # so that the indices are correct
+ idx.sort(reverse=True)
+ for i in idx:
+ self.remove_single_cell(i)
+ self.ncells -= len(idx) # _save_sets uses ncells
+
+ if self.ncells == 0:
+ self.ClearButton.setEnabled(False)
+ if self.NZ == 1:
+ io._save_sets_with_check(self)
+
+ self.update_layer()
+
+ def remove_single_cell(self, idx):
+ # remove from manual array
+ self.selected = 0
+ if self.NZ > 1:
+ zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0]
+ else:
+ zextent = [0]
+ for z in zextent:
+ cp = self.cellpix[z] == idx
+ op = self.outpix[z] == idx
+ # remove from self.cellpix and self.outpix
+ self.cellpix[z, cp] = 0
+ self.outpix[z, op] = 0
+ if z == self.currentZ:
+ # remove from mask layer
+ self.layerz[cp] = np.array([0, 0, 0, 0])
+
+ # reduce other pixels by -1
+ self.cellpix[self.cellpix > idx] -= 1
+ self.outpix[self.outpix > idx] -= 1
+
+ if self.NZ == 1:
+ self.removed_cell = [
+ self.ismanual[idx - 1], self.cellcolors[idx],
+ np.nonzero(cp),
+ np.nonzero(op)
+ ]
+ self.redo.setEnabled(True)
+ ar, ac = self.removed_cell[2]
+ d = datetime.datetime.now()
+ self.track_changes.append(
+ [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
+ # remove cell from lists
+ self.ismanual = np.delete(self.ismanual, idx - 1)
+ self.cellcolors = np.delete(self.cellcolors, [idx], axis=0)
+ del self.zdraw[idx - 1]
+ print("GUI_INFO: removed cell %d" % (idx - 1))
+
+ def remove_region_cells(self):
+ if self.removing_cells_list:
+ for idx in self.removing_cells_list:
+ self.unselect_cell_multi(idx)
+ self.removing_cells_list.clear()
+ self.disable_buttons_removeROIs()
+ self.removing_region = True
+
+ self.clear_multi_selected_cells()
+
+ # make roi region here in center of view, making ROI half the size of the view
+ roi_width = self.p0.viewRect().width() / 2
+ x_loc = self.p0.viewRect().x() + (roi_width / 2)
+ roi_height = self.p0.viewRect().height() / 2
+ y_loc = self.p0.viewRect().y() + (roi_height / 2)
+
+ pos = [x_loc, y_loc]
+ roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2),
+ removable=True)
+ roi.sigRemoveRequested.connect(self.remove_roi)
+ roi.sigRegionChangeFinished.connect(self.roi_changed)
+ self.p0.addItem(roi)
+ self.remove_roi_obj = roi
+ self.roi_changed(roi)
+
+ def delete_multiple_cells(self):
+ self.unselect_cell()
+ self.disable_buttons_removeROIs()
+ self.DoneDeleteMultipleROIButton.setEnabled(True)
+ self.MakeDeletionRegionButton.setEnabled(True)
+ self.CancelDeleteMultipleROIButton.setEnabled(True)
+ self.deleting_multiple = True
+
+ def done_remove_multiple_cells(self):
+ self.deleting_multiple = False
+ self.removing_region = False
+ self.DoneDeleteMultipleROIButton.setEnabled(False)
+ self.MakeDeletionRegionButton.setEnabled(False)
+ self.CancelDeleteMultipleROIButton.setEnabled(False)
+
+ if self.removing_cells_list:
+ self.removing_cells_list = list(set(self.removing_cells_list))
+ display_remove_list = [i - 1 for i in self.removing_cells_list]
+ print(f"GUI_INFO: removing cells: {display_remove_list}")
+ self.remove_cell(self.removing_cells_list)
+ self.removing_cells_list.clear()
+ self.unselect_cell()
+ self.enable_buttons()
+
+ if self.remove_roi_obj is not None:
+ self.remove_roi(self.remove_roi_obj)
+
+ def merge_cells(self, idx):
+ self.prev_selected = self.selected
+ self.selected = idx
+ if self.selected != self.prev_selected:
+ for z in range(self.NZ):
+ ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected)
+ ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected)
+ touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3,
+ (ac0[:, np.newaxis] - ac1) < 3).sum()
+ ar = np.hstack((ar0, ar1))
+ ac = np.hstack((ac0, ac1))
+ vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected)
+ vr1, vc1 = np.nonzero(self.outpix[z] == self.selected)
+ self.outpix[z, vr0, vc0] = 0
+ self.outpix[z, vr1, vc1] = 0
+ if touching > 0:
+ mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
+ mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = contours[-2][0].squeeze().T
+ vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
+
+ else:
+ vr = np.hstack((vr0, vr1))
+ vc = np.hstack((vc0, vc1))
+ color = self.cellcolors[self.prev_selected]
+ self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
+ self.remove_cell(self.selected)
+ print("GUI_INFO: merged two cells")
+ self.update_layer()
+ io._save_sets_with_check(self)
+ self.undo.setEnabled(False)
+ self.redo.setEnabled(False)
+
+ def undo_remove_cell(self):
+ if len(self.removed_cell) > 0:
+ z = 0
+ ar, ac = self.removed_cell[2]
+ vr, vc = self.removed_cell[3]
+ color = self.removed_cell[1]
+ self.draw_mask(z, ar, ac, vr, vc, color)
+ self.toggle_mask_ops()
+ self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0)
+ self.ncells += 1
+ self.ismanual = np.append(self.ismanual, self.removed_cell[0])
+ self.zdraw.append([])
+ print(">>> added back removed cell")
+ self.update_layer()
+ io._save_sets_with_check(self)
+ self.removed_cell = []
+ self.redo.setEnabled(False)
+
+ def remove_stroke(self, delete_points=True, stroke_ind=-1):
+ stroke = np.array(self.strokes[stroke_ind])
+ cZ = self.currentZ
+ inZ = stroke[0, 0] == cZ
+ if inZ:
+ outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0
+ self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0])
+ cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]]
+ ccol = self.cellcolors.copy()
+ if self.selected > 0:
+ ccol[self.selected] = np.array([255, 255, 255])
+ col2mask = ccol[cellpix]
+ if self.masksOn:
+ col2mask = np.concatenate(
+ (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1)
+ else:
+ col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)),
+ axis=-1)
+ self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask
+ if self.outlinesOn:
+ self.layerz[stroke[outpix, 1], stroke[outpix,
+ 2]] = np.array(self.outcolor)
+ if delete_points:
+ # self.current_point_set = self.current_point_set[:-1*(stroke[:,-1]==1).sum()]
+ del self.current_point_set[stroke_ind]
+ self.update_layer()
+
+ del self.strokes[stroke_ind]
+
+ def plot_clicked(self, event):
+ if event.button()==QtCore.Qt.LeftButton \
+ and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
+ and not self.removing_region:
+ if event.double():
+ try:
+ self.p0.setYRange(0, self.Ly + self.pr)
+ except:
+ self.p0.setYRange(0, self.Ly)
+ self.p0.setXRange(0, self.Lx)
+
+ def cancel_remove_multiple(self):
+ self.clear_multi_selected_cells()
+ self.done_remove_multiple_cells()
+
+ def clear_multi_selected_cells(self):
+ # unselect all previously selected cells:
+ for idx in self.removing_cells_list:
+ self.unselect_cell_multi(idx)
+ self.removing_cells_list.clear()
+
+ def add_roi(self, roi):
+ self.p0.addItem(roi)
+ self.remove_roi_obj = roi
+
+ def remove_roi(self, roi):
+ self.clear_multi_selected_cells()
+ assert roi == self.remove_roi_obj
+ self.remove_roi_obj = None
+ self.p0.removeItem(roi)
+ self.removing_region = False
+
+ def roi_changed(self, roi):
+ # find the overlapping cells and make them selected
+ pos = roi.pos()
+ size = roi.size()
+ x0 = int(pos.x())
+ y0 = int(pos.y())
+ x1 = int(pos.x() + size.x())
+ y1 = int(pos.y() + size.y())
+ if x0 < 0:
+ x0 = 0
+ if y0 < 0:
+ y0 = 0
+ if x1 > self.Lx:
+ x1 = self.Lx
+ if y1 > self.Ly:
+ y1 = self.Ly
+
+ # find cells in that region
+ cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1])
+ cell_idxs = np.trim_zeros(cell_idxs)
+ # deselect cells not in region by deselecting all and then selecting the ones in the region
+ self.clear_multi_selected_cells()
+
+ for idx in cell_idxs:
+ self.select_cell_multi(idx)
+ self.removing_cells_list.append(idx)
+
+ self.update_layer()
+
+ def mouse_moved(self, pos):
+ items = self.win.scene().items(pos)
+
+ def color_choose(self):
+ self.color = self.RGBDropDown.currentIndex()
+ self.view = 0
+ self.ViewDropDown.setCurrentIndex(self.view)
+ self.update_plot()
+
+ def update_plot(self):
+ self.view = self.ViewDropDown.currentIndex()
+ self.Ly, self.Lx, _ = self.stack[self.currentZ].shape
+
+ if self.restore and "upsample" in self.restore:
+ if self.view != 0:
+ if self.view == 3:
+ self.resize = True
+ elif len(self.flows[0]) > 0 and self.flows[0].shape[1] == self.Lyr:
+ self.resize = True
+ else:
+ self.resize = False
+ else:
+ self.resize = False
+ self.draw_layer()
+ self.update_scale()
+ self.update_layer()
+
+ if self.view == 0 or self.view == self.ViewDropDown.count() - 1:
+ image = self.stack[
+ self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ]
+ if self.nchan == 1:
+ # show single channel
+ image = image[..., 0]
+ if self.color == 0:
+ self.img.setImage(image, autoLevels=False, lut=None)
+ if self.nchan > 1:
+ levels = np.array([
+ self.saturation[0][self.currentZ],
+ self.saturation[1][self.currentZ],
+ self.saturation[2][self.currentZ]
+ ])
+ self.img.setLevels(levels)
+ else:
+ self.img.setLevels(self.saturation[0][self.currentZ])
+ elif self.color > 0 and self.color < 4:
+ if self.nchan > 1:
+ image = image[:, :, self.color - 1]
+ self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color])
+ if self.nchan > 1:
+ self.img.setLevels(self.saturation[self.color - 1][self.currentZ])
+ else:
+ self.img.setLevels(self.saturation[0][self.currentZ])
+ elif self.color == 4:
+ if self.nchan > 1:
+ image = image.mean(axis=-1)
+ self.img.setImage(image, autoLevels=False, lut=None)
+ self.img.setLevels(self.saturation[0][self.currentZ])
+ elif self.color == 5:
+ if self.nchan > 1:
+ image = image.mean(axis=-1)
+ self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
+ self.img.setLevels(self.saturation[0][self.currentZ])
+ else:
+ image = np.zeros((self.Ly, self.Lx), np.uint8)
+ if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
+ image = self.flows[self.view - 1][self.currentZ]
+ if self.view > 1:
+ self.img.setImage(image, autoLevels=False, lut=self.bwr)
+ else:
+ self.img.setImage(image, autoLevels=False, lut=None)
+ self.img.setLevels([0.0, 255.0])
+
+ for r in range(3):
+ self.sliders[r].setValue([
+ self.saturation[r][self.currentZ][0],
+ self.saturation[r][self.currentZ][1]
+ ])
+ self.win.show()
+ self.show()
+
+ def update_layer(self):
+ if self.masksOn or self.outlinesOn:
+ #self.draw_layer()
+ self.layer.setImage(self.layerz, autoLevels=False)
+ self.update_roi_count()
+ self.win.show()
+ self.show()
+
+ def update_roi_count(self):
+ self.roi_count.setText(f"{self.ncells} ROIs")
+
+ def add_set(self):
+ if len(self.current_point_set) > 0:
+ while len(self.strokes) > 0:
+ self.remove_stroke(delete_points=False)
+ if len(self.current_point_set[0]) > 8:
+ color = self.colormap[self.ncells, :3]
+ median = self.add_mask(points=self.current_point_set, color=color)
+ if median is not None:
+ self.removed_cell = []
+ self.toggle_mask_ops()
+ self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :],
+ axis=0)
+ self.ncells += 1
+ self.ismanual = np.append(self.ismanual, True)
+ if self.NZ == 1:
+ # only save after each cell if single image
+ io._save_sets_with_check(self)
+ else:
+ print("GUI_ERROR: cell too small, not drawn")
+ self.current_stroke = []
+ self.strokes = []
+ self.current_point_set = []
+ self.update_layer()
+
+ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
+ # points is list of strokes
+ points_all = np.concatenate(points, axis=0)
+
+ # loop over z values
+ median = []
+ zdraw = np.unique(points_all[:, 0])
+ z = 0
+ ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
+ 0, "int"), np.zeros(0, "int")
+ for stroke in points:
+ stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
+ vr = stroke[:, 1]
+ vc = stroke[:, 2]
+ # get points inside drawn points
+ mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
+ pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
+ axis=-1)[:, np.newaxis, :]
+ mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
+ ar, ac = np.nonzero(mask)
+ ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
+ # get dense outline
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = contours[-2][0][:,0].T
+ vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
+ # concatenate all points
+ ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
+ # if these pixels are overlapping with another cell, reassign them
+ ioverlap = self.cellpix[z][ar, ac] > 0
+ if (~ioverlap).sum() < 10:
+ print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn")
+ return None
+ elif ioverlap.sum() > 0:
+ ar, ac = ar[~ioverlap], ac[~ioverlap]
+ # compute outline of new mask
+ mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
+ mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = contours[-2][0][:,0].T
+ vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
+ ars = np.concatenate((ars, ar), axis=0)
+ acs = np.concatenate((acs, ac), axis=0)
+ vrs = np.concatenate((vrs, vr), axis=0)
+ vcs = np.concatenate((vcs, vc), axis=0)
+
+ self.draw_mask(z, ars, acs, vrs, vcs, color)
+ median.append(np.array([np.median(ars), np.median(acs)]))
+
+ self.zdraw.append(zdraw)
+ d = datetime.datetime.now()
+ self.track_changes.append(
+ [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]])
+ return median
+
+ def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
+ """ draw single mask using outlines and area """
+ if idx is None:
+ idx = self.ncells + 1
+ self.cellpix[z, vr, vc] = idx
+ self.cellpix[z, ar, ac] = idx
+ self.outpix[z, vr, vc] = idx
+ if self.restore and "upsample" in self.restore:
+ if self.resize:
+ self.cellpix_resize[z, vr, vc] = idx
+ self.cellpix_resize[z, ar, ac] = idx
+ self.outpix_resize[z, vr, vc] = idx
+ self.cellpix_orig[z, (vr / self.ratio).astype(int),
+ (vc / self.ratio).astype(int)] = idx
+ self.cellpix_orig[z, (ar / self.ratio).astype(int),
+ (ac / self.ratio).astype(int)] = idx
+ self.outpix_orig[z, (vr / self.ratio).astype(int),
+ (vc / self.ratio).astype(int)] = idx
+ else:
+ self.cellpix_orig[z, vr, vc] = idx
+ self.cellpix_orig[z, ar, ac] = idx
+ self.outpix_orig[z, vr, vc] = idx
+
+ # get upsampled mask
+ vrr = (vr.copy() * self.ratio).astype(int)
+ vcr = (vc.copy() * self.ratio).astype(int)
+ mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8)
+ pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2),
+ axis=-1)[:, np.newaxis, :]
+ mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
+ arr, acr = np.nonzero(mask)
+ arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2
+ # get dense outline
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = contours[-2][0].squeeze().T
+ vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2
+ # concatenate all points
+ arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr))))
+ self.cellpix_resize[z, vrr, vcr] = idx
+ self.cellpix_resize[z, arr, acr] = idx
+ self.outpix_resize[z, vrr, vcr] = idx
+
+ if z == self.currentZ:
+ self.layerz[ar, ac, :3] = color
+ if self.masksOn:
+ self.layerz[ar, ac, -1] = self.opacity
+ if self.outlinesOn:
+ self.layerz[vr, vc] = np.array(self.outcolor)
+
+ def compute_scale(self):
+ self.diameter = float(self.Diameter.text())
+ self.pr = int(float(self.Diameter.text()))
+ self.radii_padding = int(self.pr * 1.25)
+ self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8)
+ yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1],
+ self.pr / 2, self.Ly + self.radii_padding, self.Lx)
+ # rgb(150,50,150)
+ self.radii[yy, xx, 0] = 150
+ self.radii[yy, xx, 1] = 50
+ self.radii[yy, xx, 2] = 150
+ self.radii[yy, xx, 3] = 255
+ self.p0.setYRange(0, self.Ly + self.radii_padding)
+ self.p0.setXRange(0, self.Lx)
+
+ def update_scale(self):
+ self.compute_scale()
+ self.scale.setImage(self.radii, autoLevels=False)
+ self.scale.setLevels([0.0, 255.0])
+ self.win.show()
+ self.show()
+
+ def redraw_masks(self, masks=True, outlines=True, draw=True):
+ self.draw_layer()
+
+ def draw_masks(self):
+ self.draw_layer()
+
+ def draw_layer(self):
+ if self.resize:
+ self.Ly, self.Lx = self.Lyr, self.Lxr
+ else:
+ self.Ly, self.Lx = self.Ly0, self.Lx0
+
+ if self.masksOn or self.outlinesOn:
+ if self.restore and "upsample" in self.restore:
+ if self.resize:
+ self.cellpix = self.cellpix_resize.copy()
+ self.outpix = self.outpix_resize.copy()
+ else:
+ self.cellpix = self.cellpix_orig.copy()
+ self.outpix = self.outpix_orig.copy()
+
+ #print(self.cellpix.shape, self.outpix.shape, self.cellpix.max(), self.outpix.max())
+ self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8)
+ if self.masksOn:
+ self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :]
+ self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ]
+ > 0).astype(np.uint8)
+ if self.selected > 0:
+ self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array(
+ [255, 255, 255, self.opacity])
+ cZ = self.currentZ
+ stroke_z = np.array([s[0][0] for s in self.strokes])
+ inZ = np.nonzero(stroke_z == cZ)[0]
+ if len(inZ) > 0:
+ for i in inZ:
+ stroke = np.array(self.strokes[i])
+ self.layerz[stroke[:, 1], stroke[:,
+ 2]] = np.array([255, 0, 255, 100])
+ else:
+ self.layerz[..., 3] = 0
+
+ if self.outlinesOn:
+ self.layerz[self.outpix[self.currentZ] > 0] = np.array(
+ self.outcolor).astype(np.uint8)
+
+ def set_restore_button(self):
+ keys = self.denoise_text
+ for i, key in enumerate(keys):
+ if key != "none" and (self.restore and key in self.restore):
+ self.DenoiseButtons[i].setStyleSheet(self.stylePressed)
+ elif key == "none" and self.restore is None:
+ self.DenoiseButtons[i].setStyleSheet(self.stylePressed)
+ else:
+ if self.DenoiseButtons[i].isEnabled():
+ self.DenoiseButtons[i].setStyleSheet(self.styleUnpressed)
+
+ def set_normalize_params(self, normalize_params):
+ from cellpose.models import normalize_default
+ if self.restore != "filter":
+ keys = list(normalize_params.keys()).copy()
+ for key in keys:
+ if key != "percentile":
+ normalize_params[key] = normalize_default[key]
+ normalize_params = {**normalize_default, **normalize_params}
+ percentile = self.check_percentile_params(normalize_params["percentile"])
+ out = self.check_filter_params(normalize_params["sharpen_radius"],
+ normalize_params["smooth_radius"],
+ normalize_params["tile_norm_blocksize"],
+ normalize_params["tile_norm_smooth3D"],
+ normalize_params["norm3D"],
+ normalize_params["invert"])
+
+ def check_percentile_params(self, percentile):
+ # check normalization params
+ if percentile is not None and not (percentile[0] >= 0 and percentile[1] > 0 and
+ percentile[0] < 100 and percentile[1] <= 100
+ and percentile[1] > percentile[0]):
+ print(
+ "GUI_ERROR: percentiles need be between 0 and 100, and upper > lower, using defaults"
+ )
+ self.norm_edits[0].setText("1.")
+ self.norm_edits[1].setText("99.")
+ percentile = [1., 99.]
+ elif percentile is None:
+ percentile = [1., 99.]
+ self.norm_edits[0].setText(str(percentile[0]))
+ self.norm_edits[1].setText(str(percentile[1]))
+ return percentile
+
+ def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert):
+ tile_norm = 0 if tile_norm < 0 else tile_norm
+ sharpen = 0 if sharpen < 0 else sharpen
+ smooth = 0 if smooth < 0 else smooth
+ smooth3D = 0 if smooth3D < 0 else smooth3D
+ norm3D = bool(norm3D)
+ invert = bool(invert)
+ if tile_norm > self.Ly and tile_norm > self.Lx:
+ print(
+ "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling"
+ )
+ tile_norm = 0
+ self.filt_edits[0].setText(str(sharpen))
+ self.filt_edits[1].setText(str(smooth))
+ self.filt_edits[2].setText(str(tile_norm))
+ self.filt_edits[3].setText(str(smooth3D))
+ self.norm3D_cb.setChecked(norm3D)
+ self.invert_cb.setChecked(invert)
+ return sharpen, smooth, tile_norm, smooth3D, norm3D, invert
+
+ def get_normalize_params(self):
+ percentile = [
+ float(self.norm_edits[0].text()),
+ float(self.norm_edits[1].text())
+ ]
+ self.check_percentile_params(percentile)
+ normalize_params = {"percentile": percentile}
+ norm3D = self.norm3D_cb.isChecked()
+ normalize_params["norm3D"] = norm3D
+ if self.restore == "filter":
+ sharpen = float(self.filt_edits[0].text())
+ smooth = float(self.filt_edits[1].text())
+ tile_norm = float(self.filt_edits[2].text())
+ smooth3D = float(self.filt_edits[3].text())
+ invert = self.invert_cb.isChecked()
+ out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D,
+ invert)
+ sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out
+ normalize_params["sharpen_radius"] = sharpen
+ normalize_params["smooth_radius"] = smooth
+ normalize_params["tile_norm_blocksize"] = tile_norm
+ normalize_params["tile_norm_smooth3D"] = smooth3D
+ normalize_params["invert"] = invert
+
+ from cellpose.models import normalize_default
+ normalize_params = {**normalize_default, **normalize_params}
+
+ return normalize_params
+
+ def compute_saturation(self, return_img=False):
+ norm = self.get_normalize_params()
+ print(norm)
+ sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"]
+ percentile = norm["percentile"]
+ tile_norm = norm["tile_norm_blocksize"]
+ invert = norm["invert"]
+ norm3D = norm["norm3D"]
+ smooth3D = norm["tile_norm_smooth3D"]
+ tile_norm = norm["tile_norm_blocksize"]
+
+ # if grayscale, use gray img
+ channels = self.get_channels()
+ if channels[0] == 0:
+ img_norm = self.stack.mean(axis=-1, keepdims=True)
+ elif sharpen > 0 or smooth > 0 or tile_norm > 0:
+ img_norm = self.stack.copy()
+ else:
+ img_norm = self.stack
+
+ if sharpen > 0 or smooth > 0 or tile_norm > 0:
+ self.clear_restore()
+ self.restore = "filter"
+ print(
+ "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0"
+ )
+ print(
+ "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this"
+ )
+ img_norm = self.stack.copy()
+ if sharpen > 0 or smooth > 0:
+ img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen,
+ smooth_radius=smooth)
+
+ if tile_norm > 0:
+ img_norm = normalize99_tile(img_norm, blocksize=tile_norm,
+ lower=percentile[0], upper=percentile[1],
+ smooth3D=smooth3D, norm3D=norm3D)
+ # convert to 0->255
+ img_norm_min = img_norm.min()
+ img_norm_max = img_norm.max()
+ for c in range(img_norm.shape[-1]):
+ if np.ptp(img_norm[..., c]) > 1e-3:
+ img_norm[..., c] -= img_norm_min
+ img_norm[..., c] /= (img_norm_max - img_norm_min)
+ img_norm *= 255
+ self.stack_filtered = img_norm
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
+ 1).setEnabled(True)
+ self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
+ elif invert:
+ img_norm = self.stack.copy()
+ else:
+ img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
+
+ self.saturation = []
+ for c in range(img_norm.shape[-1]):
+ self.saturation.append([])
+ if np.ptp(img_norm[..., c]) > 1e-3:
+ if norm3D:
+ x01 = np.percentile(img_norm[..., c], percentile[0])
+ x99 = np.percentile(img_norm[..., c], percentile[1])
+ if invert:
+ x01i = 255. - x99
+ x99i = 255. - x01
+ x01, x99 = x01i, x99i
+ for n in range(self.NZ):
+ self.saturation[-1].append([x01, x99])
+ else:
+ for z in range(self.NZ):
+ if self.NZ > 1:
+ x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
+ x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
+ else:
+ x01 = np.percentile(img_norm[..., c], percentile[0])
+ x99 = np.percentile(img_norm[..., c], percentile[1])
+ if invert:
+ x01i = 255. - x99
+ x99i = 255. - x01
+ x01, x99 = x01i, x99i
+ self.saturation[-1].append([x01, x99])
+ else:
+ for n in range(self.NZ):
+ self.saturation[-1].append([0, 255.])
+ # if only 2 restore channels, add blue
+ if len(self.saturation) < 3:
+ for i in range(3 - len(self.saturation)):
+ self.saturation.append([])
+ for n in range(self.NZ):
+ self.saturation[-1].append([0, 255.])
+ print(self.saturation[2][self.currentZ])
+
+ if invert:
+ img_norm = 255. - img_norm
+ self.stack_filtered = img_norm
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
+ 1).setEnabled(True)
+ self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
+
+ if img_norm.shape[-1] == 1:
+ self.saturation.append(self.saturation[0])
+ self.saturation.append(self.saturation[0])
+
+ self.autobtn.setChecked(True)
+ self.update_plot()
+
+ def chanchoose(self, image):
+ if image.ndim > 2 and self.nchan > 1:
+ if self.ChannelChoose[0].currentIndex() == 0:
+ return image.mean(axis=-1, keepdims=True)
+ else:
+ chanid = [self.ChannelChoose[0].currentIndex() - 1]
+ if self.ChannelChoose[1].currentIndex() > 0:
+ chanid.append(self.ChannelChoose[1].currentIndex() - 1)
+ return image[:, :, chanid]
+ else:
+ return image
+
+ def get_model_path(self, custom=False):
+ if custom:
+ self.current_model = self.ModelChooseC.currentText()
+ self.current_model_path = os.fspath(
+ models.MODEL_DIR.joinpath(self.current_model))
+ else:
+ self.current_model = self.net_names[max(
+ 0,
+ self.ModelChooseB.currentIndex() - 1)]
+ self.current_model_path = models.model_path(self.current_model)
+
+ def initialize_model(self, model_name=None, custom=False):
+ if model_name == "dataset-specific models":
+ raise ValueError("need to specify model (use dropdown)")
+ elif model_name is None or custom:
+ self.get_model_path(custom=custom)
+ if not os.path.exists(self.current_model_path):
+ raise ValueError("need to specify model (use dropdown)")
+
+ if model_name is None or not isinstance(model_name, str):
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
+ pretrained_model=self.current_model_path)
+ else:
+ self.current_model = model_name
+ if self.current_model == "cyto" or self.current_model == "nuclei":
+ self.current_model_path = models.model_path(self.current_model, 0)
+ else:
+ self.current_model_path = os.fspath(
+ models.MODEL_DIR.joinpath(self.current_model))
+
+ if self.current_model != "cyto3":
+ diam_mean = 17. if self.current_model == "nuclei" else 30.
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
+ diam_mean=diam_mean,
+ model_type=self.current_model)
+ else:
+ self.model = models.Cellpose(gpu=self.useGPU.isChecked(),
+ model_type=self.current_model)
+
+ def add_model(self):
+ io._add_model(self)
+ return
+
+ def remove_model(self):
+ io._remove_model(self)
+ return
+
+ def new_model(self):
+ if self.NZ != 1:
+ print("ERROR: cannot train model on 3D data")
+ return
+
+ # train model
+ image_names = self.get_files()[0]
+ self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
+ image_names)
+ TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
+ train = TW.exec_()
+ if train:
+ self.logger.info(
+ f"training with {[os.path.split(f)[1] for f in self.train_files]}")
+ self.train_model(restore=restore, normalize_params=normalize_params)
+ else:
+ print("GUI_INFO: training cancelled")
+
+ def train_model(self, restore=None, normalize_params=None):
+ from cellpose.models import normalize_default
+ if normalize_params is None:
+ normalize_params = copy.deepcopy(normalize_default)
+ if self.training_params["model_index"] < len(models.MODEL_NAMES):
+ model_type = models.MODEL_NAMES[self.training_params["model_index"]]
+ self.logger.info(f"training new model starting at model {model_type}")
+ else:
+ model_type = None
+ self.logger.info(f"training new model starting from scratch")
+ self.current_model = model_type
+ self.channels = self.training_params["channels"]
+
+ self.logger.info(
+ f"training with chan = {self.ChannelChoose[0].currentText()}, chan2 = {self.ChannelChoose[1].currentText()}"
+ )
+
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
+ model_type=model_type)
+ self.SizeButton.setEnabled(False)
+ save_path = os.path.dirname(self.filename)
+
+ print("GUI_INFO: name of new model: " + self.training_params["model_name"])
+ print(f"GUI_INFO: SGD activated: {self.training_params['SGD']}")
+ self.new_model_path, train_losses = train.train_seg(
+ self.model.net, train_data=self.train_data, train_labels=self.train_labels,
+ channels=self.channels, normalize=normalize_params, min_train_masks=0,
+ save_path=save_path, nimg_per_epoch=max(8, len(self.train_data)),
+ learning_rate=self.training_params["learning_rate"],
+ weight_decay=self.training_params["weight_decay"],
+ n_epochs=self.training_params["n_epochs"],
+ SGD=self.training_params["SGD"],
+ model_name=self.training_params["model_name"])[:2]
+ # save train losses
+ np.save(str(self.new_model_path) + "_train_losses.npy", train_losses)
+ # run model on next image
+ io._add_model(self, self.new_model_path)
+ diam_labels = self.model.net.diam_labels.item() #.copy()
+ self.new_model_ind = len(self.model_strings)
+ self.autorun = True
+ channels = self.channels.copy()
+ self.clear_all()
+ # keep same channels
+ self.ChannelChoose[0].setCurrentIndex(channels[0])
+ self.ChannelChoose[1].setCurrentIndex(channels[1])
+ self.diameter = diam_labels
+ self.Diameter.setText("%0.2f" % self.diameter)
+ self.logger.info(f">>>> diameter set to diam_labels ( = {diam_labels: 0.3f} )")
+ self.restore = restore
+ self.set_normalize_params(normalize_params)
+ self.get_next_image(load_seg=False)
+
+ self.compute_segmentation(custom=True)
+ self.logger.info(
+ f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
+ )
+
+ def compute_restore(self):
+ if self.restore:
+ self.logger.info(f"running image restoration {self.restore}")
+ if self.restore != "filter":
+ rstr = self.restore.split("_")
+ model_type = rstr[0]
+ if len(rstr) > 1:
+ dset = rstr[1]
+ if dset == "cyto3":
+ self.DenoiseChoose.setCurrentIndex(0)
+ else:
+ self.DenoiseChoose.setCurrentIndex(1)
+ if "upsample" in self.restore:
+ i = self.DenoiseChoose.currentIndex()
+ diam_up = 30. if i==0 or i==1 else 17.
+ print(diam_up, self.ratio)
+ self.Diameter.setText(str(diam_up / self.ratio))
+ self.compute_denoise_model(model_type=model_type)
+ else:
+ self.compute_saturation()
+
+ def get_thresholds(self):
+ try:
+ flow_threshold = float(self.flow_threshold.text())
+ cellprob_threshold = float(self.cellprob_threshold.text())
+ if flow_threshold == 0.0 or self.NZ > 1:
+ flow_threshold = None
+ return flow_threshold, cellprob_threshold
+ except Exception as e:
+ print(
+ "flow threshold or cellprob threshold not a valid number, setting to defaults"
+ )
+ self.flow_threshold.setText("0.4")
+ self.cellprob_threshold.setText("0.0")
+ return 0.4, 0.0
+
+ def compute_cprob(self):
+ if self.recompute_masks:
+ flow_threshold, cellprob_threshold = self.get_thresholds()
+ if flow_threshold is None:
+ self.logger.info(
+ "computing masks with cell prob=%0.3f, no flow error threshold" %
+ (cellprob_threshold))
+ else:
+ self.logger.info(
+ "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
+ (cellprob_threshold, flow_threshold))
+ maski = dynamics.resize_and_compute_masks(
+ self.flows[4][:-1], self.flows[4][-1], p=self.flows[3].copy(),
+ cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold,
+ resize=self.cellpix.shape[-2:])[0]
+
+ self.masksOn = True
+ if not self.OCheckBox.isChecked():
+ self.MCheckBox.setChecked(True)
+ if maski.ndim < 3:
+ maski = maski[np.newaxis, ...]
+ self.logger.info("%d cells found" % (len(np.unique(maski)[1:])))
+ io._masks_to_gui(self, maski, outlines=None)
+ self.show()
+
+ def compute_denoise_model(self, model_type=None):
+ self.progress.setValue(0)
+ try:
+ tic = time.time()
+ nstr = self.DenoiseChoose.currentText()
+ nstr.replace("-", "")
+ self.clear_restore()
+ model_name = model_type + "_" + nstr
+ print(model_name)
+ # denoising model
+ self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(),
+ model_type=model_name)
+ self.progress.setValue(10)
+ diam_up = 30. if "cyto" in model_name else 17.
+
+ # params
+ channels = self.get_channels()
+ self.diameter = float(self.Diameter.text())
+ normalize_params = self.get_normalize_params()
+ print("GUI_INFO: channels: ", channels)
+ print("GUI_INFO: normalize_params: ", normalize_params)
+ print("GUI_INFO: diameter (before upsampling): ", self.diameter)
+
+ data = self.stack.copy()
+ print(data.shape)
+ self.Ly, self.Lx = data.shape[-3:-1]
+ if "upsample" in model_name:
+ # get upsampling factor
+ if self.diameter >= diam_up:
+ print(
+ f"GUI_ERROR: cannot upsample, already set to pixel diameter >= {diam_up}"
+ )
+ self.progress.setValue(0)
+ return
+ self.ratio = diam_up / self.diameter
+ print(
+ "GUI_WARNING: upsampling image, this will also duplicate mask layer and resize it, will use more RAM"
+ )
+ print(
+ f"GUI_INFO: upsampling image to {diam_up} pixel diameter ({self.ratio:0.2f} times)"
+ )
+ self.Lyr, self.Lxr = int(self.Ly * self.ratio), int(self.Lx *
+ self.ratio)
+ self.Ly0, self.Lx0 = self.Ly, self.Lx
+ # moved resize into eval
+ #data = resize_image(data, Ly=self.Lyr, Lx=self.Lxr)
+ #self.diameter = diam_up
+ #self.Diameter.setText(str(diam_up))
+ else:
+ self.Lyr, self.Lxr = self.Ly, self.Lx
+ self.Ly0, self.Lx0 = self.Ly, self.Lx
+ diam_up = self.diameter
+
+ img_norm = self.denoise_model.eval(data, channels=channels, z_axis=0,
+ channel_axis=3, diameter=self.diameter,
+ normalize=normalize_params)
+ print(img_norm.shape)
+ self.diameter = diam_up
+ self.Diameter.setText(str(diam_up))
+
+ if img_norm.ndim == 2:
+ img_norm = img_norm[:, :, np.newaxis]
+ if img_norm.ndim == 3:
+ img_norm = img_norm[np.newaxis, ...]
+
+ self.progress.setValue(100)
+ self.logger.info(f"{model_name} finished in %0.3f sec" %
+ (time.time() - tic))
+
+ # compute saturation
+ percentile = normalize_params["percentile"]
+ img_norm_min = img_norm.min()
+ img_norm_max = img_norm.max()
+ chan = [0] if channels[0] == 0 else [channels[0] - 1, channels[1] - 1]
+ self.saturation = [[], [], []]
+ for c in range(img_norm.shape[-1]):
+ if np.ptp(img_norm[..., c]) > 1e-3:
+ img_norm[..., c] -= img_norm_min
+ img_norm[..., c] /= (img_norm_max - img_norm_min)
+ for z in range(self.NZ):
+ x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) * 255.
+ x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) * 255.
+ self.saturation[chan[c]].append([x01, x99])
+ notchan = np.ones(3, "bool")
+ notchan[np.array(chan)] = False
+ notchan = np.nonzero(notchan)[0]
+ for c in notchan:
+ for z in range(self.NZ):
+ self.saturation[c].append([0, 255.])
+
+ img_norm *= 255.
+ self.autobtn.setChecked(True)
+
+ # assign to denoised channels
+ self.stack_filtered = np.zeros(
+ (self.NZ, self.Lyr, self.Lxr, self.stack.shape[-1]), "float32")
+ for i, c in enumerate(chan[:img_norm.shape[-1]]):
+ for z in range(self.NZ):
+ self.stack_filtered[z, :, :, c] = img_norm[z, :, :, i]
+
+ # make upsampled masks
+ if model_type == "upsample":
+ self.cellpix_orig = self.cellpix.copy()
+ self.outpix_orig = self.outpix.copy()
+ self.cellpix_resize = cv2.resize(
+ self.cellpix_orig[0], (self.Lxr, self.Lyr),
+ interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
+ outlines = masks_to_outlines(self.cellpix_resize[0])[np.newaxis, :, :]
+ self.outpix_resize = outlines * self.cellpix_resize
+
+ self.restore = model_name
+
+ # draw plot
+ if model_type == "upsample":
+ self.resize = True
+ else:
+ self.resize = False
+ self.draw_layer()
+ self.update_layer()
+ self.update_scale()
+ # if denoised in grayscale, show in grayscale
+ if channels[0] == 0:
+ self.RGBDropDown.setCurrentIndex(4)
+
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
+ 1).setEnabled(True)
+ self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
+
+ self.update_plot()
+
+ except Exception as e:
+ print("ERROR: %s" % e)
+
+ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
+ self.progress.setValue(0)
+ try:
+ tic = time.time()
+ self.clear_all()
+ self.flows = [[], [], []]
+ if load_model:
+ self.initialize_model(model_name=model_name, custom=custom)
+ self.progress.setValue(10)
+ do_3D = self.load_3D
+ stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
+ self.stitch_threshold, float) else self.stitch_threshold
+ anisotropy = float(self.anisotropy.text()) if not isinstance(
+ self.anisotropy, float) else self.anisotropy
+ flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance(
+ self.flow3D_smooth, float) else self.flow3D_smooth
+ min_size = int(self.min_size.text()) if not isinstance(
+ self.min_size, int) else self.min_size
+ resample = self.resample.isChecked() if not isinstance(
+ self.resample, bool) else self.resample
+
+ do_3D = False if stitch_threshold > 0. else do_3D
+
+ channels = self.get_channels()
+ if self.restore is not None and self.restore != "filter":
+ data = self.stack_filtered.copy().squeeze()
+ else:
+ data = self.stack.copy().squeeze()
+ flow_threshold, cellprob_threshold = self.get_thresholds()
+ self.diameter = float(self.Diameter.text())
+ niter = max(0, int(self.niter.text()))
+ niter = None if niter == 0 else niter
+ normalize_params = self.get_normalize_params()
+ print(normalize_params)
+ try:
+ masks, flows = self.model.eval(
+ data, channels=channels, diameter=self.diameter,
+ cellprob_threshold=cellprob_threshold,
+ flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
+ normalize=normalize_params, stitch_threshold=stitch_threshold,
+ anisotropy=anisotropy, resample=resample, flow3D_smooth=flow3D_smooth,
+ min_size=min_size,
+ progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
+ except Exception as e:
+ print("NET ERROR: %s" % e)
+ self.progress.setValue(0)
+ return
+
+ self.progress.setValue(75)
+
+ # convert flows to uint8 and resize to original image size
+ flows_new = []
+ flows_new.append(flows[0].copy()) # RGB flow
+ flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
+ 255).astype("uint8")) # cellprob
+ if self.load_3D:
+ if stitch_threshold == 0.:
+ flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
+ else:
+ flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))
+
+ if not self.load_3D:
+ if self.restore and "upsample" in self.restore:
+ self.Ly, self.Lx = self.Lyr, self.Lxr
+
+ if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
+ self.flows = []
+ for j in range(len(flows_new)):
+ self.flows.append(
+ resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
+ interpolation=cv2.INTER_NEAREST))
+ else:
+ self.flows = flows_new
+ else:
+ if not resample:
+ self.flows = []
+ Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
+ Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
+ print("GUI_INFO: resizing flows to original image size")
+ for j in range(len(flows_new)):
+ flow0 = flows_new[j]
+ if Ly0 != Ly:
+ flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
+ no_channels=flow0.ndim==3,
+ interpolation=cv2.INTER_NEAREST)
+ if Lz0 != Lz:
+ flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
+ Ly=Lz, Lx=Lx,
+ no_channels=flow0.ndim==3,
+ interpolation=cv2.INTER_NEAREST), 0, 1)
+ self.flows.append(flow0)
+ else:
+ self.flows = flows_new
+
+ # add first axis
+ if self.NZ == 1:
+ masks = masks[np.newaxis, ...]
+ self.flows = [
+ self.flows[n][np.newaxis, ...] for n in range(len(self.flows))
+ ]
+
+ self.logger.info("%d cells found with model in %0.3f sec" %
+ (len(np.unique(masks)[1:]), time.time() - tic))
+ self.progress.setValue(80)
+ z = 0
+
+ io._masks_to_gui(self, masks, outlines=None)
+ self.masksOn = True
+ self.MCheckBox.setChecked(True)
+ self.progress.setValue(100)
+ if self.restore != "filter" and self.restore is not None:
+ self.compute_saturation()
+ if not do_3D and not stitch_threshold > 0:
+ self.recompute_masks = True
+ else:
+ self.recompute_masks = False
+ except Exception as e:
+ print("ERROR: %s" % e)
diff --git a/cellpose/gui/gui3d.py b/cellpose/gui/gui3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..956d896595636806e5759598a3337e21558914ec
--- /dev/null
+++ b/cellpose/gui/gui3d.py
@@ -0,0 +1,692 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import sys, os, pathlib, warnings, datetime, time
+
+from qtpy import QtGui, QtCore
+from superqt import QRangeSlider
+from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
+import pyqtgraph as pg
+
+import numpy as np
+from scipy.stats import mode
+import cv2
+
+from . import guiparts, menus, io
+from .. import models, core, dynamics, version
+from ..utils import download_url_to_file, masks_to_outlines, diameters
+from ..io import get_image_files, imsave, imread
+from ..transforms import resize_image, normalize99 #fixed import
+from ..plot import disk
+from ..transforms import normalize99_tile, smooth_sharpen_img
+from .gui import MainW
+
+try:
+ import matplotlib.pyplot as plt
+ MATPLOTLIB = True
+except:
+ MATPLOTLIB = False
+
+
+def avg3d(C):
+ """ smooth value of c across nearby points
+ (c is center of grid directly below point)
+ b -- a -- b
+ a -- c -- a
+ b -- a -- b
+ """
+ Ly, Lx = C.shape
+ # pad T by 2
+ T = np.zeros((Ly + 2, Lx + 2), "float32")
+ M = np.zeros((Ly, Lx), "float32")
+ T[1:-1, 1:-1] = C.copy()
+ y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
+ indexing="ij")
+ y += 1
+ x += 1
+ a = 1. / 2 #/(z**2 + 1)**0.5
+ b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5
+ c = 1.
+ M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] +
+ c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] +
+ b * T[y + 1, x + 1])
+ M /= 4 * a + 4 * b + c
+ return M
+
+
+def interpZ(mask, zdraw):
+ """ find nearby planes and average their values using grid of points
+ zfill is in ascending order
+ """
+ ifill = np.ones(mask.shape[0], "bool")
+ zall = np.arange(0, mask.shape[0], 1, int)
+ ifill[zdraw] = False
+ zfill = zall[ifill]
+ zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1]
+ zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")]
+ for k, z in enumerate(zfill):
+ Z = zupper[k] - zlower[k]
+ zl = (z - zlower[k]) / Z
+ plower = avg3d(mask[zlower[k]]) * (1 - zl)
+ pupper = avg3d(mask[zupper[k]]) * zl
+ mask[z] = (plower + pupper) > 0.33
+ #Ml, norml = avg3d(mask[zlower[k]], zl)
+ #Mu, normu = avg3d(mask[zupper[k]], 1-zl)
+ #mask[z] = (Ml + Mu) / (norml + normu) > 0.5
+ return mask, zfill
+
+
+def run(image=None):
+ from ..io import logger_setup
+ logger, log_file = logger_setup()
+ # Always start by initializing Qt (only once per application)
+ warnings.filterwarnings("ignore")
+ app = QApplication(sys.argv)
+ icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
+ guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
+ style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy")
+ if not icon_path.is_file():
+ cp_dir = pathlib.Path.home().joinpath(".cellpose")
+ cp_dir.mkdir(exist_ok=True)
+ print("downloading logo")
+ download_url_to_file(
+ "https://www.cellpose.org/static/images/cellpose_transparent.png",
+ icon_path, progress=True)
+ if not guip_path.is_file():
+ print("downloading help window image")
+ download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
+ guip_path, progress=True)
+ icon_path = str(icon_path.resolve())
+ app_icon = QtGui.QIcon()
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
+ app.setWindowIcon(app_icon)
+ app.setStyle("Fusion")
+ app.setPalette(guiparts.DarkPalette())
+ #app.setStyleSheet("QLineEdit { color: yellow }")
+
+ # models.download_model_weights() # does not exist
+ MainW_3d(image=image, logger=logger)
+ ret = app.exec_()
+ sys.exit(ret)
+
+
+class MainW_3d(MainW):
+
+ def __init__(self, image=None, logger=None):
+ # MainW init
+ MainW.__init__(self, image=image, logger=logger)
+
+ # add gradZ view
+ self.ViewDropDown.insertItem(3, "gradZ")
+
+ # turn off single stroke
+ self.SCheckBox.setChecked(False)
+
+ ### add orthoviews and z-bar
+ # ortho crosshair lines
+ self.vLine = pg.InfiniteLine(angle=90, movable=False)
+ self.hLine = pg.InfiniteLine(angle=0, movable=False)
+ self.vLineOrtho = [
+ pg.InfiniteLine(angle=90, movable=False),
+ pg.InfiniteLine(angle=90, movable=False)
+ ]
+ self.hLineOrtho = [
+ pg.InfiniteLine(angle=0, movable=False),
+ pg.InfiniteLine(angle=0, movable=False)
+ ]
+ self.make_orthoviews()
+
+ # z scrollbar underneath
+ self.scroll = QScrollBar(QtCore.Qt.Horizontal)
+ self.scroll.setMaximum(10)
+ self.scroll.valueChanged.connect(self.move_in_Z)
+ self.lmain.addWidget(self.scroll, 40, 9, 1, 30)
+
+ b = 22
+
+ label = QLabel("stitch threshold:")
+ label.setToolTip(
+ "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
+ )
+ label.setFont(self.medfont)
+ self.segBoxG.addWidget(label, b, 0, 1, 4)
+ self.stitch_threshold = QLineEdit()
+ self.stitch_threshold.setText("0.0")
+ self.stitch_threshold.setFixedWidth(30)
+ self.stitch_threshold.setFont(self.medfont)
+ self.stitch_threshold.setToolTip(
+ "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
+ )
+ self.segBoxG.addWidget(self.stitch_threshold, b, 4, 1, 1)
+
+ label = QLabel("flow3D_smooth:")
+ label.setToolTip(
+ "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
+ )
+ label.setFont(self.medfont)
+ self.segBoxG.addWidget(label, b, 5, 1, 3)
+ self.flow3D_smooth = QLineEdit()
+ self.flow3D_smooth.setText("0.0")
+ self.flow3D_smooth.setFixedWidth(30)
+ self.flow3D_smooth.setFont(self.medfont)
+ self.flow3D_smooth.setToolTip(
+ "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
+ )
+ self.segBoxG.addWidget(self.flow3D_smooth, b, 8, 1, 1)
+
+ b+=1
+ label = QLabel("anisotropy:")
+ label.setToolTip(
+ "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
+ )
+ label.setFont(self.medfont)
+ self.segBoxG.addWidget(label, b, 0, 1, 4)
+ self.anisotropy = QLineEdit()
+ self.anisotropy.setText("1.0")
+ self.anisotropy.setFixedWidth(30)
+ self.anisotropy.setFont(self.medfont)
+ self.anisotropy.setToolTip(
+ "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
+ )
+ self.segBoxG.addWidget(self.anisotropy, b, 4, 1, 1)
+
+ self.resample = QCheckBox("resample")
+ self.resample.setToolTip("reample before creating masks; if diameter > 30 resample will use more CPU+GPU memory (see docs for more details)")
+ self.resample.setFont(self.medfont)
+ self.resample.setChecked(True)
+ self.segBoxG.addWidget(self.resample, b, 5, 1, 4)
+
+ b+=1
+ label = QLabel("min_size:")
+ label.setToolTip(
+ "all masks less than this size in pixels (volume) will be removed"
+ )
+ label.setFont(self.medfont)
+ self.segBoxG.addWidget(label, b, 0, 1, 4)
+ self.min_size = QLineEdit()
+ self.min_size.setText("15")
+ self.min_size.setFixedWidth(50)
+ self.min_size.setFont(self.medfont)
+ self.min_size.setToolTip(
+ "all masks less than this size in pixels (volume) will be removed"
+ )
+ self.segBoxG.addWidget(self.min_size, b, 4, 1, 3)
+
+ b += 1
+ self.orthobtn = QCheckBox("ortho")
+ self.orthobtn.setToolTip("activate orthoviews with 3D image")
+ self.orthobtn.setFont(self.medfont)
+ self.orthobtn.setChecked(False)
+ self.l0.addWidget(self.orthobtn, b, 0, 1, 2)
+ self.orthobtn.toggled.connect(self.toggle_ortho)
+
+ label = QLabel("dz:")
+ label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ label.setFont(self.medfont)
+ self.l0.addWidget(label, b, 2, 1, 1)
+ self.dz = 10
+ self.dzedit = QLineEdit()
+ self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.dzedit.setText(str(self.dz))
+ self.dzedit.returnPressed.connect(self.update_ortho)
+ self.dzedit.setFixedWidth(40)
+ self.dzedit.setFont(self.medfont)
+ self.l0.addWidget(self.dzedit, b, 3, 1, 2)
+
+ label = QLabel("z-aspect:")
+ label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ label.setFont(self.medfont)
+ self.l0.addWidget(label, b, 5, 1, 2)
+ self.zaspect = 1.0
+ self.zaspectedit = QLineEdit()
+ self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.zaspectedit.setText(str(self.zaspect))
+ self.zaspectedit.returnPressed.connect(self.update_ortho)
+ self.zaspectedit.setFixedWidth(40)
+ self.zaspectedit.setFont(self.medfont)
+ self.l0.addWidget(self.zaspectedit, b, 7, 1, 2)
+
+ b += 1
+ # add z position underneath
+ self.currentZ = 0
+ label = QLabel("Z:")
+ label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(label, b, 5, 1, 2)
+ self.zpos = QLineEdit()
+ self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.zpos.setText(str(self.currentZ))
+ self.zpos.returnPressed.connect(self.update_ztext)
+ self.zpos.setFixedWidth(40)
+ self.zpos.setFont(self.medfont)
+ self.l0.addWidget(self.zpos, b, 7, 1, 2)
+
+ # if called with image, load it
+ if image is not None:
+ self.filename = image
+ io._load_image(self, self.filename, load_3D=True)
+
+ self.load_3D = True
+
+ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
+ # points is list of strokes
+
+ points_all = np.concatenate(points, axis=0)
+
+ # loop over z values
+ median = []
+ zdraw = np.unique(points_all[:, 0])
+ zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int)
+ zmin = zdraw.min()
+ pix = np.zeros((2, 0), "uint16")
+ mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool")
+ k = 0
+ for z in zdraw:
+ ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
+ 0, "int"), np.zeros(0, "int")
+ for stroke in points:
+ stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
+ iz = stroke[:, 0] == z
+ vr = stroke[iz, 1]
+ vc = stroke[iz, 2]
+ if iz.sum() > 0:
+ # get points inside drawn points
+ mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
+ pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
+ axis=-1)[:, np.newaxis, :]
+ mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
+ ar, ac = np.nonzero(mask)
+ ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
+ # get dense outline
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = contours[-2][0].squeeze().T
+ vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
+ # concatenate all points
+ ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
+ # if these pixels are overlapping with another cell, reassign them
+ ioverlap = self.cellpix[z][ar, ac] > 0
+ if (~ioverlap).sum() < 8:
+ print("ERROR: cell too small without overlaps, not drawn")
+ return None
+ elif ioverlap.sum() > 0:
+ ar, ac = ar[~ioverlap], ac[~ioverlap]
+ # compute outline of new mask
+ mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
+ mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = contours[-2][0].squeeze().T
+ vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
+ ars = np.concatenate((ars, ar), axis=0)
+ acs = np.concatenate((acs, ac), axis=0)
+ vrs = np.concatenate((vrs, vr), axis=0)
+ vcs = np.concatenate((vcs, vc), axis=0)
+ self.draw_mask(z, ars, acs, vrs, vcs, color)
+
+ median.append(np.array([np.median(ars), np.median(acs)]))
+ mall[z - zmin, ars, acs] = True
+ pix = np.append(pix, np.vstack((ars, acs)), axis=-1)
+
+ mall = mall[:, pix[0].min():pix[0].max() + 1,
+ pix[1].min():pix[1].max() + 1].astype("float32")
+ ymin, xmin = pix[0].min(), pix[1].min()
+ if len(zdraw) > 1:
+ mall, zfill = interpZ(mall, zdraw - zmin)
+ for z in zfill:
+ mask = mall[z].copy()
+ ar, ac = np.nonzero(mask)
+ ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0
+ if (~ioverlap).sum() < 5:
+ print("WARNING: stroke on plane %d not included due to overlaps" %
+ z)
+ elif ioverlap.sum() > 0:
+ mask[ar[ioverlap], ac[ioverlap]] = 0
+ ar, ac = ar[~ioverlap], ac[~ioverlap]
+ # compute outline of mask
+ outlines = masks_to_outlines(mask)
+ vr, vc = np.nonzero(outlines)
+ vr, vc = vr + ymin, vc + xmin
+ ar, ac = ar + ymin, ac + xmin
+ self.draw_mask(z + zmin, ar, ac, vr, vc, color)
+
+ self.zdraw.append(zdraw)
+
+ return median
+
+ def move_in_Z(self):
+ if self.loaded:
+ self.currentZ = min(self.NZ, max(0, int(self.scroll.value())))
+ self.zpos.setText(str(self.currentZ))
+ self.update_plot()
+ self.draw_layer()
+ self.update_layer()
+
+ def make_orthoviews(self):
+ self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], []
+ for j in range(2):
+ self.pOrtho.append(
+ pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}",
+ border=[100, 100, 100], invertY=True, enableMouse=False))
+ self.pOrtho[j].setMenuEnabled(False)
+
+ self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
+ self.imgOrtho[j].autoDownsample = False
+
+ self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
+ self.layerOrtho[j].setLevels([0., 255.])
+
+ #self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j]
+ self.pOrtho[j].addItem(self.imgOrtho[j])
+ self.pOrtho[j].addItem(self.layerOrtho[j])
+ self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False)
+ self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False)
+
+ self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0)
+ self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0)
+
+ def add_orthoviews(self):
+ self.yortho = self.Ly // 2
+ self.xortho = self.Lx // 2
+ if self.NZ > 1:
+ self.update_ortho()
+
+ self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1)
+ self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1)
+
+ qGraphicsGridLayout = self.win.ci.layout
+ qGraphicsGridLayout.setColumnStretchFactor(0, 2)
+ qGraphicsGridLayout.setColumnStretchFactor(1, 1)
+ qGraphicsGridLayout.setRowStretchFactor(0, 2)
+ qGraphicsGridLayout.setRowStretchFactor(1, 1)
+
+ #self.p0.linkView(self.p0.YAxis, self.pOrtho[0])
+ #self.p0.linkView(self.p0.XAxis, self.pOrtho[1])
+
+ self.pOrtho[0].setYRange(0, self.Lx)
+ self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
+ self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
+ self.pOrtho[1].setXRange(0, self.Ly)
+ #self.pOrtho[0].setLimits(minXRange=self.dz*2+self.dz/3*2)
+ #self.pOrtho[1].setLimits(minYRange=self.dz*2+self.dz/3*2)
+
+ self.p0.addItem(self.vLine, ignoreBounds=False)
+ self.p0.addItem(self.hLine, ignoreBounds=False)
+ self.p0.setYRange(0, self.Lx)
+ self.p0.setXRange(0, self.Ly)
+
+ self.win.show()
+ self.show()
+
+ #self.p0.linkView(self.p0.XAxis, self.pOrtho[1])
+
+ def remove_orthoviews(self):
+ self.win.removeItem(self.pOrtho[0])
+ self.win.removeItem(self.pOrtho[1])
+ self.p0.removeItem(self.vLine)
+ self.p0.removeItem(self.hLine)
+ self.win.show()
+ self.show()
+
+ def update_crosshairs(self):
+ self.yortho = min(self.Ly - 1, max(0, int(self.yortho)))
+ self.xortho = min(self.Lx - 1, max(0, int(self.xortho)))
+ self.vLine.setPos(self.xortho)
+ self.hLine.setPos(self.yortho)
+ self.vLineOrtho[1].setPos(self.xortho)
+ self.hLineOrtho[1].setPos(self.zc)
+ self.vLineOrtho[0].setPos(self.zc)
+ self.hLineOrtho[0].setPos(self.yortho)
+
+ def update_ortho(self):
+ if self.NZ > 1 and self.orthobtn.isChecked():
+ dzcurrent = self.dz
+ self.dz = min(100, max(3, int(self.dzedit.text())))
+ self.zaspect = max(0.01, min(100., float(self.zaspectedit.text())))
+ self.dzedit.setText(str(self.dz))
+ self.zaspectedit.setText(str(self.zaspect))
+ if self.dz != dzcurrent:
+ self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
+ self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
+ dztot = min(self.NZ, self.dz * 2)
+ y = self.yortho
+ x = self.xortho
+ z = self.currentZ
+ if dztot == self.NZ:
+ zmin, zmax = 0, self.NZ
+ else:
+ if z - self.dz < 0:
+ zmin = 0
+ zmax = zmin + self.dz * 2
+ elif z + self.dz >= self.NZ:
+ zmax = self.NZ
+ zmin = zmax - self.dz * 2
+ else:
+ zmin, zmax = z - self.dz, z + self.dz
+ self.zc = z - zmin
+ self.update_crosshairs()
+ if self.view == 0 or self.view == 4:
+ for j in range(2):
+ if j == 0:
+ if self.view == 0:
+ image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
+ else:
+ image = self.stack_filtered[zmin:zmax, :,
+ x].transpose(1, 0, 2).copy()
+ else:
+ image = self.stack[
+ zmin:zmax,
+ y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
+ y, :].copy()
+ if self.nchan == 1:
+ # show single channel
+ image = image[..., 0]
+ if self.color == 0:
+ self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
+ if self.nchan > 1:
+ levels = np.array([
+ self.saturation[0][self.currentZ],
+ self.saturation[1][self.currentZ],
+ self.saturation[2][self.currentZ]
+ ])
+ self.imgOrtho[j].setLevels(levels)
+ else:
+ self.imgOrtho[j].setLevels(
+ self.saturation[0][self.currentZ])
+ elif self.color > 0 and self.color < 4:
+ if self.nchan > 1:
+ image = image[..., self.color - 1]
+ self.imgOrtho[j].setImage(image, autoLevels=False,
+ lut=self.cmap[self.color])
+ if self.nchan > 1:
+ self.imgOrtho[j].setLevels(
+ self.saturation[self.color - 1][self.currentZ])
+ else:
+ self.imgOrtho[j].setLevels(
+ self.saturation[0][self.currentZ])
+ elif self.color == 4:
+ if image.ndim > 2:
+ image = image.astype("float32").mean(axis=2).astype("uint8")
+ self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
+ self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
+ elif self.color == 5:
+ if image.ndim > 2:
+ image = image.astype("float32").mean(axis=2).astype("uint8")
+ self.imgOrtho[j].setImage(image, autoLevels=False,
+ lut=self.cmap[0])
+ self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
+ self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
+ self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)
+
+ else:
+ image = np.zeros((10, 10), "uint8")
+ self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
+ self.imgOrtho[0].setLevels([0.0, 255.0])
+ self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
+ self.imgOrtho[1].setLevels([0.0, 255.0])
+
+ zrange = zmax - zmin
+ self.layer_ortho = [
+ np.zeros((self.Ly, zrange, 4), "uint8"),
+ np.zeros((zrange, self.Lx, 4), "uint8")
+ ]
+ if self.masksOn:
+ for j in range(2):
+ if j == 0:
+ cp = self.cellpix[zmin:zmax, :, x].T
+ else:
+ cp = self.cellpix[zmin:zmax, y]
+ self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
+ self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
+ if self.selected > 0:
+ self.layer_ortho[j][cp == self.selected] = np.array(
+ [255, 255, 255, self.opacity])
+
+ if self.outlinesOn:
+ for j in range(2):
+ if j == 0:
+ op = self.outpix[zmin:zmax, :, x].T
+ else:
+ op = self.outpix[zmin:zmax, y]
+ self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")
+
+ for j in range(2):
+ self.layerOrtho[j].setImage(self.layer_ortho[j])
+ self.win.show()
+ self.show()
+
+ def toggle_ortho(self):
+ if self.orthobtn.isChecked():
+ self.add_orthoviews()
+ else:
+ self.remove_orthoviews()
+
+ def plot_clicked(self, event):
+ if event.button()==QtCore.Qt.LeftButton \
+ and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
+ and not self.removing_region:
+ if event.double():
+ try:
+ self.p0.setYRange(0, self.Ly + self.pr)
+ except:
+ self.p0.setYRange(0, self.Ly)
+ self.p0.setXRange(0, self.Lx)
+ elif self.loaded and not self.in_stroke:
+ if self.orthobtn.isChecked():
+ items = self.win.scene().items(event.scenePos())
+ for x in items:
+ if x == self.p0:
+ pos = self.p0.mapSceneToView(event.scenePos())
+ x = int(pos.x())
+ y = int(pos.y())
+ if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx:
+ self.yortho = y
+ self.xortho = x
+ self.update_ortho()
+
+ def update_plot(self):
+ super().update_plot()
+ if self.NZ > 1 and self.orthobtn.isChecked():
+ self.update_ortho()
+ self.win.show()
+ self.show()
+
+ def keyPressEvent(self, event):
+ if self.loaded:
+ if not (event.modifiers() &
+ (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
+ QtCore.Qt.AltModifier) or self.in_stroke):
+ updated = False
+ if len(self.current_point_set) > 0:
+ if event.key() == QtCore.Qt.Key_Return:
+ self.add_set()
+ if self.NZ > 1:
+ if event.key() == QtCore.Qt.Key_Left:
+ self.currentZ = max(0, self.currentZ - 1)
+ self.scroll.setValue(self.currentZ)
+ updated = True
+ elif event.key() == QtCore.Qt.Key_Right:
+ self.currentZ = min(self.NZ - 1, self.currentZ + 1)
+ self.scroll.setValue(self.currentZ)
+ updated = True
+ else:
+ nviews = self.ViewDropDown.count() - 1
+ nviews += int(
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
+ 1).isEnabled())
+ if event.key() == QtCore.Qt.Key_X:
+ self.MCheckBox.toggle()
+ if event.key() == QtCore.Qt.Key_Z:
+ self.OCheckBox.toggle()
+ if event.key() == QtCore.Qt.Key_Left or event.key(
+ ) == QtCore.Qt.Key_A:
+ self.currentZ = max(0, self.currentZ - 1)
+ self.scroll.setValue(self.currentZ)
+ updated = True
+ elif event.key() == QtCore.Qt.Key_Right or event.key(
+ ) == QtCore.Qt.Key_D:
+ self.currentZ = min(self.NZ - 1, self.currentZ + 1)
+ self.scroll.setValue(self.currentZ)
+ updated = True
+ elif event.key() == QtCore.Qt.Key_PageDown:
+ self.view = (self.view + 1) % (nviews)
+ self.ViewDropDown.setCurrentIndex(self.view)
+ elif event.key() == QtCore.Qt.Key_PageUp:
+ self.view = (self.view - 1) % (nviews)
+ self.ViewDropDown.setCurrentIndex(self.view)
+
+ # can change background or stroke size if cell not finished
+ if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
+ self.color = (self.color - 1) % (6)
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_Down or event.key(
+ ) == QtCore.Qt.Key_S:
+ self.color = (self.color + 1) % (6)
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_R:
+ if self.color != 1:
+ self.color = 1
+ else:
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_G:
+ if self.color != 2:
+ self.color = 2
+ else:
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif event.key() == QtCore.Qt.Key_B:
+ if self.color != 3:
+ self.color = 3
+ else:
+ self.color = 0
+ self.RGBDropDown.setCurrentIndex(self.color)
+ elif (event.key() == QtCore.Qt.Key_Comma or
+ event.key() == QtCore.Qt.Key_Period):
+ count = self.BrushChoose.count()
+ gci = self.BrushChoose.currentIndex()
+ if event.key() == QtCore.Qt.Key_Comma:
+ gci = max(0, gci - 1)
+ else:
+ gci = min(count - 1, gci + 1)
+ self.BrushChoose.setCurrentIndex(gci)
+ self.brush_choose()
+ if not updated:
+ self.update_plot()
+ if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
+ self.p0.keyPressEvent(event)
+
+ def update_ztext(self):
+ zpos = self.currentZ
+ try:
+ zpos = int(self.zpos.text())
+ except:
+ print("ERROR: zposition is not a number")
+ self.currentZ = max(0, min(self.NZ - 1, zpos))
+ self.zpos.setText(str(self.currentZ))
+ self.scroll.setValue(self.currentZ)
diff --git a/cellpose/gui/guihelpwindowtext.html b/cellpose/gui/guihelpwindowtext.html
new file mode 100644
index 0000000000000000000000000000000000000000..4ca5d209f1f39905ec60b049e288a3cf9fcff37b
--- /dev/null
+++ b/cellpose/gui/guihelpwindowtext.html
@@ -0,0 +1,151 @@
+
+
+ Main GUI mouse controls:
+
+
+ - Pan = left-click + drag
+ - Zoom = scroll wheel (or +/= and - buttons)
+ - Full view = double left-click
+ - Select mask = left-click on mask
+ - Delete mask = Ctrl (or COMMAND on Mac) +
+ left-click
+
+ - Merge masks = Alt + left-click (will merge
+ last two)
+
+ - Start draw mask = right-click
+ - End draw mask = right-click, or return to
+ circle at beginning
+
+
+ Overlaps in masks are NOT allowed. If you
+ draw a mask on top of another mask, it is cropped so that it doesnβt overlap with the old mask. Masks in 2D
+ should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then
+ you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D
+ labelling will fill in planes that you have not labelled so that you do not have to as densely label.
+
+ !NOTE!: The GUI automatically saves after
+ you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or
+ with Ctrl+S. The output file is in the same folder as the loaded image with _seg.npy appended.
+
+
+ Bulk Mask Deletion
+ Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once.
+ Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete,
+ click the 'done' button to delete them.
+
+
+ Alternatively, you can create a rectangular region to delete a regions of masks by clicking the
+ 'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete.
+ Once you have selected the masks you want to delete, click the 'done' button to delete them.
+
+
+ At any point in the process, you can click the 'cancel' button to cancel the bulk deletion.
+
+
+
+
+
+ FYI there are tooltips throughout the GUI (hover over text to see)
+
+
+
+ | Keyboard shortcuts |
+ Description |
+
+
+
+
+ | =/+ button // - button |
+ zoom in // zoom out |
+
+
+ | CTRL+Z |
+ undo previously drawn mask/stroke |
+
+
+ | CTRL+Y |
+ undo remove mask |
+
+
+ | CTRL+0 |
+ clear all masks |
+
+
+ | CTRL+L |
+ load image (can alternatively drag and drop image) |
+
+
+ | CTRL+S |
+ SAVE MASKS IN IMAGE to _seg.npy file |
+
+
+ | CTRL+T |
+ train model using _seg.npy files in folder
+ |
+
+ | CTRL+P |
+ load _seg.npy file (note: it will load automatically with image if it exists) |
+
+
+ | CTRL+M |
+ load masks file (must be same size as image with 0 for NO mask, and 1,2,3β¦ for masks) |
+
+
+ | CTRL+N |
+ save masks as PNG |
+
+
+ | CTRL+R |
+ save ROIs to native ImageJ ROI format |
+
+
+ | CTRL+F |
+ save flows to image file |
+
+
+ | A/D or LEFT/RIGHT |
+ cycle through images in current directory |
+
+
+ | W/S or UP/DOWN |
+ change color (RGB/gray/red/green/blue) |
+
+
+ | R / G / B |
+ toggle between RGB and Red or Green or Blue |
+
+
+ | PAGE-UP / PAGE-DOWN |
+ change to flows and cell prob views (if segmentation computed) |
+
+
+ | X |
+ turn masks ON or OFF |
+
+
+ | Z |
+ toggle outlines ON or OFF |
+
+
+ | , / . |
+ increase / decrease brush size for drawing masks |
+
+
+
+ Segmentation options \
+ (2D only)
+ SIZE: you can manually enter the \
+ approximate diameter for your cells, or press βcalibrateβ to let the model estimate it. The size is \
+ represented by a disk at the bottom of the view window (can turn this disk of by unchecking \
+ βscale disk onβ).
+ use GPU: if you have specially \
+ installed the cuda version of mxnet, then you can activate this, but it wonβt give huge speedups when \
+ running single 2D images in the GUI.
+ MODEL: there is a cytoplasm \
+ model and a nuclei model, choose what you want to segment
+ CHAN TO SEG: this is the channel in \
+ which the cytoplasm or nuclei exist
+ CHAN2 (OPT): if cytoplasm model \
+ is chosen, then choose the nuclear channel for this option
+
diff --git a/cellpose/gui/guiparts.py b/cellpose/gui/guiparts.py
new file mode 100644
index 0000000000000000000000000000000000000000..f84896828e4b913f412b4715c80d00c55476b67c
--- /dev/null
+++ b/cellpose/gui/guiparts.py
@@ -0,0 +1,591 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+from qtpy import QtGui, QtCore, QtWidgets
+from qtpy.QtGui import QPainter, QPixmap
+from qtpy.QtWidgets import QApplication, QRadioButton, QWidget, QDialog, QButtonGroup, QSlider, QStyle, QStyleOptionSlider, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox
+import pyqtgraph as pg
+from pyqtgraph import functions as fn
+from pyqtgraph import Point
+import numpy as np
+import pathlib, os
+
+
+def stylesheet():
+ return """
+ QToolTip {
+ background-color: black;
+ color: white;
+ border: black solid 1px
+ }
+ QComboBox {color: white;
+ background-color: rgb(40,40,40);}
+ QComboBox::item:enabled { color: white;
+ background-color: rgb(40,40,40);
+ selection-color: white;
+ selection-background-color: rgb(50,100,50);}
+ QComboBox::item:!enabled {
+ background-color: rgb(40,40,40);
+ color: rgb(100,100,100);
+ }
+ QScrollArea > QWidget > QWidget
+ {
+ background: transparent;
+ border: none;
+ margin: 0px 0px 0px 0px;
+ }
+
+ QGroupBox
+ { border: 1px solid white; color: rgb(255,255,255);
+ border-radius: 6px;
+ margin-top: 8px;
+ padding: 0px 0px;}
+
+ QPushButton:pressed {Text-align: center;
+ background-color: rgb(150,50,150);
+ border-color: white;
+ color:white;}
+ QToolTip {
+ background-color: black;
+ color: white;
+ border: black solid 1px
+ }
+ QPushButton:!pressed {Text-align: center;
+ background-color: rgb(50,50,50);
+ border-color: white;
+ color:white;}
+ QToolTip {
+ background-color: black;
+ color: white;
+ border: black solid 1px
+ }
+ QPushButton:disabled {Text-align: center;
+ background-color: rgb(30,30,30);
+ border-color: white;
+ color:rgb(80,80,80);}
+ QToolTip {
+ background-color: black;
+ color: white;
+ border: black solid 1px
+ }
+
+ """
+
+
+class DarkPalette(QtGui.QPalette):
+ """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application.
+ (from pykilosort/kilosort4)
+ """
+
+ def __init__(self):
+ QtGui.QPalette.__init__(self)
+ self.setup()
+
+ def setup(self):
+ self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40))
+ self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255))
+ self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24))
+ self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47))
+ self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255))
+ self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255))
+ self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255))
+ self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47))
+ self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255))
+ self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0))
+ self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218))
+ self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218))
+ self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0))
+ self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text,
+ QtGui.QColor(128, 128, 128))
+ self.setColor(
+ QtGui.QPalette.Disabled,
+ QtGui.QPalette.ButtonText,
+ QtGui.QColor(128, 128, 128),
+ )
+ self.setColor(
+ QtGui.QPalette.Disabled,
+ QtGui.QPalette.WindowText,
+ QtGui.QColor(128, 128, 128),
+ )
+
+
+def create_channel_choose():
+ # choose channel
+ ChannelChoose = [QComboBox(), QComboBox()]
+ ChannelLabels = []
+ ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
+ ChannelChoose[1].addItems(["none", "red", "green", "blue"])
+ cstr = ["chan to segment:", "chan2 (optional): "]
+ for i in range(2):
+ ChannelLabels.append(QLabel(cstr[i]))
+ if i == 0:
+ ChannelLabels[i].setToolTip(
+ "this is the channel in which the cytoplasm or nuclei exist \
+ that you want to segment")
+ ChannelChoose[i].setToolTip(
+ "this is the channel in which the cytoplasm or nuclei exist \
+ that you want to segment")
+ else:
+ ChannelLabels[i].setToolTip(
+ "if cytoplasm model is chosen, and you also have a \
+ nuclear channel, then choose the nuclear channel for this option")
+ ChannelChoose[i].setToolTip(
+ "if cytoplasm model is chosen, and you also have a \
+ nuclear channel, then choose the nuclear channel for this option")
+
+ return ChannelChoose, ChannelLabels
+
+
+class ModelButton(QPushButton):
+
+ def __init__(self, parent, model_name, text):
+ super().__init__()
+ self.setEnabled(False)
+ self.setText(text)
+ self.setFont(parent.boldfont)
+ self.clicked.connect(lambda: self.press(parent))
+ self.model_name = model_name if "cyto3" not in model_name else "cyto3"
+
+ def press(self, parent):
+ parent.compute_segmentation(model_name=self.model_name)
+
+
+class DenoiseButton(QPushButton):
+
+ def __init__(self, parent, text):
+ super().__init__()
+ self.setEnabled(False)
+ self.model_type = text
+ self.setText(text)
+ self.setFont(parent.medfont)
+ self.clicked.connect(lambda: self.press(parent))
+
+ def press(self, parent):
+ if self.model_type == "filter":
+ parent.restore = "filter"
+ normalize_params = parent.get_normalize_params()
+ if (normalize_params["sharpen_radius"] == 0 and
+ normalize_params["smooth_radius"] == 0 and
+ normalize_params["tile_norm_blocksize"] == 0):
+ print(
+ "GUI_ERROR: no filtering settings on (use custom filter settings)")
+ parent.restore = None
+ return
+ parent.restore = self.model_type
+ parent.compute_saturation()
+ elif self.model_type != "none":
+ parent.compute_denoise_model(model_type=self.model_type)
+ else:
+ parent.clear_restore()
+ parent.set_restore_button()
+
+
+class TrainWindow(QDialog):
+
+ def __init__(self, parent, model_strings):
+ super().__init__(parent)
+ self.setGeometry(100, 100, 900, 550)
+ self.setWindowTitle("train settings")
+ self.win = QWidget(self)
+ self.l0 = QGridLayout()
+ self.win.setLayout(self.l0)
+
+ yoff = 0
+ qlabel = QLabel("train model w/ images + _seg.npy in current folder >>")
+ qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
+
+ qlabel.setAlignment(QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, yoff, 0, 1, 2)
+
+ # choose initial model
+ yoff += 1
+ self.ModelChoose = QComboBox()
+ self.ModelChoose.addItems(model_strings)
+ self.ModelChoose.addItems(["scratch"])
+ self.ModelChoose.setFixedWidth(150)
+ self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
+ self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)
+ qlabel = QLabel("initial model: ")
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, yoff, 0, 1, 1)
+
+ # choose channels
+ self.ChannelChoose, self.ChannelLabels = create_channel_choose()
+ for i in range(2):
+ yoff += 1
+ self.ChannelChoose[i].setFixedWidth(150)
+ self.ChannelChoose[i].setCurrentIndex(
+ parent.ChannelChoose[i].currentIndex())
+ self.l0.addWidget(self.ChannelLabels[i], yoff, 0, 1, 1)
+ self.l0.addWidget(self.ChannelChoose[i], yoff, 1, 1, 1)
+
+ # choose parameters
+ labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"]
+ self.edits = []
+ yoff += 1
+ for i, label in enumerate(labels):
+ qlabel = QLabel(label)
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, i + yoff, 0, 1, 1)
+ self.edits.append(QLineEdit())
+ self.edits[-1].setText(str(parent.training_params[label]))
+ self.edits[-1].setFixedWidth(200)
+ self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1)
+
+ yoff += 1
+ use_SGD = "SGD"
+ self.useSGD = QCheckBox(f"{use_SGD}")
+ self.useSGD.setToolTip("use SGD, if unchecked uses AdamW (recommended learning_rate then 0.001)")
+ self.useSGD.setChecked(True)
+ self.l0.addWidget(self.useSGD, i+yoff, 1, 1, 1)
+
+ yoff += len(labels)
+
+ yoff += 1
+ self.use_norm = QCheckBox(f"use restored/filtered image")
+ self.use_norm.setChecked(True)
+ #self.l0.addWidget(self.use_norm, yoff, 0, 2, 4)
+
+ yoff += 2
+ qlabel = QLabel(
+ "(to remove files, click cancel then remove \nfrom folder and reopen train window)"
+ )
+ self.l0.addWidget(qlabel, yoff, 0, 2, 4)
+
+ # click button
+ yoff += 3
+ QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
+ self.buttonBox = QDialogButtonBox(QBtn)
+ self.buttonBox.accepted.connect(lambda: self.accept(parent))
+ self.buttonBox.rejected.connect(self.reject)
+ self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4)
+
+ # list files in folder
+ qlabel = QLabel("filenames")
+ qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
+ self.l0.addWidget(qlabel, 0, 4, 1, 1)
+ qlabel = QLabel("# of masks")
+ qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
+ self.l0.addWidget(qlabel, 0, 5, 1, 1)
+
+ for i in range(10):
+ if i > len(parent.train_files) - 1:
+ break
+ elif i == 9 and len(parent.train_files) > 10:
+ label = "..."
+ nmasks = "..."
+ else:
+ label = os.path.split(parent.train_files[i])[-1]
+ nmasks = str(parent.train_labels[i].max())
+ qlabel = QLabel(label)
+ self.l0.addWidget(qlabel, i + 1, 4, 1, 1)
+ qlabel = QLabel(nmasks)
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, i + 1, 5, 1, 1)
+
+ def accept(self, parent):
+ # set training params
+ parent.training_params = {
+ "model_index": self.ModelChoose.currentIndex(),
+ "learning_rate": float(self.edits[0].text()),
+ "weight_decay": float(self.edits[1].text()),
+ "n_epochs": int(self.edits[2].text()),
+ "model_name": self.edits[3].text(),
+ "SGD": True if self.useSGD.isChecked() else False,
+ "channels": [self.ChannelChoose[0].currentIndex(),
+ self.ChannelChoose[1].currentIndex()],
+ #"use_norm": True if self.use_norm.isChecked() else False,
+ }
+ self.done(1)
+
+
+class ExampleGUI(QDialog):
+
+ def __init__(self, parent=None):
+ super(ExampleGUI, self).__init__(parent)
+ self.setGeometry(100, 100, 1300, 900)
+ self.setWindowTitle("GUI layout")
+ self.win = QWidget(self)
+ layout = QGridLayout()
+ self.win.setLayout(layout)
+ guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
+ guip_path = str(guip_path.resolve())
+ pixmap = QPixmap(guip_path)
+ label = QLabel(self)
+ label.setPixmap(pixmap)
+ pixmap.scaled
+ layout.addWidget(label, 0, 0, 1, 1)
+
+
+class HelpWindow(QDialog):
+
+ def __init__(self, parent=None):
+ super(HelpWindow, self).__init__(parent)
+ self.setGeometry(100, 50, 700, 1000)
+ self.setWindowTitle("cellpose help")
+ self.win = QWidget(self)
+ layout = QGridLayout()
+ self.win.setLayout(layout)
+
+ text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html")
+ with open(str(text_file.resolve()), "r") as f:
+ text = f.read()
+
+ label = QLabel(text)
+ label.setFont(QtGui.QFont("Arial", 8))
+ label.setWordWrap(True)
+ layout.addWidget(label, 0, 0, 1, 1)
+ self.show()
+
+
+class TrainHelpWindow(QDialog):
+
+ def __init__(self, parent=None):
+ super(TrainHelpWindow, self).__init__(parent)
+ self.setGeometry(100, 50, 700, 300)
+ self.setWindowTitle("training instructions")
+ self.win = QWidget(self)
+ layout = QGridLayout()
+ self.win.setLayout(layout)
+
+ text_file = pathlib.Path(__file__).parent.joinpath(
+ "guitrainhelpwindowtext.html")
+ with open(str(text_file.resolve()), "r") as f:
+ text = f.read()
+
+ label = QLabel(text)
+ label.setFont(QtGui.QFont("Arial", 8))
+ label.setWordWrap(True)
+ layout.addWidget(label, 0, 0, 1, 1)
+ self.show()
+
+
+class ViewBoxNoRightDrag(pg.ViewBox):
+
+ def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True,
+ invertY=False, enableMenu=True, name=None, invertX=False):
+ pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY,
+ enableMenu, name, invertX)
+ self.parent = parent
+ self.axHistoryPointer = -1
+
+ def keyPressEvent(self, ev):
+ """
+ This routine should capture key presses in the current view box.
+ The following events are implemented:
+ +/= : moves forward in the zooming stack (if it exists)
+ - : moves backward in the zooming stack (if it exists)
+
+ """
+ ev.accept()
+ if ev.text() == "-":
+ self.scaleBy([1.1, 1.1])
+ elif ev.text() in ["+", "="]:
+ self.scaleBy([0.9, 0.9])
+ else:
+ ev.ignore()
+
+
+class ImageDraw(pg.ImageItem):
+ """
+ **Bases:** :class:`GraphicsObject `
+ GraphicsObject displaying an image. Optimized for rapid update (ie video display).
+ This item displays either a 2D numpy array (height, width) or
+ a 3D array (height, width, RGBa). This array is optionally scaled (see
+ :func:`setLevels `) and/or colored
+ with a lookup table (see :func:`setLookupTable `)
+ before being displayed.
+ ImageItem is frequently used in conjunction with
+ :class:`HistogramLUTItem ` or
+ :class:`HistogramLUTWidget ` to provide a GUI
+ for controlling the levels and lookup table used to display the image.
+ """
+
+ sigImageChanged = QtCore.Signal()
+
+ def __init__(self, image=None, viewbox=None, parent=None, **kargs):
+ super(ImageDraw, self).__init__()
+ #self.image=None
+ #self.viewbox=viewbox
+ self.levels = np.array([0, 255])
+ self.lut = None
+ self.autoDownsample = False
+ self.axisOrder = "row-major"
+ self.removable = False
+
+ self.parent = parent
+ #kernel[1,1] = 1
+ self.setDrawKernel(kernel_size=self.parent.brush_size)
+ self.parent.current_stroke = []
+ self.parent.in_stroke = False
+
+ def mouseClickEvent(self, ev):
+ if (self.parent.masksOn or
+ self.parent.outlinesOn) and not self.parent.removing_region:
+ is_right_click = ev.button() == QtCore.Qt.RightButton
+ if self.parent.loaded \
+ and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\
+ and not self.parent.deleting_multiple:
+ if not self.parent.in_stroke:
+ ev.accept()
+ self.create_start(ev.pos())
+ self.parent.stroke_appended = False
+ self.parent.in_stroke = True
+ self.drawAt(ev.pos(), ev)
+ else:
+ ev.accept()
+ self.end_stroke()
+ self.parent.in_stroke = False
+ elif not self.parent.in_stroke:
+ y, x = int(ev.pos().y()), int(ev.pos().x())
+ if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx:
+ if ev.button() == QtCore.Qt.LeftButton and not ev.double():
+ idx = self.parent.cellpix[self.parent.currentZ][y, x]
+ if idx > 0:
+ if ev.modifiers() & QtCore.Qt.ControlModifier:
+ # delete mask selected
+ self.parent.remove_cell(idx)
+ elif ev.modifiers() & QtCore.Qt.AltModifier:
+ self.parent.merge_cells(idx)
+ elif self.parent.masksOn and not self.parent.deleting_multiple:
+ self.parent.unselect_cell()
+ self.parent.select_cell(idx)
+ elif self.parent.deleting_multiple:
+ if idx in self.parent.removing_cells_list:
+ self.parent.unselect_cell_multi(idx)
+ self.parent.removing_cells_list.remove(idx)
+ else:
+ self.parent.select_cell_multi(idx)
+ self.parent.removing_cells_list.append(idx)
+
+ elif self.parent.masksOn and not self.parent.deleting_multiple:
+ self.parent.unselect_cell()
+
+ def mouseDragEvent(self, ev):
+ ev.ignore()
+ return
+
+ def hoverEvent(self, ev):
+ #QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.CrossCursor)
+ if self.parent.in_stroke:
+ if self.parent.in_stroke:
+ # continue stroke if not at start
+ self.drawAt(ev.pos())
+ if self.is_at_start(ev.pos()):
+ #self.parent.in_stroke = False
+ self.end_stroke()
+ else:
+ ev.acceptClicks(QtCore.Qt.RightButton)
+ #ev.acceptClicks(QtCore.Qt.LeftButton)
+
+ def create_start(self, pos):
+ self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False,
+ pen=pg.mkPen(color=(255, 0, 0),
+ width=self.parent.brush_size),
+ size=max(3 * 2,
+ self.parent.brush_size * 1.8 * 2),
+ brush=None)
+ self.parent.p0.addItem(self.scatter)
+
+ def is_at_start(self, pos):
+ thresh_out = max(6, self.parent.brush_size * 3)
+ thresh_in = max(3, self.parent.brush_size * 1.8)
+ # first check if you ever left the start
+ if len(self.parent.current_stroke) > 3:
+ stroke = np.array(self.parent.current_stroke)
+ dist = (((stroke[1:, 1:] -
+ stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5
+ dist = dist.flatten()
+ #print(dist)
+ has_left = (dist > thresh_out).nonzero()[0]
+ if len(has_left) > 0:
+ first_left = np.sort(has_left)[0]
+ has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum()
+ if has_returned > 0:
+ return True
+ else:
+ return False
+ else:
+ return False
+
+ def end_stroke(self):
+ self.parent.p0.removeItem(self.scatter)
+ if not self.parent.stroke_appended:
+ self.parent.strokes.append(self.parent.current_stroke)
+ self.parent.stroke_appended = True
+ self.parent.current_stroke = np.array(self.parent.current_stroke)
+ ioutline = self.parent.current_stroke[:, 3] == 1
+ self.parent.current_point_set.append(
+ list(self.parent.current_stroke[ioutline]))
+ self.parent.current_stroke = []
+ if self.parent.autosave:
+ self.parent.add_set()
+ if len(self.parent.current_point_set) and len(
+ self.parent.current_point_set[0]) > 0 and self.parent.autosave:
+ self.parent.add_set()
+ self.parent.in_stroke = False
+
+ def tabletEvent(self, ev):
+ pass
+ #print(ev.device())
+ #print(ev.pointerType())
+ #print(ev.pressure())
+
+ def drawAt(self, pos, ev=None):
+ mask = self.strokemask
+ stroke = self.parent.current_stroke
+ pos = [int(pos.y()), int(pos.x())]
+ dk = self.drawKernel
+ kc = self.drawKernelCenter
+ sx = [0, dk.shape[0]]
+ sy = [0, dk.shape[1]]
+ tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]]
+ ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]]
+ kcent = kc.copy()
+ if tx[0] <= 0:
+ sx[0] = 0
+ sx[1] = kc[0] + 1
+ tx = sx
+ kcent[0] = 0
+ if ty[0] <= 0:
+ sy[0] = 0
+ sy[1] = kc[1] + 1
+ ty = sy
+ kcent[1] = 0
+ if tx[1] >= self.parent.Ly - 1:
+ sx[0] = dk.shape[0] - kc[0] - 1
+ sx[1] = dk.shape[0]
+ tx[0] = self.parent.Ly - kc[0] - 1
+ tx[1] = self.parent.Ly
+ kcent[0] = tx[1] - tx[0] - 1
+ if ty[1] >= self.parent.Lx - 1:
+ sy[0] = dk.shape[1] - kc[1] - 1
+ sy[1] = dk.shape[1]
+ ty[0] = self.parent.Lx - kc[1] - 1
+ ty[1] = self.parent.Lx
+ kcent[1] = ty[1] - ty[0] - 1
+
+ ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1]))
+ ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1]))
+ self.image[ts] = mask[ss]
+
+ for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)):
+ for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)):
+ iscent = np.logical_and(kx == kcent[0], ky == kcent[1])
+ stroke.append([self.parent.currentZ, x, y, iscent])
+ self.updateImage()
+
+ def setDrawKernel(self, kernel_size=3):
+ bs = kernel_size
+ kernel = np.ones((bs, bs), np.uint8)
+ self.drawKernel = kernel
+ self.drawKernelCenter = [
+ int(np.floor(kernel.shape[0] / 2)),
+ int(np.floor(kernel.shape[1] / 2))
+ ]
+ onmask = 255 * kernel[:, :, np.newaxis]
+ offmask = np.zeros((bs, bs, 1))
+ opamask = 100 * kernel[:, :, np.newaxis]
+ self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1)
+ self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)
diff --git a/cellpose/gui/guitrainhelpwindowtext.html b/cellpose/gui/guitrainhelpwindowtext.html
new file mode 100644
index 0000000000000000000000000000000000000000..f198359113ba0e549f5174e564e549034fbacfb0
--- /dev/null
+++ b/cellpose/gui/guitrainhelpwindowtext.html
@@ -0,0 +1,25 @@
+
+ Check out this video to learn the process.
+
+ - Drag and drop an image from a folder of images with a similar style (like similar cell types).
+ - Run the built-in models on one of the images using the "model zoo" and find the one that works best for your
+ data. Make sure that if you have a nuclear channel you have selected it for CHAN2.
+
+ - Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI
+ autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The
+ segmentation is saved in a "_seg.npy" file.
+
+ - Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T.
+
+ - Choose the pretrained model to start the training from (the model you used in #2), and type in the model
+ name that you want to use. The other parameters should work well in general for most data types. Then click
+ OK.
+
+ - The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder.
+ Next you can repeat #3-#5 as many times as is necessary.
+
+ - The trained model is available to use in the future in the GUI in the "custom model" section and is saved
+ in your image folder.
+
+
+
\ No newline at end of file
diff --git a/cellpose/gui/io.py b/cellpose/gui/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dc2e0cd2bc2ba02634c134638735042f411a71d
--- /dev/null
+++ b/cellpose/gui/io.py
@@ -0,0 +1,711 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os, datetime, gc, warnings, glob, shutil, copy
+from natsort import natsorted
+import numpy as np
+import cv2
+import tifffile
+import logging
+import fastremap
+
+from ..io import imread, imsave, outlines_to_text, add_model, remove_model, save_rois
+from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models
+from ..utils import masks_to_outlines, outlines_list
+
+try:
+ import qtpy
+ from qtpy.QtWidgets import QFileDialog
+ GUI = True
+except:
+ GUI = False
+
+try:
+ import matplotlib.pyplot as plt
+ MATPLOTLIB = True
+except:
+ MATPLOTLIB = False
+
+
+def _init_model_list(parent):
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
+ parent.model_list_path = MODEL_LIST_PATH
+ parent.model_strings = get_user_models()
+
+
+def _add_model(parent, filename=None, load_model=True):
+ if filename is None:
+ name = QFileDialog.getOpenFileName(parent, "Add model to GUI")
+ filename = name[0]
+ add_model(filename)
+ fname = os.path.split(filename)[-1]
+ parent.ModelChooseC.addItems([fname])
+ parent.model_strings.append(fname)
+
+ for ind, model_string in enumerate(parent.model_strings[:-1]):
+ if model_string == fname:
+ _remove_model(parent, ind=ind + 1, verbose=False)
+
+ parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
+ if load_model:
+ parent.model_choose(custom=True)
+
+
+def _remove_model(parent, ind=None, verbose=True):
+ if ind is None:
+ ind = parent.ModelChooseC.currentIndex()
+ if ind > 0:
+ ind -= 1
+ parent.ModelChooseC.removeItem(ind + 1)
+ del parent.model_strings[ind]
+ # remove model from txt path
+ modelstr = parent.ModelChooseC.currentText()
+ remove_model(modelstr)
+ if len(parent.model_strings) > 0:
+ parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
+ else:
+ parent.ModelChooseC.setCurrentIndex(0)
+ else:
+ print("ERROR: no model selected to delete")
+
+
+def _get_train_set(image_names):
+ """ get training data and labels for images in current folder image_names"""
+ train_data, train_labels, train_files = [], [], []
+ restore = None
+ normalize_params = normalize_default
+ for image_name_full in image_names:
+ image_name = os.path.splitext(image_name_full)[0]
+ label_name = None
+ if os.path.exists(image_name + "_seg.npy"):
+ dat = np.load(image_name + "_seg.npy", allow_pickle=True).item()
+ masks = dat["masks"].squeeze()
+ if masks.ndim == 2:
+ fastremap.renumber(masks, in_place=True)
+ label_name = image_name + "_seg.npy"
+ else:
+ print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2")
+ if "img_restore" in dat:
+ data = dat["img_restore"].squeeze()
+ restore = dat["restore"]
+ else:
+ data = imread(image_name_full)
+ normalize_params = dat[
+ "normalize_params"] if "normalize_params" in dat else normalize_default
+ if label_name is not None:
+ train_files.append(image_name_full)
+ train_data.append(data)
+ train_labels.append(masks)
+ if restore:
+ print(f"GUI_INFO: using {restore} images (dat['img_restore'])")
+ return train_data, train_labels, train_files, restore, normalize_params
+
+
+def _load_image(parent, filename=None, load_seg=True, load_3D=False):
+ """ load image with filename; if None, open QFileDialog """
+ if filename is None:
+ name = QFileDialog.getOpenFileName(parent, "Load image")
+ filename = name[0]
+ if filename == "":
+ return
+ manual_file = os.path.splitext(filename)[0] + "_seg.npy"
+ load_mask = False
+ if load_seg:
+ if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked():
+ _load_seg(parent, manual_file, image=imread(filename), image_file=filename,
+ load_3D=load_3D)
+ return
+ elif parent.autoloadMasks.isChecked():
+ mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext(
+ filename)[-1]
+ mask_file = os.path.splitext(filename)[
+ 0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file
+ load_mask = True if os.path.isfile(mask_file) else False
+ try:
+ print(f"GUI_INFO: loading image: {filename}")
+ image = imread(filename)
+ parent.loaded = True
+ except Exception as e:
+ print("ERROR: images not compatible")
+ print(f"ERROR: {e}")
+
+ if parent.loaded:
+ parent.reset()
+ parent.filename = filename
+ filename = os.path.split(parent.filename)[-1]
+ _initialize_images(parent, image, load_3D=load_3D)
+ parent.loaded = True
+ parent.enable_buttons()
+ if load_mask:
+ _load_masks(parent, filename=mask_file)
+
+
+def _initialize_images(parent, image, load_3D=False):
+ """ format image for GUI
+
+ assumes image is Z x channels x W x H
+
+ """
+ load_3D = parent.load_3D if load_3D is False else load_3D
+ parent.nchan = 3
+ if image.ndim > 4:
+ image = image.squeeze()
+ if image.ndim > 4:
+ raise ValueError("cannot load 4D stack, reduce dimensions")
+ elif image.ndim == 1:
+ raise ValueError("cannot load 1D stack, increase dimensions")
+ if image.ndim == 4:
+ if not load_3D:
+ raise ValueError(
+ "cannot load 3D stack, run 'python -m cellpose --Zstack' for 3D GUI")
+ else:
+ # check if tiff is channels first
+ if image.shape[0] < 4 and image.shape[0] == min(image.shape) and image.shape[0] < image.shape[1]:
+ # tiff is channels x Z x W x H => Z x channels x W x H
+ image = image.transpose((1, 0, 2, 3))
+ image = np.transpose(image, (0, 2, 3, 1))
+ elif image.ndim == 3:
+ if not load_3D:
+ # assume smallest dimension is channels and put last
+ c = np.array(image.shape).argmin()
+ image = image.transpose(((c + 1) % 3, (c + 2) % 3, c))
+ elif load_3D:
+ # assume smallest dimension is Z and put first if <3x max dim
+ shape = np.array(image.shape)
+ z = shape.argmin()
+ if shape[z] < shape.max() / 3:
+ image = image.transpose((z, (z + 1) % 3, (z + 2) % 3))
+ image = image[..., np.newaxis]
+ elif image.ndim == 2:
+ if not load_3D:
+ image = image[..., np.newaxis]
+ else:
+ raise ValueError(
+ "cannot load 2D stack in 3D mode, run 'python -m cellpose' for 2D GUI")
+ if image.shape[-1] > 3:
+ print("WARNING: image has more than 3 channels, keeping only first 3")
+ image = image[..., :3]
+ elif image.shape[-1] == 2:
+ # fill in with blank channels to make 3 channels
+ shape = image.shape
+ image = np.concatenate(
+ (image, np.zeros((*shape[:-1], 3 - shape[-1]), dtype=np.uint8)), axis=-1)
+ parent.nchan = 2
+ elif image.shape[-1] == 1:
+ parent.nchan = 1
+
+ parent.stack = image
+ print(f"GUI_INFO: image shape: {image.shape}")
+ if load_3D:
+ parent.NZ = len(parent.stack)
+ parent.scroll.setMaximum(parent.NZ - 1)
+ else:
+ parent.NZ = 1
+ parent.stack = parent.stack[np.newaxis, ...]
+
+ img_min = image.min()
+ img_max = image.max()
+ parent.stack = parent.stack.astype(np.float32)
+ parent.stack -= img_min
+ if img_max > img_min + 1e-3:
+ parent.stack /= (img_max - img_min)
+ parent.stack *= 255
+
+ if load_3D:
+ print("GUI_INFO: converted to float and normalized values to 0.0->255.0")
+
+ del image
+ gc.collect()
+
+ parent.imask = 0
+ parent.Ly, parent.Lx = parent.stack.shape[-3:-1]
+ parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1]
+ parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8")
+ if hasattr(parent, "stack_filtered"):
+ parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1]
+ elif parent.restore and "upsample" in parent.restore:
+ parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx *
+ parent.ratio)
+ else:
+ parent.Lyr, parent.Lxr = parent.Ly, parent.Lx
+ parent.clear_all()
+
+ if not hasattr(parent, "stack_filtered") and parent.restore:
+ print("GUI_INFO: no 'img_restore' found, applying current settings")
+ parent.compute_restore()
+
+ if parent.autobtn.isChecked():
+ if parent.restore is None or parent.restore != "filter":
+ print(
+ "GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
+ )
+ parent.compute_saturation()
+ elif len(parent.saturation) != parent.NZ:
+ parent.saturation = []
+ for r in range(3):
+ parent.saturation.append([])
+ for n in range(parent.NZ):
+ parent.saturation[-1].append([0, 255])
+ parent.sliders[r].setValue([0, 255])
+ parent.compute_scale()
+ parent.track_changes = []
+
+ if load_3D:
+ parent.currentZ = int(np.floor(parent.NZ / 2))
+ parent.scroll.setValue(parent.currentZ)
+ parent.zpos.setText(str(parent.currentZ))
+ else:
+ parent.currentZ = 0
+
+
+def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False):
+ """ load *_seg.npy with filename; if None, open QFileDialog """
+ if filename is None:
+ name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy")
+ filename = name[0]
+ try:
+ dat = np.load(filename, allow_pickle=True).item()
+ # check if there are keys in filename
+ dat["outlines"]
+ parent.loaded = True
+ except:
+ parent.loaded = False
+ print("ERROR: not NPY")
+ return
+
+ parent.reset()
+ if image is None:
+ found_image = False
+ if "filename" in dat:
+ parent.filename = dat["filename"]
+ if os.path.isfile(parent.filename):
+ parent.filename = dat["filename"]
+ found_image = True
+ else:
+ imgname = os.path.split(parent.filename)[1]
+ root = os.path.split(filename)[0]
+ parent.filename = root + "/" + imgname
+ if os.path.isfile(parent.filename):
+ found_image = True
+ if found_image:
+ try:
+ print(parent.filename)
+ image = imread(parent.filename)
+ except:
+ parent.loaded = False
+ found_image = False
+ print("ERROR: cannot find image file, loading from npy")
+ if not found_image:
+ parent.filename = filename[:-8]
+ print(parent.filename)
+ if "img" in dat:
+ image = dat["img"]
+ else:
+ print("ERROR: no image file found and no image in npy")
+ return
+ else:
+ parent.filename = image_file
+
+ parent.restore = None
+ parent.ratio = 1.
+
+ if "normalize_params" in dat:
+ parent.restore = None if "restore" not in dat else dat["restore"]
+ print(f"GUI_INFO: restore: {parent.restore}")
+ parent.set_normalize_params(dat["normalize_params"])
+ parent.set_restore_button()
+
+ if "img_restore" in dat:
+ img = dat["img_restore"]
+ img_min = img.min()
+ img_max = img.max()
+ parent.stack_filtered = img.astype("float32")
+ parent.stack_filtered -= img_min
+ if img_max > img_min + 1e-3:
+ parent.stack_filtered /= (img_max - img_min)
+ parent.stack_filtered *= 255
+ if parent.stack_filtered.ndim < 4:
+ parent.stack_filtered = parent.stack_filtered[np.newaxis, ...]
+ if parent.stack_filtered.ndim < 4:
+ parent.stack_filtered = parent.stack_filtered[..., np.newaxis]
+ shape = parent.stack_filtered.shape
+ if shape[-1] == 2:
+ if "chan_choose" in dat:
+ channels = np.array(dat["chan_choose"]) - 1
+ img = np.zeros((*shape[:-1], 3), dtype="float32")
+ img[..., channels] = parent.stack_filtered
+ parent.stack_filtered = img
+ else:
+ parent.stack_filtered = np.concatenate(
+ (parent.stack_filtered, np.zeros(
+ (*shape[:-1], 1), dtype="float32")), axis=-1)
+ elif shape[-1] > 3:
+ parent.stack_filtered = parent.stack_filtered[..., :3]
+
+ parent.restore = dat["restore"]
+ parent.ViewDropDown.model().item(parent.ViewDropDown.count() -
+ 1).setEnabled(True)
+ parent.view = parent.ViewDropDown.count() - 1
+ if parent.restore and "upsample" in parent.restore:
+ print(parent.stack_filtered.shape, image.shape)
+ parent.ratio = dat["ratio"]
+
+ parent.set_restore_button()
+
+ _initialize_images(parent, image, load_3D=load_3D)
+ print(parent.stack.shape)
+ if "chan_choose" in dat:
+ parent.ChannelChoose[0].setCurrentIndex(dat["chan_choose"][0])
+ parent.ChannelChoose[1].setCurrentIndex(dat["chan_choose"][1])
+
+ if "outlines" in dat:
+ if isinstance(dat["outlines"], list):
+ # old way of saving files
+ dat["outlines"] = dat["outlines"][::-1]
+ for k, outline in enumerate(dat["outlines"]):
+ if "colors" in dat:
+ color = dat["colors"][k]
+ else:
+ col_rand = np.random.randint(1000)
+ color = parent.colormap[col_rand, :3]
+ median = parent.add_mask(points=outline, color=color)
+ if median is not None:
+ parent.cellcolors = np.append(parent.cellcolors,
+ color[np.newaxis, :], axis=0)
+ parent.ncells += 1
+ else:
+ if dat["masks"].min() == -1:
+ dat["masks"] += 1
+ dat["outlines"] += 1
+ parent.ncells = dat["masks"].max()
+ if "colors" in dat and len(dat["colors"]) == dat["masks"].max():
+ colors = dat["colors"]
+ else:
+ colors = parent.colormap[:parent.ncells, :3]
+
+ _masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors)
+
+ parent.draw_layer()
+ if "est_diam" in dat:
+ parent.Diameter.setText("%0.1f" % dat["est_diam"])
+ parent.diameter = dat["est_diam"]
+ parent.compute_scale()
+
+ if "manual_changes" in dat:
+ parent.track_changes = dat["manual_changes"]
+ print("GUI_INFO: loaded in previous changes")
+ if "zdraw" in dat:
+ parent.zdraw = dat["zdraw"]
+ else:
+ parent.zdraw = [None for n in range(parent.ncells)]
+ parent.loaded = True
+ #print(f"GUI_INFO: {parent.ncells} masks found in {filename}")
+ else:
+ parent.clear_all()
+
+ parent.ismanual = np.zeros(parent.ncells, bool)
+ if "ismanual" in dat:
+ if len(dat["ismanual"]) == parent.ncells:
+ parent.ismanual = dat["ismanual"]
+
+ if "current_channel" in dat:
+ parent.color = (dat["current_channel"] + 2) % 5
+ parent.RGBDropDown.setCurrentIndex(parent.color)
+
+ if "flows" in dat:
+ parent.flows = dat["flows"]
+ try:
+ if parent.flows[0].shape[-3] != dat["masks"].shape[-2]:
+ Ly, Lx = dat["masks"].shape[-2:]
+ for i in range(len(parent.flows)):
+ parent.flows[i] = cv2.resize(
+ parent.flows[i].squeeze(), (Lx, Ly),
+ interpolation=cv2.INTER_NEAREST)[np.newaxis, ...]
+ if parent.NZ == 1:
+ parent.recompute_masks = True
+ else:
+ parent.recompute_masks = False
+
+ except:
+ try:
+ if len(parent.flows[0]) > 0:
+ parent.flows = parent.flows[0]
+ except:
+ parent.flows = [[], [], [], [], [[]]]
+ parent.recompute_masks = False
+
+ parent.enable_buttons()
+ parent.update_layer()
+ del dat
+ gc.collect()
+
+
+def _load_masks(parent, filename=None):
+ """ load zeros-based masks (0=no cell, 1=cell 1, ...) """
+ if filename is None:
+ name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)")
+ filename = name[0]
+ print(f"GUI_INFO: loading masks: {filename}")
+ masks = imread(filename)
+ outlines = None
+ if masks.ndim > 3:
+ # Z x nchannels x Ly x Lx
+ if masks.shape[-1] > 5:
+ parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2)))
+ outlines = masks[..., 1]
+ masks = masks[..., 0]
+ else:
+ parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2)))
+ masks = masks[..., 0]
+ elif masks.ndim == 3:
+ if masks.shape[-1] < 5:
+ masks = masks[np.newaxis, :, :, 0]
+ elif masks.ndim < 3:
+ masks = masks[np.newaxis, :, :]
+ # masks should be Z x Ly x Lx
+ if masks.shape[0] != parent.NZ:
+ print("ERROR: masks are not same depth (number of planes) as image stack")
+ return
+
+ _masks_to_gui(parent, masks, outlines)
+ if parent.ncells > 0:
+ parent.draw_layer()
+ parent.toggle_mask_ops()
+ del masks
+ gc.collect()
+ parent.update_layer()
+ parent.update_plot()
+
+
+def _masks_to_gui(parent, masks, outlines=None, colors=None):
+ """ masks loaded into GUI """
+ # get unique values
+ shape = masks.shape
+ if len(fastremap.unique(masks)) != masks.max() + 1:
+ print("GUI_INFO: renumbering masks")
+ fastremap.renumber(masks, in_place=True)
+ outlines = None
+ masks = masks.reshape(shape)
+ if masks.ndim == 2:
+ outlines = None
+ masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype(
+ np.uint32)
+ if parent.restore and "upsample" in parent.restore:
+ parent.cellpix_resize = masks.copy()
+ parent.cellpix = parent.cellpix_resize.copy()
+ parent.cellpix_orig = cv2.resize(
+ masks.squeeze(), (parent.Lx0, parent.Ly0),
+ interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
+ parent.resize = True
+ else:
+ parent.cellpix = masks
+ if parent.cellpix.ndim == 2:
+ parent.cellpix = parent.cellpix[np.newaxis, :, :]
+ if parent.restore and "upsample" in parent.restore:
+ if parent.cellpix_resize.ndim == 2:
+ parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :]
+ if parent.cellpix_orig.ndim == 2:
+ parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :]
+
+ print(f"GUI_INFO: {masks.max()} masks found")
+
+ # get outlines
+ if outlines is None: # parent.outlinesOn
+ parent.outpix = np.zeros_like(parent.cellpix)
+ if parent.restore and "upsample" in parent.restore:
+ parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
+ for z in range(parent.NZ):
+ outlines = masks_to_outlines(parent.cellpix[z])
+ parent.outpix[z] = outlines * parent.cellpix[z]
+ if parent.restore and "upsample" in parent.restore:
+ outlines = masks_to_outlines(parent.cellpix_orig[z])
+ parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
+ if z % 50 == 0 and parent.NZ > 1:
+ print("GUI_INFO: plane %d outlines processed" % z)
+ if parent.restore and "upsample" in parent.restore:
+ parent.outpix_resize = parent.outpix.copy()
+ else:
+ parent.outpix = outlines
+ if parent.restore and "upsample" in parent.restore:
+ parent.outpix_resize = parent.outpix.copy()
+ parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
+ for z in range(parent.NZ):
+ outlines = masks_to_outlines(parent.cellpix_orig[z])
+ parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
+ if z % 50 == 0 and parent.NZ > 1:
+ print("GUI_INFO: plane %d outlines processed" % z)
+
+ if parent.outpix.ndim == 2:
+ parent.outpix = parent.outpix[np.newaxis, :, :]
+ if parent.restore and "upsample" in parent.restore:
+ if parent.outpix_resize.ndim == 2:
+ parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :]
+ if parent.outpix_orig.ndim == 2:
+ parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :]
+
+ parent.ncells = parent.cellpix.max()
+ colors = parent.colormap[:parent.ncells, :3] if colors is None else colors
+ print("GUI_INFO: creating cellcolors and drawing masks")
+ parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors),
+ axis=0).astype(np.uint8)
+ if parent.ncells > 0:
+ parent.draw_layer()
+ parent.toggle_mask_ops()
+ parent.ismanual = np.zeros(parent.ncells, bool)
+ parent.zdraw = list(-1 * np.ones(parent.ncells, np.int16))
+
+ if hasattr(parent, "stack_filtered"):
+ parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1)
+ print("set denoised/filtered view")
+ else:
+ parent.ViewDropDown.setCurrentIndex(0)
+
+
+def _save_png(parent):
+ """ save masks to png or tiff (if 3D) """
+ filename = parent.filename
+ base = os.path.splitext(filename)[0]
+ if parent.NZ == 1:
+ if parent.cellpix[0].max() > 65534:
+ print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)")
+ imsave(base + "_cp_masks.tif", parent.cellpix[0])
+ else:
+ print("GUI_INFO: saving 2D masks to png")
+ imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16))
+ else:
+ print("GUI_INFO: saving 3D masks to tiff")
+ imsave(base + "_cp_masks.tif", parent.cellpix)
+
+
+def _save_flows(parent):
+ """ save flows and cellprob to tiff """
+ filename = parent.filename
+ base = os.path.splitext(filename)[0]
+ print("GUI_INFO: saving flows and cellprob to tiff")
+ if len(parent.flows) > 0:
+ imsave(base + "_cp_cellprob.tif", parent.flows[1])
+ for i in range(3):
+ imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i])
+ if len(parent.flows) > 2:
+ imsave(base + "_cp_flows.tif", parent.flows[2])
+ print("GUI_INFO: saved flows and cellprob")
+ else:
+ print("ERROR: no flows or cellprob found")
+
+def _save_rois(parent):
+ """ save masks as rois in .zip file for ImageJ """
+ filename = parent.filename
+ if parent.NZ == 1:
+ print(
+ f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.")
+ save_rois(parent.cellpix[0], parent.filename)
+ else:
+ print("ERROR: cannot save 3D outlines")
+
+
+def _save_outlines(parent):
+ filename = parent.filename
+ base = os.path.splitext(filename)[0]
+ if parent.NZ == 1:
+ print(
+ "GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ"
+ )
+ outlines = outlines_list(parent.cellpix[0])
+ outlines_to_text(base, outlines)
+ else:
+ print("ERROR: cannot save 3D outlines")
+
+
+def _save_sets_with_check(parent):
+ """ Save masks and update *_seg.npy file. Use this function when saving should be optional
+ based on the disableAutosave checkbox. Otherwise, use _save_sets """
+ if not parent.disableAutosave.isChecked():
+ _save_sets(parent)
+
+
+def _save_sets(parent):
+ """ save masks to *_seg.npy. This function should be used when saving
+ is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check
+ """
+ filename = parent.filename
+ base = os.path.splitext(filename)[0]
+ flow_threshold, cellprob_threshold = parent.get_thresholds()
+ if parent.NZ > 1:
+ dat = {
+ "outlines":
+ parent.outpix,
+ "colors":
+ parent.cellcolors[1:],
+ "masks":
+ parent.cellpix,
+ "current_channel": (parent.color - 2) % 5,
+ "filename":
+ parent.filename,
+ "flows":
+ parent.flows,
+ "zdraw":
+ parent.zdraw,
+ "model_path":
+ parent.current_model_path
+ if hasattr(parent, "current_model_path") else 0,
+ "flow_threshold":
+ flow_threshold,
+ "cellprob_threshold":
+ cellprob_threshold,
+ "normalize_params":
+ parent.get_normalize_params(),
+ "restore":
+ parent.restore,
+ "ratio":
+ parent.ratio,
+ "diameter":
+ parent.diameter
+ }
+ if parent.restore is not None:
+ dat["img_restore"] = parent.stack_filtered
+ np.save(base + "_seg.npy", dat)
+ else:
+ dat = {
+ "outlines":
+ parent.outpix.squeeze() if parent.restore is None or
+ not "upsample" in parent.restore else parent.outpix_resize.squeeze(),
+ "colors":
+ parent.cellcolors[1:],
+ "masks":
+ parent.cellpix.squeeze() if parent.restore is None or
+ not "upsample" in parent.restore else parent.cellpix_resize.squeeze(),
+ "chan_choose": [
+ parent.ChannelChoose[0].currentIndex(),
+ parent.ChannelChoose[1].currentIndex()
+ ],
+ "filename":
+ parent.filename,
+ "flows":
+ parent.flows,
+ "ismanual":
+ parent.ismanual,
+ "manual_changes":
+ parent.track_changes,
+ "model_path":
+ parent.current_model_path
+ if hasattr(parent, "current_model_path") else 0,
+ "flow_threshold":
+ flow_threshold,
+ "cellprob_threshold":
+ cellprob_threshold,
+ "normalize_params":
+ parent.get_normalize_params(),
+ "restore":
+ parent.restore,
+ "ratio":
+ parent.ratio,
+ "diameter":
+ parent.diameter
+ }
+ if parent.restore is not None:
+ dat["img_restore"] = parent.stack_filtered
+ np.save(base + "_seg.npy", dat)
+ del dat
+ #print(parent.point_sets)
+ print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells, base + "_seg.npy"))
diff --git a/cellpose/gui/make_train.py b/cellpose/gui/make_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..20cb9fc36381c1a133e860ca50db6d9c0cb54d78
--- /dev/null
+++ b/cellpose/gui/make_train.py
@@ -0,0 +1,104 @@
+import sys, os, argparse, glob, pathlib, time
+import numpy as np
+from natsort import natsorted
+from tqdm import tqdm
+from cellpose import utils, models, io, core, version_str, transforms
+
+
+def main():
+ parser = argparse.ArgumentParser(description='cellpose parameters')
+
+ input_img_args = parser.add_argument_group("input image arguments")
+ input_img_args.add_argument('--dir', default=[], type=str,
+ help='folder containing data to run or train on.')
+ input_img_args.add_argument(
+ '--image_path', default=[], type=str, help=
+ 'if given and --dir not given, run on single image instead of folder (cannot train with this option)'
+ )
+ input_img_args.add_argument(
+ '--look_one_level_down', action='store_true',
+ help='run processing on all subdirectories of current folder')
+ input_img_args.add_argument('--img_filter', default=[], type=str,
+ help='end string for images to run on')
+ input_img_args.add_argument(
+ '--channel_axis', default=None, type=int,
+ help='axis of image which corresponds to image channels')
+ input_img_args.add_argument('--z_axis', default=None, type=int,
+ help='axis of image which corresponds to Z dimension')
+ input_img_args.add_argument(
+ '--chan', default=0, type=int, help=
+ 'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s')
+ input_img_args.add_argument(
+ '--chan2', default=0, type=int, help=
+ 'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
+ )
+ input_img_args.add_argument('--invert', action='store_true',
+ help='invert grayscale channel')
+ input_img_args.add_argument(
+ '--all_channels', action='store_true', help=
+ 'use all channels in image if using own model and images with special channels')
+ input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
+ help="anisotropy of volume in 3D")
+
+
+ # algorithm settings
+ algorithm_args = parser.add_argument_group("algorithm arguments")
+ algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0,
+ type=float, help='high-pass filtering radius. Default: %(default)s')
+ algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int,
+ help='tile normalization block size. Default: %(default)s')
+ algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int,
+ help='number of crops in XY to save per tiff. Default: %(default)s')
+ algorithm_args.add_argument('--crop_size', required=False, default=512, type=int,
+ help='size of random crop to save. Default: %(default)s')
+
+ args = parser.parse_args()
+
+ # find images
+ if len(args.img_filter) > 0:
+ imf = args.img_filter
+ else:
+ imf = None
+
+ if len(args.dir) > 0:
+ image_names = io.get_image_files(args.dir, "_masks", imf=imf,
+ look_one_level_down=args.look_one_level_down)
+ dirname = args.dir
+ else:
+ if os.path.exists(args.image_path):
+ image_names = [args.image_path]
+ dirname = os.path.split(args.image_path)[0]
+ else:
+ raise ValueError(f"ERROR: no file found at {args.image_path}")
+
+ np.random.seed(0)
+ nimg_per_tif = args.nimg_per_tif
+ crop_size = args.crop_size
+ os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True)
+ pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)]
+ npm = ["YX", "ZY", "ZX"]
+ for name in image_names:
+ name0 = os.path.splitext(os.path.split(name)[-1])[0]
+ img0 = io.imread(name)
+ img0 = transforms.convert_image(img0, channels=[args.chan, args.chan2], channel_axis=args.channel_axis, z_axis=args.z_axis)
+ for p in range(3):
+ img = img0.transpose(pm[p]).copy()
+ print(npm[p], img[0].shape)
+ Ly, Lx = img.shape[1:3]
+ imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
+ if args.anisotropy > 1.0 and p > 0:
+ imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx)
+ for k, img in enumerate(imgs):
+ if args.tile_norm:
+ img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
+ if args.sharpen_radius:
+ img = transforms.smooth_sharpen_img(img,
+ sharpen_radius=args.sharpen_radius)
+ ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size)
+ lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size)
+ io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'),
+ img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze())
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cellpose/gui/menus.py b/cellpose/gui/menus.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d7994fcbd35e21596c7f59cea5ceeabfd7ffe43
--- /dev/null
+++ b/cellpose/gui/menus.py
@@ -0,0 +1,148 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import qtpy
+from qtpy.QtWidgets import QAction
+from . import io
+from .. import models
+
+
+def mainmenu(parent):
+ main_menu = parent.menuBar()
+ file_menu = main_menu.addMenu("&File")
+ # load processed data
+ loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent)
+ loadImg.setShortcut("Ctrl+L")
+ loadImg.triggered.connect(lambda: io._load_image(parent))
+ file_menu.addAction(loadImg)
+
+ parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent,
+ checkable=True)
+ parent.autoloadMasks.setChecked(False)
+ file_menu.addAction(parent.autoloadMasks)
+
+ parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent,
+ checkable=True)
+ parent.disableAutosave.setChecked(False)
+ file_menu.addAction(parent.disableAutosave)
+
+ parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent)
+ parent.loadMasks.setShortcut("Ctrl+M")
+ parent.loadMasks.triggered.connect(lambda: io._load_masks(parent))
+ file_menu.addAction(parent.loadMasks)
+ parent.loadMasks.setEnabled(False)
+
+ loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent)
+ loadManual.setShortcut("Ctrl+P")
+ loadManual.triggered.connect(lambda: io._load_seg(parent))
+ file_menu.addAction(loadManual)
+
+ parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent)
+ parent.saveSet.setShortcut("Ctrl+S")
+ parent.saveSet.triggered.connect(lambda: io._save_sets(parent))
+ file_menu.addAction(parent.saveSet)
+ parent.saveSet.setEnabled(False)
+
+ parent.savePNG = QAction("Save masks as P&NG/tif", parent)
+ parent.savePNG.setShortcut("Ctrl+N")
+ parent.savePNG.triggered.connect(lambda: io._save_png(parent))
+ file_menu.addAction(parent.savePNG)
+ parent.savePNG.setEnabled(False)
+
+ parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent)
+ parent.saveOutlines.setShortcut("Ctrl+O")
+ parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent))
+ file_menu.addAction(parent.saveOutlines)
+ parent.saveOutlines.setEnabled(False)
+
+ parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ",
+ parent)
+ parent.saveROIs.setShortcut("Ctrl+R")
+ parent.saveROIs.triggered.connect(lambda: io._save_rois(parent))
+ file_menu.addAction(parent.saveROIs)
+ parent.saveROIs.setEnabled(False)
+
+ parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent)
+ parent.saveFlows.setShortcut("Ctrl+F")
+ parent.saveFlows.triggered.connect(lambda: io._save_flows(parent))
+ file_menu.addAction(parent.saveFlows)
+ parent.saveFlows.setEnabled(False)
+
+
+def editmenu(parent):
+ main_menu = parent.menuBar()
+ edit_menu = main_menu.addMenu("&Edit")
+ parent.undo = QAction("Undo previous mask/trace", parent)
+ parent.undo.setShortcut("Ctrl+Z")
+ parent.undo.triggered.connect(parent.undo_action)
+ parent.undo.setEnabled(False)
+ edit_menu.addAction(parent.undo)
+
+ parent.redo = QAction("Undo remove mask", parent)
+ parent.redo.setShortcut("Ctrl+Y")
+ parent.redo.triggered.connect(parent.undo_remove_action)
+ parent.redo.setEnabled(False)
+ edit_menu.addAction(parent.redo)
+
+ parent.ClearButton = QAction("Clear all masks", parent)
+ parent.ClearButton.setShortcut("Ctrl+0")
+ parent.ClearButton.triggered.connect(parent.clear_all)
+ parent.ClearButton.setEnabled(False)
+ edit_menu.addAction(parent.ClearButton)
+
+ parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent)
+ parent.remcell.setShortcut("Ctrl+Click")
+ parent.remcell.triggered.connect(parent.remove_action)
+ parent.remcell.setEnabled(False)
+ edit_menu.addAction(parent.remcell)
+
+ parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent)
+ parent.mergecell.setEnabled(False)
+ edit_menu.addAction(parent.mergecell)
+
+
+def modelmenu(parent):
+ main_menu = parent.menuBar()
+ io._init_model_list(parent)
+ model_menu = main_menu.addMenu("&Models")
+ parent.addmodel = QAction("Add custom torch model to GUI", parent)
+ #parent.addmodel.setShortcut("Ctrl+A")
+ parent.addmodel.triggered.connect(parent.add_model)
+ parent.addmodel.setEnabled(True)
+ model_menu.addAction(parent.addmodel)
+
+ parent.removemodel = QAction("Remove selected custom model from GUI", parent)
+ #parent.removemodel.setShortcut("Ctrl+R")
+ parent.removemodel.triggered.connect(parent.remove_model)
+ parent.removemodel.setEnabled(True)
+ model_menu.addAction(parent.removemodel)
+
+ parent.newmodel = QAction("&Train new model with image+masks in folder", parent)
+ parent.newmodel.setShortcut("Ctrl+T")
+ parent.newmodel.triggered.connect(parent.new_model)
+ parent.newmodel.setEnabled(False)
+ model_menu.addAction(parent.newmodel)
+
+ openTrainHelp = QAction("Training instructions", parent)
+ openTrainHelp.triggered.connect(parent.train_help_window)
+ model_menu.addAction(openTrainHelp)
+
+
+def helpmenu(parent):
+ main_menu = parent.menuBar()
+ help_menu = main_menu.addMenu("&Help")
+
+ openHelp = QAction("&Help with GUI", parent)
+ openHelp.setShortcut("Ctrl+H")
+ openHelp.triggered.connect(parent.help_window)
+ help_menu.addAction(openHelp)
+
+ openGUI = QAction("&GUI layout", parent)
+ openGUI.setShortcut("Ctrl+G")
+ openGUI.triggered.connect(parent.gui_window)
+ help_menu.addAction(openGUI)
+
+ openTrainHelp = QAction("Training instructions", parent)
+ openTrainHelp.triggered.connect(parent.train_help_window)
+ help_menu.addAction(openTrainHelp)
diff --git a/cellpose/io.py b/cellpose/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5f30f5cee3a34cded07a2a14be66ae7813829d3
--- /dev/null
+++ b/cellpose/io.py
@@ -0,0 +1,756 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os, datetime, gc, warnings, glob, shutil
+from natsort import natsorted
+import numpy as np
+import cv2
+import tifffile
+import logging, pathlib, sys
+from tqdm import tqdm
+from pathlib import Path
+import re
+from . import version_str
+from roifile import ImagejRoi, roiwrite
+
+try:
+ from qtpy import QtGui, QtCore, Qt, QtWidgets
+ from qtpy.QtWidgets import QMessageBox
+ GUI = True
+except:
+ GUI = False
+
+try:
+ import matplotlib.pyplot as plt
+ MATPLOTLIB = True
+except:
+ MATPLOTLIB = False
+
+try:
+ import nd2
+ ND2 = True
+except:
+ ND2 = False
+
+try:
+ import nrrd
+ NRRD = True
+except:
+ NRRD = False
+
+try:
+ from google.cloud import storage
+ SERVER_UPLOAD = True
+except:
+ SERVER_UPLOAD = False
+
+io_logger = logging.getLogger(__name__)
+
+def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None):
+ cp_dir = pathlib.Path.home().joinpath(cp_path)
+ cp_dir.mkdir(exist_ok=True)
+ log_file = cp_dir.joinpath(logfile_name)
+ try:
+ log_file.unlink()
+ except:
+ print('creating new log file')
+ handlers = [logging.FileHandler(log_file),]
+ if stdout_file_replacement is not None:
+ handlers.append(logging.FileHandler(stdout_file_replacement))
+ else:
+ handlers.append(logging.StreamHandler(sys.stdout))
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ handlers=handlers,
+ )
+ logger = logging.getLogger(__name__)
+ logger.info(f"WRITING LOG OUTPUT TO {log_file}")
+ logger.info(version_str)
+ #logger.handlers[1].stream = sys.stdout
+
+ return logger, log_file
+
+
+from . import utils, plot, transforms
+
+# helper function to check for a path; if it doesn't exist, make it
+def check_dir(path):
+ if not os.path.isdir(path):
+ os.mkdir(path)
+
+
+def outlines_to_text(base, outlines):
+ with open(base + "_cp_outlines.txt", "w") as f:
+ for o in outlines:
+ xy = list(o.flatten())
+ xy_str = ",".join(map(str, xy))
+ f.write(xy_str)
+ f.write("\n")
+
+
+def load_dax(filename):
+ ### modified from ZhuangLab github:
+ ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
+
+ inf_filename = os.path.splitext(filename)[0] + ".inf"
+ if not os.path.exists(inf_filename):
+ io_logger.critical(
+ f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
+ )
+ return None
+
+ ### get metadata
+ image_height, image_width = None, None
+ # extract the movie information from the associated inf file
+ size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
+ length_re = re.compile(r"number of frames = ([\d]+)")
+ endian_re = re.compile(r" (big|little) endian")
+
+ with open(inf_filename, "r") as inf_file:
+ lines = inf_file.read().split("\n")
+ for line in lines:
+ m = size_re.match(line)
+ if m:
+ image_height = int(m.group(2))
+ image_width = int(m.group(1))
+ m = length_re.match(line)
+ if m:
+ number_frames = int(m.group(1))
+ m = endian_re.search(line)
+ if m:
+ if m.group(1) == "big":
+ bigendian = 1
+ else:
+ bigendian = 0
+ # set defaults, warn the user that they couldn"t be determined from the inf file.
+ if not image_height:
+ io_logger.warning("could not determine dax image size, assuming 256x256")
+ image_height = 256
+ image_width = 256
+
+ ### load image
+ img = np.memmap(filename, dtype="uint16",
+ shape=(number_frames, image_height, image_width))
+ if bigendian:
+ img = img.byteswap()
+ img = np.array(img)
+
+ return img
+
+
+def imread(filename):
+ """
+ Read in an image file with tif or image file type supported by cv2.
+
+ Args:
+ filename (str): The path to the image file.
+
+ Returns:
+ numpy.ndarray: The image data as a NumPy array.
+
+ Raises:
+ None
+
+ Raises an error if the image file format is not supported.
+
+ Examples:
+ >>> img = imread("image.tif")
+ """
+ # ensure that extension check is not case sensitive
+ ext = os.path.splitext(filename)[-1].lower()
+ if ext == ".tif" or ext == ".tiff" or ext == ".flex":
+ with tifffile.TiffFile(filename) as tif:
+ ltif = len(tif.pages)
+ try:
+ full_shape = tif.shaped_metadata[0]["shape"]
+ except:
+ try:
+ page = tif.series[0][0]
+ full_shape = tif.series[0].shape
+ except:
+ ltif = 0
+ if ltif < 10:
+ img = tif.asarray()
+ else:
+ page = tif.series[0][0]
+ shape, dtype = page.shape, page.dtype
+ ltif = int(np.prod(full_shape) / np.prod(shape))
+ io_logger.info(f"reading tiff with {ltif} planes")
+ img = np.zeros((ltif, *shape), dtype=dtype)
+ for i, page in enumerate(tqdm(tif.series[0])):
+ img[i] = page.asarray()
+ img = img.reshape(full_shape)
+ return img
+ elif ext == ".dax":
+ img = load_dax(filename)
+ return img
+ elif ext == ".nd2":
+ if not ND2:
+ io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
+ return None
+ elif ext == ".nrrd":
+ if not NRRD:
+ io_logger.critical(
+ "ERROR: need to 'pip install pynrrd' to load in .nrrd file")
+ return None
+ else:
+ img, metadata = nrrd.read(filename)
+ if img.ndim == 3:
+ img = img.transpose(2, 0, 1)
+ return img
+ elif ext != ".npy":
+ try:
+ img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
+ if img.ndim > 2:
+ img = img[..., [2, 1, 0]]
+ return img
+ except Exception as e:
+ io_logger.critical("ERROR: could not read file, %s" % e)
+ return None
+ else:
+ try:
+ dat = np.load(filename, allow_pickle=True).item()
+ masks = dat["masks"]
+ return masks
+ except Exception as e:
+ io_logger.critical("ERROR: could not read masks from file, %s" % e)
+ return None
+
+
+def remove_model(filename, delete=False):
+ """ remove model from .cellpose custom model list """
+ filename = os.path.split(filename)[-1]
+ from . import models
+ model_strings = models.get_user_models()
+ if len(model_strings) > 0:
+ with open(models.MODEL_LIST_PATH, "w") as textfile:
+ for fname in model_strings:
+ textfile.write(fname + "\n")
+ else:
+ # write empty file
+ textfile = open(models.MODEL_LIST_PATH, "w")
+ textfile.close()
+ print(f"{filename} removed from custom model list")
+ if delete:
+ os.remove(os.fspath(models.MODEL_DIR.joinpath(fname)))
+ print("model deleted")
+
+
+def add_model(filename):
+ """ add model to .cellpose models folder to use with GUI or CLI """
+ from . import models
+ fname = os.path.split(filename)[-1]
+ try:
+ shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
+ except shutil.SameFileError:
+ pass
+ print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}")
+ if fname not in models.get_user_models():
+ with open(models.MODEL_LIST_PATH, "a") as textfile:
+ textfile.write(fname + "\n")
+
+
+def imsave(filename, arr):
+ """
+ Saves an image array to a file.
+
+ Args:
+ filename (str): The name of the file to save the image to.
+ arr (numpy.ndarray): The image array to be saved.
+
+ Returns:
+ None
+ """
+ ext = os.path.splitext(filename)[-1].lower()
+ if ext == ".tif" or ext == ".tiff":
+ tifffile.imwrite(filename, arr)
+ else:
+ if len(arr.shape) > 2:
+ arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
+ cv2.imwrite(filename, arr)
+
+
+def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
+ """
+ Finds all images in a folder and its subfolders (if specified) with the given file extensions.
+
+ Args:
+ folder (str): The path to the folder to search for images.
+ mask_filter (str): The filter for mask files.
+ imf (str, optional): The additional filter for image files. Defaults to None.
+ look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False.
+
+ Returns:
+ list: A list of image file paths.
+
+ Raises:
+ ValueError: If no files are found in the specified folder.
+ ValueError: If no images are found in the specified folder with the supported file extensions.
+ ValueError: If no images are found in the specified folder without the mask or flow file endings.
+ """
+ mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1",
+ "_flows_2", "_cellprob", "_masks", mask_filter]
+ image_names = []
+ if imf is None:
+ imf = ""
+
+ folders = []
+ if look_one_level_down:
+ folders = natsorted(glob.glob(os.path.join(folder, "*/")))
+ folders.append(folder)
+ exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"]
+ l0 = 0
+ al = 0
+ for folder in folders:
+ all_files = glob.glob(folder + "/*")
+ al += len(all_files)
+ for ext in exts:
+ image_names.extend(glob.glob(folder + f"/*{imf}{ext}"))
+ image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}"))
+ l0 += len(image_names)
+
+ # return error if no files found
+ if al == 0:
+ raise ValueError("ERROR: no files in --dir folder ")
+ elif l0 == 0:
+ raise ValueError(
+ "ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex"
+ )
+
+ image_names = natsorted(image_names)
+ imn = []
+ for im in image_names:
+ imfile = os.path.splitext(im)[0]
+ igood = all([(len(imfile) > len(mask_filter) and
+ imfile[-len(mask_filter):] != mask_filter) or
+ len(imfile) <= len(mask_filter) for mask_filter in mask_filters])
+ if len(imf) > 0:
+ igood &= imfile[-len(imf):] == imf
+ if igood:
+ imn.append(im)
+
+ image_names = imn
+
+ # remove duplicates
+ image_names = [*set(image_names)]
+ image_names = natsorted(image_names)
+
+ if len(image_names) == 0:
+ raise ValueError(
+ "ERROR: no images in --dir folder without _masks or _flows or _cellprob ending")
+
+ return image_names
+
+def get_label_files(image_names, mask_filter, imf=None):
+ """
+ Get the label files corresponding to the given image names and mask filter.
+
+ Args:
+ image_names (list): List of image names.
+ mask_filter (str): Mask filter to be applied.
+ imf (str, optional): Image file extension. Defaults to None.
+
+ Returns:
+ tuple: A tuple containing the label file names and flow file names (if present).
+ """
+ nimg = len(image_names)
+ label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
+
+ if imf is not None and len(imf) > 0:
+ label_names = [label_names0[n][:-len(imf)] for n in range(nimg)]
+ else:
+ label_names = label_names0
+
+ # check for flows
+ if os.path.exists(label_names0[0] + "_flows.tif"):
+ flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)]
+ else:
+ flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)]
+ if not all([os.path.exists(flow) for flow in flow_names]):
+ io_logger.info(
+ "not all flows are present, running flow generation for all images")
+ flow_names = None
+
+ # check for masks
+ if mask_filter == "_seg.npy":
+ label_names = [label_names[n] + mask_filter for n in range(nimg)]
+ return label_names, None
+
+ if os.path.exists(label_names[0] + mask_filter + ".tif"):
+ label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)]
+ elif os.path.exists(label_names[0] + mask_filter + ".tiff"):
+ label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)]
+ elif os.path.exists(label_names[0] + mask_filter + ".png"):
+ label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)]
+ # todo, allow _seg.npy
+ #elif os.path.exists(label_names[0] + "_seg.npy"):
+ # io_logger.info("labels found as _seg.npy files, converting to tif")
+ else:
+ if not flow_names:
+ raise ValueError("labels not provided with correct --mask_filter")
+ else:
+ label_names = None
+ if not all([os.path.exists(label) for label in label_names]):
+ if not flow_names:
+ raise ValueError(
+ "labels not provided for all images in train and/or test set")
+ else:
+ label_names = None
+
+ return label_names, flow_names
+
+
+def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
+ look_one_level_down=False):
+ """
+ Loads images and corresponding labels from a directory.
+
+ Args:
+ tdir (str): The directory path.
+ mask_filter (str, optional): The filter for mask files. Defaults to "_masks".
+ image_filter (str, optional): The filter for image files. Defaults to None.
+ look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False.
+
+ Returns:
+ tuple: A tuple containing a list of images, a list of labels, and a list of image names.
+ """
+ image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
+ nimg = len(image_names)
+
+ # training data
+ label_names, flow_names = get_label_files(image_names, mask_filter,
+ imf=image_filter)
+
+ images = []
+ labels = []
+ k = 0
+ for n in range(nimg):
+ if (os.path.isfile(label_names[n]) or
+ (flow_names is not None and os.path.isfile(flow_names[0]))):
+ image = imread(image_names[n])
+ if label_names is not None:
+ label = imread(label_names[n])
+ if flow_names is not None:
+ flow = imread(flow_names[n])
+ if flow.shape[0] < 4:
+ label = np.concatenate((label[np.newaxis, :, :], flow), axis=0)
+ else:
+ label = flow
+ images.append(image)
+ labels.append(label)
+ k += 1
+ io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels")
+ return images, labels, image_names
+
+def load_train_test_data(train_dir, test_dir=None, image_filter=None,
+ mask_filter="_masks", look_one_level_down=False):
+ """
+ Loads training and testing data for a Cellpose model.
+
+ Args:
+ train_dir (str): The directory path containing the training data.
+ test_dir (str, optional): The directory path containing the testing data. Defaults to None.
+ image_filter (str, optional): The filter for selecting image files. Defaults to None.
+ mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks".
+ look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False.
+
+ Returns:
+ images, labels, image_names, test_images, test_labels, test_image_names
+
+ """
+ images, labels, image_names = load_images_labels(train_dir, mask_filter,
+ image_filter, look_one_level_down)
+ # testing data
+ test_images, test_labels, test_image_names = None, None, None
+ if test_dir is not None:
+ test_images, test_labels, test_image_names = load_images_labels(
+ test_dir, mask_filter, image_filter, look_one_level_down)
+
+ return images, labels, image_names, test_images, test_labels, test_image_names
+
+
+def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=None,
+ imgs_restore=None, restore_type=None, ratio=1.):
+ """Save output of model eval to be loaded in GUI.
+
+ Can be list output (run on multiple images) or single output (run on single image).
+
+ Saved to file_names[k]+"_seg.npy".
+
+ Args:
+ images (list): Images input into cellpose.
+ masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
+ flows (list): Flows output from Cellpose.eval.
+ file_names (list, str): Names of files of images.
+ diams (float array): Diameters used to run Cellpose. Defaults to 30.
+ channels (list, int, optional): Channels used to run Cellpose. Defaults to None.
+
+ Returns:
+ None
+ """
+
+ if channels is None:
+ channels = [0, 0]
+
+ if isinstance(masks, list):
+ if not isinstance(diams, (list, np.ndarray)):
+ diams = diams * np.ones(len(masks), np.float32)
+ if imgs_restore is None:
+ imgs_restore = [None] * len(masks)
+ if isinstance(file_names, str):
+ file_names = [file_names] * len(masks)
+ for k, [image, mask, flow, diam, file_name, img_restore
+ ] in enumerate(zip(images, masks, flows, diams, file_names,
+ imgs_restore)):
+ channels_img = channels
+ if channels_img is not None and len(channels) > 2:
+ channels_img = channels[k]
+ masks_flows_to_seg(image, mask, flow, file_name, diams=diam,
+ channels=channels_img, imgs_restore=img_restore,
+ restore_type=restore_type, ratio=ratio)
+ return
+
+ if len(channels) == 1:
+ channels = channels[0]
+
+ flowi = []
+ if flows[0].ndim == 3:
+ Ly, Lx = masks.shape[-2:]
+ flowi.append(
+ cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis,
+ ...])
+ else:
+ flowi.append(flows[0])
+
+ if flows[0].ndim == 3:
+ cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(
+ np.uint8)
+ cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
+ flowi.append(cellprob[np.newaxis, ...])
+ flowi.append(np.zeros(flows[0].shape, dtype=np.uint8))
+ flowi[-1] = flowi[-1][np.newaxis, ...]
+ else:
+ flowi.append(
+ (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8))
+ flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8))
+ if len(flows) > 2:
+ if len(flows) > 3:
+ flowi.append(flows[3])
+ else:
+ flowi.append([])
+ flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0))
+ outlines = masks * utils.masks_to_outlines(masks)
+ base = os.path.splitext(file_names)[0]
+
+ dat = {
+ "outlines":
+ outlines.astype(np.uint16) if outlines.max() < 2**16 -
+ 1 else outlines.astype(np.uint32),
+ "masks":
+ masks.astype(np.uint16) if outlines.max() < 2**16 -
+ 1 else masks.astype(np.uint32),
+ "chan_choose":
+ channels,
+ "ismanual":
+ np.zeros(masks.max(), bool),
+ "filename":
+ file_names,
+ "flows":
+ flowi,
+ "diameter":
+ diams
+ }
+ if restore_type is not None and imgs_restore is not None:
+ dat["restore"] = restore_type
+ dat["ratio"] = ratio
+ dat["img_restore"] = imgs_restore
+
+ np.save(base + "_seg.npy", dat)
+
+def save_to_png(images, masks, flows, file_names):
+ """ deprecated (runs io.save_masks with png=True)
+
+ does not work for 3D images
+
+ """
+ save_masks(images, masks, flows, file_names, png=True)
+
+
+def save_rois(masks, file_name, multiprocessing=None):
+ """ save masks to .roi files in .zip archive for ImageJ/Fiji
+
+ Args:
+ masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels
+ file_name (str): name to save the .zip file to
+
+ Returns:
+ None
+ """
+ outlines = utils.outlines_list(masks, multiprocessing=multiprocessing)
+ nonempty_outlines = [outline for outline in outlines if len(outline)!=0]
+ if len(outlines)!=len(nonempty_outlines):
+ print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.")
+ rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines]
+ file_name = os.path.splitext(file_name)[0] + '_rois.zip'
+
+
+ # Delete file if it exists; the roifile lib appends to existing zip files.
+ # If the user removed a mask it will still be in the zip file
+ if os.path.exists(file_name):
+ os.remove(file_name)
+
+ roiwrite(file_name, rois)
+
+
+def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
+ suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False,
+ in_folders=False, savedir=None, save_txt=False, save_mpl=False):
+ """ Save masks + nicely plotted segmentation image to png and/or tiff.
+
+ Can save masks, flows to different directories, if in_folders is True.
+
+ If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png".
+
+ If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif".
+
+ If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png".
+
+ Only tif option works for 3D data, and only tif option works for empty masks.
+
+ Args:
+ images (list): Images input into cellpose.
+ masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
+ flows (list): Flows output from Cellpose.eval.
+ file_names (list, str): Names of files of images.
+ png (bool, optional): Save masks to PNG. Defaults to True.
+ tif (bool, optional): Save masks to TIF. Defaults to False.
+ channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0].
+ suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks".
+ save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False.
+ save_outlines (bool, optional): Save outlines of masks. Defaults to False.
+ dir_above (bool, optional): Save masks/flows in directory above. Defaults to False.
+ in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False.
+ savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
+ save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
+ save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
+ This takes a long time for large images. Defaults to False.
+
+ Returns:
+ None
+ """
+
+ if isinstance(masks, list):
+ for image, mask, flow, file_name in zip(images, masks, flows, file_names):
+ save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix,
+ dir_above=dir_above, save_flows=save_flows,
+ save_outlines=save_outlines, savedir=savedir, save_txt=save_txt,
+ in_folders=in_folders, save_mpl=save_mpl)
+ return
+
+ if masks.ndim > 2 and not tif:
+ raise ValueError("cannot save 3D outputs as PNG, use tif option instead")
+
+ if masks.max() == 0:
+ io_logger.warning("no masks found, will not save PNG or outlines")
+ if not tif:
+ return
+ else:
+ png = False
+ save_outlines = False
+ save_flows = False
+ save_txt = False
+
+ if savedir is None:
+ if dir_above:
+ savedir = Path(file_names).parent.parent.absolute(
+ ) #go up a level to save in its own folder
+ else:
+ savedir = Path(file_names).parent.absolute()
+
+ check_dir(savedir)
+
+ basename = os.path.splitext(os.path.basename(file_names))[0]
+ if in_folders:
+ maskdir = os.path.join(savedir, "masks")
+ outlinedir = os.path.join(savedir, "outlines")
+ txtdir = os.path.join(savedir, "txt_outlines")
+ flowdir = os.path.join(savedir, "flows")
+ else:
+ maskdir = savedir
+ outlinedir = savedir
+ txtdir = savedir
+ flowdir = savedir
+
+ check_dir(maskdir)
+
+ exts = []
+ if masks.ndim > 2:
+ png = False
+ tif = True
+ if png:
+ if masks.max() < 2**16:
+ masks = masks.astype(np.uint16)
+ exts.append(".png")
+ else:
+ png = False
+ tif = True
+ io_logger.warning(
+ "found more than 65535 masks in each image, cannot save PNG, saving as TIF"
+ )
+ if tif:
+ exts.append(".tif")
+
+ # save masks
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ for ext in exts:
+ imsave(os.path.join(maskdir, basename + suffix + ext), masks)
+
+ if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3:
+ # Make and save original/segmentation/flows image
+
+ img = images.copy()
+ if img.ndim < 3:
+ img = img[:, :, np.newaxis]
+ elif img.shape[0] < 8:
+ np.transpose(img, (1, 2, 0))
+
+ fig = plt.figure(figsize=(12, 3))
+ plot.show_segmentation(fig, img, masks, flows[0])
+ fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"),
+ dpi=300)
+ plt.close(fig)
+
+ # ImageJ txt outline files
+ if masks.ndim < 3 and save_txt:
+ check_dir(txtdir)
+ outlines = utils.outlines_list(masks)
+ outlines_to_text(os.path.join(txtdir, basename), outlines)
+
+ # RGB outline images
+ if masks.ndim < 3 and save_outlines:
+ check_dir(outlinedir)
+ outlines = utils.masks_to_outlines(masks)
+ outX, outY = np.nonzero(outlines)
+ img0 = transforms.normalize99(images)
+ if img0.shape[0] < 4:
+ img0 = np.transpose(img0, (1, 2, 0))
+ if img0.shape[-1] < 3 or img0.ndim < 3:
+ img0 = plot.image_to_rgb(img0, channels=channels)
+ else:
+ if img0.max() <= 50.0:
+ img0 = np.uint8(np.clip(img0 * 255, 0, 1))
+ imgout = img0.copy()
+ imgout[outX, outY] = np.array([255, 0, 0]) #pure red
+ imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"),
+ imgout)
+
+ # save RGB flow picture
+ if masks.ndim < 3 and save_flows:
+ check_dir(flowdir)
+ imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"),
+ (flows[0] * (2**16 - 1)).astype(np.uint16))
+ #save full flow data
+ imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1])
diff --git a/cellpose/logo/cellpose.ico b/cellpose/logo/cellpose.ico
new file mode 100644
index 0000000000000000000000000000000000000000..01344e7d286956763ba64ebf5f5032e70657f402
Binary files /dev/null and b/cellpose/logo/cellpose.ico differ
diff --git a/cellpose/logo/logo.png b/cellpose/logo/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..3f4a324287ef48d9f6ef8b8f9d942ef37cca64b4
--- /dev/null
+++ b/cellpose/logo/logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11fd5caf7ce1f3c0f9f4b283ac5f34981e604e8b68a7ede70457dbe8611aa328
+size 29050
diff --git a/cellpose/metrics.py b/cellpose/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e656a16a936a7ea2774dde31f1c784a222b7d4
--- /dev/null
+++ b/cellpose/metrics.py
@@ -0,0 +1,264 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+import numpy as np
+from . import utils, dynamics
+from numba import jit
+from scipy.optimize import linear_sum_assignment
+from scipy.ndimage import convolve, mean
+
+
+def mask_ious(masks_true, masks_pred):
+ """Return best-matched masks."""
+ iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
+ n_min = min(iou.shape[0], iou.shape[1])
+ costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min)
+ true_ind, pred_ind = linear_sum_assignment(costs)
+ iout = np.zeros(masks_true.max())
+ iout[true_ind] = iou[true_ind, pred_ind]
+ preds = np.zeros(masks_true.max(), "int")
+ preds[true_ind] = pred_ind + 1
+ return iout, preds
+
+
+def boundary_scores(masks_true, masks_pred, scales):
+ """
+ Calculate boundary precision, recall, and F-score.
+
+ Args:
+ masks_true (list): List of true masks.
+ masks_pred (list): List of predicted masks.
+ scales (list): List of scales.
+
+ Returns:
+ tuple: A tuple containing precision, recall, and F-score arrays.
+ """
+ diams = [utils.diameters(lbl)[0] for lbl in masks_true]
+ precision = np.zeros((len(scales), len(masks_true)))
+ recall = np.zeros((len(scales), len(masks_true)))
+ fscore = np.zeros((len(scales), len(masks_true)))
+ for j, scale in enumerate(scales):
+ for n in range(len(masks_true)):
+ diam = max(1, scale * diams[n])
+ rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
+ filt = (rs <= diam).astype(np.float32)
+ otrue = utils.masks_to_outlines(masks_true[n])
+ otrue = convolve(otrue, filt)
+ opred = utils.masks_to_outlines(masks_pred[n])
+ opred = convolve(opred, filt)
+ tp = np.logical_and(otrue == 1, opred == 1).sum()
+ fp = np.logical_and(otrue == 0, opred == 1).sum()
+ fn = np.logical_and(otrue == 1, opred == 0).sum()
+ precision[j, n] = tp / (tp + fp)
+ recall[j, n] = tp / (tp + fn)
+ fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
+ return precision, recall, fscore
+
+
+def aggregated_jaccard_index(masks_true, masks_pred):
+ """
+ AJI = intersection of all matched masks / union of all masks
+
+ Args:
+ masks_true (list of np.ndarrays (int) or np.ndarray (int)):
+ where 0=NO masks; 1,2... are mask labels
+ masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
+ np.ndarray (int) where 0=NO masks; 1,2... are mask labels
+
+ Returns:
+ aji (float): aggregated jaccard index for each set of masks
+ """
+ aji = np.zeros(len(masks_true))
+ for n in range(len(masks_true)):
+ iout, preds = mask_ious(masks_true[n], masks_pred[n])
+ inds = np.arange(0, masks_true[n].max(), 1, int)
+ overlap = _label_overlap(masks_true[n], masks_pred[n])
+ union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum()
+ overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)]
+ aji[n] = overlap.sum() / union
+ return aji
+
+
+def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
+ """
+ Average precision estimation: AP = TP / (TP + FP + FN)
+
+ This function is based heavily on the *fast* stardist matching functions
+ (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)
+
+ Args:
+ masks_true (list of np.ndarrays (int) or np.ndarray (int)):
+ where 0=NO masks; 1,2... are mask labels
+ masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
+ np.ndarray (int) where 0=NO masks; 1,2... are mask labels
+
+ Returns:
+ ap (array [len(masks_true) x len(threshold)]):
+ average precision at thresholds
+ tp (array [len(masks_true) x len(threshold)]):
+ number of true positives at thresholds
+ fp (array [len(masks_true) x len(threshold)]):
+ number of false positives at thresholds
+ fn (array [len(masks_true) x len(threshold)]):
+ number of false negatives at thresholds
+ """
+ not_list = False
+ if not isinstance(masks_true, list):
+ masks_true = [masks_true]
+ masks_pred = [masks_pred]
+ not_list = True
+ if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray):
+ threshold = [threshold]
+
+ if len(masks_true) != len(masks_pred):
+ raise ValueError(
+ "metrics.average_precision requires len(masks_true)==len(masks_pred)")
+
+ ap = np.zeros((len(masks_true), len(threshold)), np.float32)
+ tp = np.zeros((len(masks_true), len(threshold)), np.float32)
+ fp = np.zeros((len(masks_true), len(threshold)), np.float32)
+ fn = np.zeros((len(masks_true), len(threshold)), np.float32)
+ n_true = np.array(list(map(np.max, masks_true)))
+ n_pred = np.array(list(map(np.max, masks_pred)))
+
+ for n in range(len(masks_true)):
+ #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)
+ if n_pred[n] > 0:
+ iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:]
+ for k, th in enumerate(threshold):
+ tp[n, k] = _true_positive(iou, th)
+ fp[n] = n_pred[n] - tp[n]
+ fn[n] = n_true[n] - tp[n]
+ ap[n] = tp[n] / (tp[n] + fp[n] + fn[n])
+
+ if not_list:
+ ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0]
+ return ap, tp, fp, fn
+
+
+@jit(nopython=True)
+def _label_overlap(x, y):
+ """Fast function to get pixel overlaps between masks in x and y.
+
+ Args:
+ x (np.ndarray, int): Where 0=NO masks; 1,2... are mask labels.
+ y (np.ndarray, int): Where 0=NO masks; 1,2... are mask labels.
+
+ Returns:
+ overlap (np.ndarray, int): Matrix of pixel overlaps of size [x.max()+1, y.max()+1].
+ """
+ # put label arrays into standard form then flatten them
+ # x = (utils.format_labels(x)).ravel()
+ # y = (utils.format_labels(y)).ravel()
+ x = x.ravel()
+ y = y.ravel()
+
+ # preallocate a "contact map" matrix
+ overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint)
+
+ # loop over the labels in x and add to the corresponding
+ # overlap entry. If label A in x and label B in y share P
+ # pixels, then the resulting overlap is P
+ # len(x)=len(y), the number of pixels in the whole image
+ for i in range(len(x)):
+ overlap[x[i], y[i]] += 1
+ return overlap
+
+
+def _intersection_over_union(masks_true, masks_pred):
+ """Calculate the intersection over union of all mask pairs.
+
+ Parameters:
+ masks_true (np.ndarray, int): Ground truth masks, where 0=NO masks; 1,2... are mask labels.
+ masks_pred (np.ndarray, int): Predicted masks, where 0=NO masks; 1,2... are mask labels.
+
+ Returns:
+ iou (np.ndarray, float): Matrix of IOU pairs of size [x.max()+1, y.max()+1].
+
+ How it works:
+ The overlap matrix is a lookup table of the area of intersection
+ between each set of labels (true and predicted). The true labels
+ are taken to be along axis 0, and the predicted labels are taken
+ to be along axis 1. The sum of the overlaps along axis 0 is thus
+ an array giving the total overlap of the true labels with each of
+ the predicted labels, and likewise the sum over axis 1 is the
+ total overlap of the predicted labels with each of the true labels.
+ Because the label 0 (background) is included, this sum is guaranteed
+ to reconstruct the total area of each label. Adding this row and
+ column vectors gives a 2D array with the areas of every label pair
+ added together. This is equivalent to the union of the label areas
+ except for the duplicated overlap area, so the overlap matrix is
+ subtracted to find the union matrix.
+ """
+ overlap = _label_overlap(masks_true, masks_pred)
+ n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
+ n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
+ iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
+ iou[np.isnan(iou)] = 0.0
+ return iou
+
+
+def _true_positive(iou, th):
+ """Calculate the true positive at threshold th.
+
+ Args:
+ iou (float, np.ndarray): Array of IOU pairs.
+ th (float): Threshold on IOU for positive label.
+
+ Returns:
+ tp (float): Number of true positives at threshold.
+
+ How it works:
+ (1) Find minimum number of masks.
+ (2) Define cost matrix; for a given threshold, each element is negative
+ the higher the IoU is (perfect IoU is 1, worst is 0). The second term
+ gets more negative with higher IoU, but less negative with greater
+ n_min (but that's a constant...).
+ (3) Solve the linear sum assignment problem. The costs array defines the cost
+ of matching a true label with a predicted label, so the problem is to
+ find the set of pairings that minimizes this cost. The scipy.optimize
+ function gives the ordered lists of corresponding true and predicted labels.
+ (4) Extract the IoUs from these pairings and then threshold to get a boolean array
+ whose sum is the number of true positives that is returned.
+ """
+ n_min = min(iou.shape[0], iou.shape[1])
+ costs = -(iou >= th).astype(float) - iou / (2 * n_min)
+ true_ind, pred_ind = linear_sum_assignment(costs)
+ match_ok = iou[true_ind, pred_ind] >= th
+ tp = match_ok.sum()
+ return tp
+
+
+def flow_error(maski, dP_net, device=None):
+ """Error in flows from predicted masks vs flows predicted by network run on image.
+
+ This function serves to benchmark the quality of masks. It works as follows:
+ 1. The predicted masks are used to create a flow diagram.
+ 2. The mask-flows are compared to the flows that the network predicted.
+
+ If there is a discrepancy between the flows, it suggests that the mask is incorrect.
+ Masks with flow_errors greater than 0.4 are discarded by default. This setting can be
+ changed in Cellpose.eval or CellposeModel.eval.
+
+ Args:
+ maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels.
+ dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape.
+
+ Returns:
+ A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks;
+ dP_masks (np.ndarray, float): ND flows produced from the predicted masks.
+ """
+ if dP_net.shape[1:] != maski.shape:
+ print("ERROR: net flow is not same size as predicted masks")
+ return
+
+ # flows predicted from estimated masks
+ dP_masks = dynamics.masks_to_flows(maski, device=device)
+ # difference between predicted flows vs mask flows
+ flow_errors = np.zeros(maski.max())
+ for i in range(dP_masks.shape[0]):
+ flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski,
+ index=np.arange(1,
+ maski.max() + 1))
+
+ return flow_errors, dP_masks
diff --git a/cellpose/models.py b/cellpose/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..40447ebe15b0f8ae4b9ac64825a134ee04774f04
--- /dev/null
+++ b/cellpose/models.py
@@ -0,0 +1,797 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess
+from pathlib import Path
+import numpy as np
+from tqdm import trange, tqdm
+from urllib.parse import urlparse
+import torch
+from scipy.ndimage import gaussian_filter
+#import cv2
+import gc
+
+import logging
+
+models_logger = logging.getLogger(__name__)
+
+from . import transforms, dynamics, utils, plot
+from .resnet_torch import CPnet
+from .core import assign_device, check_mkl, run_net, run_3D
+
+_MODEL_URL = "https://www.cellpose.org/models"
+_MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
+_MODEL_DIR_DEFAULT = pathlib.Path.home().joinpath(".cellpose", "models")
+MODEL_DIR = pathlib.Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
+
+MODEL_NAMES = [
+ "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3",
+ "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", "CPx",
+ "transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer",
+ "neurips_grayscale_cyto2",
+ "CP", "CPx", "TN1", "TN2", "TN3", "LC1", "LC2", "LC3", "LC4"
+]
+
+MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
+
+normalize_default = {
+ "lowhigh": None,
+ "percentile": None,
+ "normalize": True,
+ "norm3D": True,
+ "sharpen_radius": 0,
+ "smooth_radius": 0,
+ "tile_norm_blocksize": 0,
+ "tile_norm_smooth3D": 1,
+ "invert": False
+}
+
+
+def model_path(model_type, model_index=0):
+ torch_str = "torch"
+ if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei":
+ basename = "%s%s_%d" % (model_type, torch_str, model_index)
+ else:
+ basename = model_type
+ return cache_model_path(basename)
+
+
+def size_model_path(model_type):
+ torch_str = "torch"
+ if (model_type == "cyto" or model_type == "nuclei" or
+ model_type == "cyto2" or model_type == "cyto3"):
+ if model_type == "cyto3":
+ basename = "size_%s.npy" % model_type
+ else:
+ basename = "size_%s%s_0.npy" % (model_type, torch_str)
+ return cache_model_path(basename)
+ else:
+ if os.path.exists(model_type) and os.path.exists(model_type + "_size.npy"):
+ return model_type + "_size.npy"
+ else:
+ raise FileNotFoundError(f"size model not found ({model_type + '_size.npy'})")
+
+
+def cache_model_path(basename):
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
+ url = f"{_MODEL_URL}/{basename}"
+ cached_file = os.fspath(MODEL_DIR.joinpath(basename))
+ if not os.path.exists(cached_file):
+ models_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
+ utils.download_url_to_file(url, cached_file, progress=True)
+ return cached_file
+
+
+def get_user_models():
+ model_strings = []
+ if os.path.exists(MODEL_LIST_PATH):
+ with open(MODEL_LIST_PATH, "r") as textfile:
+ lines = [line.rstrip() for line in textfile]
+ if len(lines) > 0:
+ model_strings.extend(lines)
+ return model_strings
+
+
+class Cellpose():
+ """Main model which combines SizeModel and CellposeModel.
+
+ Args:
+ gpu (bool, optional): Whether or not to use GPU, will check if GPU available. Defaults to False.
+ model_type (str, optional): Model type. "cyto"=cytoplasm model; "nuclei"=nucleus model;
+ "cyto2"=cytoplasm model with additional user images;
+ "cyto3"=super-generalist model; Defaults to "cyto3".
+ device (torch device, optional): Device used for model running / training. Overrides gpu input. Recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). Defaults to None.
+
+ Attributes:
+ device (torch device): Device used for model running / training.
+ gpu (bool): Flag indicating if GPU is used.
+ diam_mean (float): Mean diameter for cytoplasm model.
+ cp (CellposeModel): CellposeModel instance.
+ pretrained_size (str): Pretrained size model path.
+ sz (SizeModel): SizeModel instance.
+
+ """
+
+ def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None,
+ backbone="default"):
+ super(Cellpose, self).__init__()
+
+ # assign device (GPU or CPU)
+ sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
+ self.device = device if device is not None else sdevice
+ self.gpu = gpu
+ self.backbone = backbone
+
+ model_type = "cyto3" if model_type is None else model_type
+
+ self.diam_mean = 30. #default for any cyto model
+ nuclear = "nuclei" in model_type
+ if nuclear:
+ self.diam_mean = 17.
+
+ if model_type in ["cyto", "nuclei", "cyto2", "cyto3"] and nchan != 2:
+ nchan = 2
+ models_logger.warning(
+ f"cannot set nchan to other value for {model_type} model")
+ self.nchan = nchan
+
+ self.cp = CellposeModel(device=self.device, gpu=self.gpu, model_type=model_type,
+ diam_mean=self.diam_mean, nchan=self.nchan,
+ backbone=self.backbone)
+ self.cp.model_type = model_type
+
+ # size model not used for bacterial model
+ self.pretrained_size = size_model_path(model_type)
+ self.sz = SizeModel(device=self.device, pretrained_size=self.pretrained_size,
+ cp_model=self.cp)
+ self.sz.model_type = model_type
+
+ def eval(self, x, batch_size=8, channels=[0, 0], channel_axis=None, invert=False,
+ normalize=True, diameter=30., do_3D=False, **kwargs):
+ """Run cellpose size model and mask model and get masks.
+
+ Args:
+ x (list or array): List or array of images. Can be list of 2D/3D images, or array of 2D/3D images, or 4D image array.
+ batch_size (int, optional): Number of 224x224 patches to run simultaneously on the GPU. Can make smaller or bigger depending on GPU memory usage. Defaults to 8.
+ channels (list, optional): List of channels, either of length 2 or of length number of images by 2. First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). For instance, to segment grayscale images, input [0,0]. To segment images with cells in green and nuclei in blue, input [2,3]. To segment one grayscale image and one image with cells in green and nuclei in blue, input [[0,0], [2,3]]. Defaults to [0,0].
+ channel_axis (int, optional): If None, channels dimension is attempted to be automatically determined. Defaults to None.
+ invert (bool, optional): Invert image pixel intensity before running network (if True, image is also normalized). Defaults to False.
+ normalize (bool, optional): If True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; can also pass dictionary of parameters (see CellposeModel for details). Defaults to True.
+ diameter (float, optional): If set to None, then diameter is automatically estimated if size model is loaded. Defaults to 30..
+ do_3D (bool, optional): Set to True to run 3D segmentation on 4D image input. Defaults to False.
+
+ Returns:
+ A tuple containing (masks, flows, styles, diams): masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels;
+ flows (list of lists 2D arrays or list of 3D arrays): flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY flows at each pixel;
+ flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics);
+ flows[k][3] = final pixel locations after Euler integration;
+ styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image;
+ diams (list of diameters or float): List of diameters or float (if do_3D=True).
+
+ """
+
+ tic0 = time.time()
+ models_logger.info(f"channels set to {channels}")
+
+ diam0 = diameter[0] if isinstance(diameter, (np.ndarray, list)) else diameter
+ estimate_size = True if (diameter is None or diam0 == 0) else False
+
+ if estimate_size and self.pretrained_size is not None and not do_3D and x[
+ 0].ndim < 4:
+ tic = time.time()
+ models_logger.info("~~~ ESTIMATING CELL DIAMETER(S) ~~~")
+ diams, _ = self.sz.eval(x, channels=channels, channel_axis=channel_axis,
+ batch_size=batch_size, normalize=normalize,
+ invert=invert)
+ diameter = None
+ models_logger.info("estimated cell diameter(s) in %0.2f sec" %
+ (time.time() - tic))
+ models_logger.info(">>> diameter(s) = ")
+ if isinstance(diams, list) or isinstance(diams, np.ndarray):
+ diam_string = "[" + "".join(["%0.2f, " % d for d in diams]) + "]"
+ else:
+ diam_string = "[ %0.2f ]" % diams
+ models_logger.info(diam_string)
+ elif estimate_size:
+ if self.pretrained_size is None:
+ reason = "no pretrained size model specified in model Cellpose"
+ else:
+ reason = "does not work on non-2D images"
+ models_logger.warning(f"could not estimate diameter, {reason}")
+ diams = self.diam_mean
+ else:
+ diams = diameter
+
+ models_logger.info("~~~ FINDING MASKS ~~~")
+ masks, flows, styles = self.cp.eval(x, channels=channels,
+ channel_axis=channel_axis,
+ batch_size=batch_size, normalize=normalize,
+ invert=invert, diameter=diams, do_3D=do_3D,
+ **kwargs)
+ models_logger.info(">>>> TOTAL TIME %0.2f sec" % (time.time() - tic0))
+
+ return masks, flows, styles, diams
+
+def get_model_params(pretrained_model, model_type, pretrained_model_ortho, default_model="cyto3"):
+ """ return pretrained_model path, diam_mean and if model is builtin """
+ builtin = False
+ use_default = False
+ diam_mean = None
+ model_strings = get_user_models()
+ all_models = MODEL_NAMES.copy()
+ all_models.extend(model_strings)
+
+ # check if pretrained_model is builtin or custom user model saved in .cellpose/models
+ # if yes, then set to model_type
+ if (pretrained_model and not Path(pretrained_model).exists() and
+ np.any([pretrained_model == s for s in all_models])):
+ model_type = pretrained_model
+
+ # check if model_type is builtin or custom user model saved in .cellpose/models
+ if model_type is not None and np.any([model_type == s for s in all_models]):
+ if np.any([model_type == s for s in MODEL_NAMES]):
+ builtin = True
+ models_logger.info(f">> {model_type} << model set to be used")
+ if model_type == "nuclei":
+ diam_mean = 17.
+ pretrained_model = model_path(model_type)
+ # if model_type is not None and does not exist, use default model
+ elif model_type is not None:
+ if Path(model_type).exists():
+ pretrained_model = model_type
+ else:
+ models_logger.warning("model_type does not exist, using default model")
+ use_default = True
+ # if model_type is None...
+ else:
+ # if pretrained_model does not exist, use default model
+ if pretrained_model and not Path(pretrained_model).exists():
+ models_logger.warning(
+ "pretrained_model path does not exist, using default model")
+ use_default = True
+ elif pretrained_model:
+ if pretrained_model[-13:] == "nucleitorch_0":
+ builtin = True
+ diam_mean = 17.
+
+ if pretrained_model_ortho:
+ if pretrained_model_ortho in all_models:
+ pretrained_model_ortho = model_path(pretrained_model_ortho)
+ elif Path(pretrained_model_ortho).exists():
+ pass
+ else:
+ pretrained_model_ortho = None
+
+ pretrained_model = model_path(default_model) if use_default else pretrained_model
+ builtin = True if use_default else builtin
+ return pretrained_model, diam_mean, builtin, pretrained_model_ortho
+
+
+class CellposeModel():
+ """
+ Class representing a Cellpose model.
+
+ Attributes:
+ diam_mean (float): Mean "diameter" value for the model.
+ builtin (bool): Whether the model is a built-in model or not.
+ device (torch device): Device used for model running / training.
+ mkldnn (None or bool): MKLDNN flag for the model.
+ nchan (int): Number of channels used as input to the network.
+ nclasses (int): Number of classes in the model.
+ nbase (list): List of base values for the model.
+ net (CPnet): Cellpose network.
+ pretrained_model (str): Path to pretrained cellpose model.
+ pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D.
+ backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer).
+
+ Methods:
+ __init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None, nchan=2):
+ Initialize the CellposeModel.
+
+ eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None):
+ Segment list of images x, or 4D array - Z x nchan x Y x X.
+
+ """
+
+ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
+ mkldnn=True, diam_mean=30., device=None, nchan=2,
+ pretrained_model_ortho=None, backbone="default"):
+ """
+ Initialize the CellposeModel.
+
+ Parameters:
+ gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available.
+ pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded.
+ model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo).
+ mkldnn (bool, optional): Use MKLDNN for CPU inference, faster but not always supported.
+ diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value.
+ device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
+ nchan (int, optional): Number of channels to use as input to network, default is 2 (cyto + nuclei) or (nuclei + zeros).
+ """
+ self.diam_mean = diam_mean
+
+ ### set model path
+ default_model = "cyto3" if backbone == "default" else "transformer_cp3"
+ pretrained_model, diam_mean, builtin, pretrained_model_ortho = get_model_params(
+ pretrained_model,
+ model_type,
+ pretrained_model_ortho,
+ default_model)
+ self.diam_mean = diam_mean if diam_mean is not None else self.diam_mean
+
+ ### assign model device
+ self.mkldnn = None
+ self.device = assign_device(gpu=gpu)[0] if device is None else device
+ if torch.cuda.is_available():
+ device_gpu = self.device.type == "cuda"
+ elif torch.backends.mps.is_available():
+ device_gpu = self.device.type == "mps"
+ else:
+ device_gpu = False
+ self.gpu = device_gpu
+ if not self.gpu:
+ self.mkldnn = check_mkl(True) if mkldnn else False
+
+ ### create neural network
+ self.nchan = nchan
+ self.nclasses = 3
+ nbase = [32, 64, 128, 256]
+ self.nbase = [nchan, *nbase]
+ self.pretrained_model = pretrained_model
+ if backbone == "default":
+ self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn,
+ max_pool=True, diam_mean=self.diam_mean).to(self.device)
+ else:
+ from .segformer import Transformer
+ self.net = Transformer(
+ encoder_weights="imagenet" if not self.pretrained_model else None,
+ diam_mean=self.diam_mean).to(self.device)
+
+ ### load model weights
+ if self.pretrained_model:
+ models_logger.info(f">>>> loading model {pretrained_model}")
+ self.net.load_model(self.pretrained_model, device=self.device)
+ if not builtin:
+ self.diam_mean = self.net.diam_mean.data.cpu().numpy()[0]
+ self.diam_labels = self.net.diam_labels.data.cpu().numpy()[0]
+ models_logger.info(
+ f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
+ )
+ if not builtin:
+ models_logger.info(
+ f">>>> model diam_labels = {self.diam_labels: .3f} (mean diameter of training ROIs)"
+ )
+ if pretrained_model_ortho is not None:
+ models_logger.info(f">>>> loading ortho model {pretrained_model_ortho}")
+ self.net_ortho = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn,
+ max_pool=True, diam_mean=self.diam_mean).to(self.device)
+ self.net_ortho.load_model(pretrained_model_ortho, device=self.device)
+ else:
+ self.net_ortho = None
+ else:
+ models_logger.info(f">>>> no model weights loaded")
+ self.diam_labels = self.diam_mean
+
+ self.net_type = f"cellpose_{backbone}"
+
+ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
+ z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
+ flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
+ flow3D_smooth=0, stitch_threshold=0.0,
+ min_size=15, max_size_fraction=0.4, niter=None,
+ augment=False, tile_overlap=0.1, bsize=224,
+ interp=True, compute_masks=True, progress=None):
+ """ segment list of images x, or 4D array - Z x nchan x Y x X
+
+ Args:
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
+ resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
+ Defaults to None.
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
+ can also pass dictionary of parameters (all keys are optional, default values shown):
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
+ Defaults to True.
+ invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
+ (only used if diameter is None). Defaults to None.
+ diameter (float, optional): diameter for each image,
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
+ flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
+ cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
+ do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
+ flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
+ stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
+ min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
+ max_size_fraction (float, optional): max_size_fraction (float, optional): Masks larger than max_size_fraction of
+ total image size are removed. Default is 0.4.
+ niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
+ augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
+ bsize (int, optional): block size for tiles, recommended to keep at 224, like in training. Defaults to 224.
+ interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
+ compute_masks (bool, optional): Whether or not to compute dynamics and return masks. This is set to False when retrieving the styles for the size model. Defaults to True.
+ progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
+
+ Returns:
+ A tuple containing (masks, flows, styles, diams):
+ masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels;
+ flows (list of lists 2D arrays or list of 3D arrays): flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY flows at each pixel;
+ flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics);
+ flows[k][3] = final pixel locations after Euler integration;
+ styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image.
+
+ """
+ if isinstance(x, list) or x.squeeze().ndim == 5:
+ self.timing = []
+ masks, styles, flows = [], [], []
+ tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
+ nimg = len(x)
+ iterator = trange(nimg, file=tqdm_out,
+ mininterval=30) if nimg > 1 else range(nimg)
+ for i in iterator:
+ tic = time.time()
+ maski, flowi, stylei = self.eval(
+ x[i], batch_size=batch_size,
+ channels=channels[i] if channels is not None and
+ ((len(channels) == len(x) and
+ (isinstance(channels[i], list) or
+ isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
+ else channels, channel_axis=channel_axis, z_axis=z_axis,
+ normalize=normalize, invert=invert,
+ rescale=rescale[i] if isinstance(rescale, list) or
+ isinstance(rescale, np.ndarray) else rescale,
+ diameter=diameter[i] if isinstance(diameter, list) or
+ isinstance(diameter, np.ndarray) else diameter, do_3D=do_3D,
+ anisotropy=anisotropy, augment=augment,
+ tile_overlap=tile_overlap, bsize=bsize, resample=resample,
+ interp=interp, flow_threshold=flow_threshold,
+ cellprob_threshold=cellprob_threshold, compute_masks=compute_masks,
+ min_size=min_size, max_size_fraction=max_size_fraction,
+ stitch_threshold=stitch_threshold, flow3D_smooth=flow3D_smooth,
+ progress=progress, niter=niter
+ )
+ masks.append(maski)
+ flows.append(flowi)
+ styles.append(stylei)
+ self.timing.append(time.time() - tic)
+ return masks, flows, styles
+
+ else:
+ # reshape image
+ x = transforms.convert_image(x, channels, channel_axis=channel_axis,
+ z_axis=z_axis, do_3D=(do_3D or
+ stitch_threshold > 0),
+ nchan=self.nchan)
+ if x.ndim < 4:
+ x = x[np.newaxis, ...]
+ nimg = x.shape[0]
+
+ if diameter is not None and diameter > 0:
+ rescale = self.diam_mean / diameter
+ elif rescale is None:
+ rescale = self.diam_mean / self.diam_labels
+
+ # normalize image
+ normalize_params = normalize_default
+ if isinstance(normalize, dict):
+ normalize_params = {**normalize_params, **normalize}
+ elif not isinstance(normalize, bool):
+ raise ValueError("normalize parameter must be a bool or a dict")
+ else:
+ normalize_params["normalize"] = normalize
+ normalize_params["invert"] = invert
+
+ # pre-normalize if 3D stack for stitching or do_3D
+ do_normalization = True if normalize_params["normalize"] else False
+ x = np.asarray(x)
+ if nimg > 1 and do_normalization and (stitch_threshold or do_3D):
+ normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"]
+ x = transforms.normalize_img(x, **normalize_params)
+ do_normalization = False # do not normalize again
+ else:
+ if normalize_params["norm3D"] and nimg > 1:
+ models_logger.warning(
+ "normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False"
+ )
+ normalize_params["norm3D"] = False
+ if do_normalization:
+ x = transforms.normalize_img(x, **normalize_params)
+
+ dP, cellprob, styles = self._run_net(
+ x, rescale=rescale, augment=augment,
+ batch_size=batch_size, tile_overlap=tile_overlap, bsize=bsize,
+ resample=resample, do_3D=do_3D, anisotropy=anisotropy)
+
+ if do_3D:
+ if flow3D_smooth > 0:
+ models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
+ dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth))
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ if compute_masks:
+ niter0 = 200 if not resample else (1 / rescale * 200)
+ niter = niter0 if niter is None or niter == 0 else niter
+ masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold,
+ cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size,
+ max_size_fraction=max_size_fraction, niter=niter,
+ stitch_threshold=stitch_threshold, do_3D=do_3D)
+ else:
+ masks = np.zeros(0) #pass back zeros if not compute_masks
+
+ masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze()
+
+ return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
+
+ def _run_net(self, x, rescale=1.0, resample=True, augment=False,
+ batch_size=8, tile_overlap=0.1,
+ bsize=224, anisotropy=1.0, do_3D=False):
+ """ run network on image x """
+ tic = time.time()
+ shape = x.shape
+ nimg = shape[0]
+
+ if do_3D:
+ Lz, Ly, Lx = shape[:-1]
+ if rescale != 1.0 or (anisotropy is not None and anisotropy != 1.0):
+ models_logger.info(f"resizing 3D image with rescale={rescale:.2f} and anisotropy={anisotropy}")
+ anisotropy = 1.0 if anisotropy is None else anisotropy
+ if rescale != 1.0:
+ x = transforms.resize_image(x, Ly=int(Ly*rescale),
+ Lx=int(Lx*rescale))
+ x = transforms.resize_image(x.transpose(1,0,2,3),
+ Ly=int(Lz*anisotropy*rescale),
+ Lx=int(Lx*rescale)).transpose(1,0,2,3)
+ yf, styles = run_3D(self.net, x,
+ batch_size=batch_size, augment=augment,
+ tile_overlap=tile_overlap, net_ortho=self.net_ortho)
+ if resample:
+ if rescale != 1.0 or Lz != yf.shape[0]:
+ models_logger.info("resizing 3D flows and cellprob to original image size")
+ if rescale != 1.0:
+ yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx)
+ if Lz != yf.shape[0]:
+ yf = transforms.resize_image(yf.transpose(1,0,2,3),
+ Ly=Lz, Lx=Lx).transpose(1,0,2,3)
+ cellprob = yf[..., -1]
+ dP = yf[..., :-1].transpose((3, 0, 1, 2))
+ else:
+ yf, styles = run_net(self.net, x, bsize=bsize, augment=augment,
+ batch_size=batch_size,
+ tile_overlap=tile_overlap,
+ rsz=rescale if rescale!=1.0 else None)
+ if resample:
+ if rescale != 1.0:
+ yf = transforms.resize_image(yf, shape[1], shape[2])
+ cellprob = yf[..., 2]
+ dP = yf[..., :2].transpose((3, 0, 1, 2))
+
+ styles = styles.squeeze()
+
+ net_time = time.time() - tic
+ if nimg > 1:
+ models_logger.info("network run in %2.2fs" % (net_time))
+
+ return dP, cellprob, styles
+
+ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0,
+ interp=True, min_size=15, max_size_fraction=0.4, niter=None,
+ do_3D=False, stitch_threshold=0.0):
+ """ compute masks from flows and cell probability """
+ Lz, Ly, Lx = shape[:3]
+ tic = time.time()
+ if do_3D:
+ masks = dynamics.resize_and_compute_masks(
+ dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
+ flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
+ min_size=min_size, max_size_fraction=max_size_fraction,
+ resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
+ else None,
+ device=self.device)
+ else:
+ nimg = shape[0]
+ Ly0, Lx0 = cellprob[0].shape
+ resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
+ tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
+ iterator = trange(nimg, file=tqdm_out,
+ mininterval=30) if nimg > 1 else range(nimg)
+ for i in iterator:
+ # turn off min_size for 3D stitching
+ min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
+ outputs = dynamics.resize_and_compute_masks(
+ dP[:, i], cellprob[i],
+ niter=niter, cellprob_threshold=cellprob_threshold,
+ flow_threshold=flow_threshold, interp=interp, resize=resize,
+ min_size=min_size0, max_size_fraction=max_size_fraction,
+ device=self.device)
+ if i==0 and nimg > 1:
+ masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
+ if nimg > 1:
+ masks[i] = outputs
+ else:
+ masks = outputs
+
+ if stitch_threshold > 0 and nimg > 1:
+ models_logger.info(
+ f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
+ )
+ masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
+ masks = utils.fill_holes_and_remove_small_masks(
+ masks, min_size=min_size)
+ elif nimg > 1:
+ models_logger.warning(
+ "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
+ )
+
+ flow_time = time.time() - tic
+ if shape[0] > 1:
+ models_logger.info("masks created in %2.2fs" % (flow_time))
+
+ return masks
+
+class SizeModel():
+ """
+ Linear regression model for determining the size of objects in image
+ used to rescale before input to cp_model.
+ Uses styles from cp_model.
+
+ Attributes:
+ pretrained_size (str): Path to pretrained size model.
+ cp (UnetModel or CellposeModel): Model from which to get styles.
+ device (torch device): Device used for model running / training
+ (torch.device("cuda") or torch.device("cpu")), overrides gpu input,
+ recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
+ diam_mean (float): Mean diameter of objects.
+
+ Methods:
+ eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False,
+ augment=False, batch_size=8, progress=None, interp=True):
+ Use images x to produce style or use style input to predict size of objects in image.
+
+ Raises:
+ ValueError: If no pretrained cellpose model is specified, cannot compute size.
+ """
+
+ def __init__(self, cp_model, device=None, pretrained_size=None, **kwargs):
+ super(SizeModel, self).__init__(**kwargs)
+ """
+ Initialize size model.
+
+ Args:
+ cp_model (UnetModel or CellposeModel): Model from which to get styles.
+ device (torch device, optional): Device used for model running / training
+ (torch.device("cuda") or torch.device("cpu")), overrides gpu input,
+ recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
+ pretrained_size (str): Path to pretrained size model.
+ """
+
+ self.pretrained_size = pretrained_size
+ self.cp = cp_model
+ self.device = self.cp.device
+ self.diam_mean = self.cp.diam_mean
+ if pretrained_size is not None:
+ self.params = np.load(self.pretrained_size, allow_pickle=True).item()
+ self.diam_mean = self.params["diam_mean"]
+ if not hasattr(self.cp, "pretrained_model"):
+ error_message = "no pretrained cellpose model specified, cannot compute size"
+ models_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False,
+ augment=False, batch_size=8, progress=None):
+ """Use images x to produce style or use style input to predict size of objects in image.
+
+ Object size estimation is done in two steps:
+ 1. Use a linear regression model to predict size from style in image.
+ 2. Resize image to predicted size and run CellposeModel to get output masks.
+ Take the median object size of the predicted masks as the final predicted size.
+
+ Args:
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
+ Defaults to None.
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
+ can also pass dictionary of parameters (all keys are optional, default values shown):
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
+ Defaults to True.
+ invert (bool, optional): Invert image pixel intensity before running network (if True, image is also normalized). Defaults to False.
+ augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
+ progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
+
+
+ Returns:
+ A tuple containing (diam, diam_style):
+ diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps;
+ diam_style (np.ndarray): Estimated diameters from style alone.
+ """
+ if isinstance(x, list):
+ self.timing = []
+ diams, diams_style = [], []
+ nimg = len(x)
+ tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
+ iterator = trange(nimg, file=tqdm_out,
+ mininterval=30) if nimg > 1 else range(nimg)
+ for i in iterator:
+ tic = time.time()
+ diam, diam_style = self.eval(
+ x[i], channels=channels[i] if
+ (channels is not None and len(channels) == len(x) and
+ (isinstance(channels[i], list) or
+ isinstance(channels[i], np.ndarray)) and
+ len(channels[i]) == 2) else channels, channel_axis=channel_axis,
+ normalize=normalize, invert=invert, augment=augment,
+ batch_size=batch_size, progress=progress)
+ diams.append(diam)
+ diams_style.append(diam_style)
+ self.timing.append(time.time() - tic)
+
+ return diams, diams_style
+
+ if x.squeeze().ndim > 3:
+ models_logger.warning("image is not 2D cannot compute diameter")
+ return self.diam_mean, self.diam_mean
+
+ styles = self.cp.eval(x, channels=channels, channel_axis=channel_axis,
+ normalize=normalize, invert=invert, augment=augment,
+ batch_size=batch_size, resample=False,
+ compute_masks=False)[-1]
+
+ diam_style = self._size_estimation(np.array(styles))
+ diam_style = self.diam_mean if (diam_style == 0 or
+ np.isnan(diam_style)) else diam_style
+
+ masks = self.cp.eval(
+ x, compute_masks=True, channels=channels, channel_axis=channel_axis,
+ normalize=normalize, invert=invert, augment=augment,
+ batch_size=batch_size, resample=False,
+ rescale=self.diam_mean / diam_style if self.diam_mean > 0 else 1,
+ diameter=None, interp=False)[0]
+
+ diam = utils.diameters(masks)[0]
+ diam = self.diam_mean if (diam == 0 or np.isnan(diam)) else diam
+ return diam, diam_style
+
+ def _size_estimation(self, style):
+ """ linear regression from style to size
+
+ sizes were estimated using "diameters" from square estimates not circles;
+ therefore a conversion factor is included (to be removed)
+
+ """
+ szest = np.exp(self.params["A"] @ (style - self.params["smean"]).T +
+ np.log(self.diam_mean) + self.params["ymean"])
+ szest = np.maximum(5., szest)
+ return szest
diff --git a/cellpose/plot.py b/cellpose/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c89bc1e6a5f745555510db85709770cfa0d4581
--- /dev/null
+++ b/cellpose/plot.py
@@ -0,0 +1,282 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os
+import numpy as np
+import cv2
+from scipy.ndimage import gaussian_filter
+from . import utils, io, transforms
+
+try:
+ import matplotlib
+ MATPLOTLIB_ENABLED = True
+except:
+ MATPLOTLIB_ENABLED = False
+
+try:
+ from skimage import color
+ from skimage.segmentation import find_boundaries
+ SKIMAGE_ENABLED = True
+except:
+ SKIMAGE_ENABLED = False
+
+
+# modified to use sinebow color
+def dx_to_circ(dP):
+ """Converts the optic flow representation to a circular color representation.
+
+ Args:
+ dP (ndarray): Flow field components [dy, dx].
+
+ Returns:
+ ndarray: The circular color representation of the optic flow.
+
+ """
+ mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.)
+ angles = np.arctan2(dP[1], dP[0]) + np.pi
+ a = 2
+ mag /= a
+ rgb = np.zeros((*dP.shape[1:], 3), "uint8")
+ rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8")
+ rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8")
+ rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8")
+
+ return rgb
+
+
+def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
+ """Plot segmentation results (like on website).
+
+ Can save each panel of figure with file_name option. Use channels option if
+ img input is not an RGB image with 3 channels.
+
+ Args:
+ fig (matplotlib.pyplot.figure): Figure in which to make plot.
+ img (ndarray): 2D or 3D array. Image input into cellpose.
+ maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
+ flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows).
+ channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0].
+ file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None.
+ seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False.
+ """
+ if not MATPLOTLIB_ENABLED:
+ raise ImportError(
+ "matplotlib not installed, install with 'pip install matplotlib'")
+ ax = fig.add_subplot(1, 4, 1)
+ img0 = img.copy()
+
+ if img0.shape[0] < 4:
+ img0 = np.transpose(img0, (1, 2, 0))
+ if img0.shape[-1] < 3 or img0.ndim < 3:
+ img0 = image_to_rgb(img0, channels=channels)
+ else:
+ if img0.max() <= 50.0:
+ img0 = np.uint8(np.clip(img0, 0, 1) * 255)
+ ax.imshow(img0)
+ ax.set_title("original image")
+ ax.axis("off")
+
+ outlines = utils.masks_to_outlines(maski)
+
+ overlay = mask_overlay(img0, maski)
+
+ ax = fig.add_subplot(1, 4, 2)
+ outX, outY = np.nonzero(outlines)
+ imgout = img0.copy()
+ imgout[outX, outY] = np.array([255, 0, 0]) # pure red
+
+ ax.imshow(imgout)
+ ax.set_title("predicted outlines")
+ ax.axis("off")
+
+ ax = fig.add_subplot(1, 4, 3)
+ ax.imshow(overlay)
+ ax.set_title("predicted masks")
+ ax.axis("off")
+
+ ax = fig.add_subplot(1, 4, 4)
+ ax.imshow(flowi)
+ ax.set_title("predicted cell pose")
+ ax.axis("off")
+
+ if file_name is not None:
+ save_path = os.path.splitext(file_name)[0]
+ io.imsave(save_path + "_overlay.jpg", overlay)
+ io.imsave(save_path + "_outlines.jpg", imgout)
+ io.imsave(save_path + "_flows.jpg", flowi)
+
+
+def mask_rgb(masks, colors=None):
+ """Masks in random RGB colors.
+
+ Args:
+ masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
+ colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
+
+ Returns:
+ RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
+ """
+ if colors is not None:
+ if colors.max() > 1:
+ colors = np.float32(colors)
+ colors /= 255
+ colors = utils.rgb_to_hsv(colors)
+
+ HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32)
+ HSV[:, :, 2] = 1.0
+ for n in range(int(masks.max())):
+ ipix = (masks == n + 1).nonzero()
+ if colors is None:
+ HSV[ipix[0], ipix[1], 0] = np.random.rand()
+ else:
+ HSV[ipix[0], ipix[1], 0] = colors[n, 0]
+ HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5
+ HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5
+ RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
+ return RGB
+
+
+def mask_overlay(img, masks, colors=None):
+ """Overlay masks on image (set image to grayscale).
+
+ Args:
+ img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)].
+ masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
+ colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
+
+ Returns:
+ RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
+ """
+ if colors is not None:
+ if colors.max() > 1:
+ colors = np.float32(colors)
+ colors /= 255
+ colors = utils.rgb_to_hsv(colors)
+ if img.ndim > 2:
+ img = img.astype(np.float32).mean(axis=-1)
+ else:
+ img = img.astype(np.float32)
+
+ HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
+ HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1)
+ hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())]
+ for n in range(int(masks.max())):
+ ipix = (masks == n + 1).nonzero()
+ if colors is None:
+ HSV[ipix[0], ipix[1], 0] = hues[n]
+ else:
+ HSV[ipix[0], ipix[1], 0] = colors[n, 0]
+ HSV[ipix[0], ipix[1], 1] = 1.0
+ RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
+ return RGB
+
+
+def image_to_rgb(img0, channels=[0, 0]):
+ """Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3.
+
+ Args:
+ img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2.
+
+ Returns:
+ ndarray: RGB image of shape Ly x Lx x 3.
+
+ """
+ img = img0.copy()
+ img = img.astype(np.float32)
+ if img.ndim < 3:
+ img = img[:, :, np.newaxis]
+ if img.shape[0] < 5:
+ img = np.transpose(img, (1, 2, 0))
+ if channels[0] == 0:
+ img = img.mean(axis=-1)[:, :, np.newaxis]
+ for i in range(img.shape[-1]):
+ if np.ptp(img[:, :, i]) > 0:
+ img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1)
+ img[:, :, i] = np.clip(img[:, :, i], 0, 1)
+ img *= 255
+ img = np.uint8(img)
+ RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
+ if img.shape[-1] == 1:
+ RGB = np.tile(img, (1, 1, 3))
+ else:
+ RGB[:, :, channels[0] - 1] = img[:, :, 0]
+ if channels[1] > 0:
+ RGB[:, :, channels[1] - 1] = img[:, :, 1]
+ return RGB
+
+
+def interesting_patch(mask, bsize=130):
+ """
+ Get patch of size bsize x bsize with most masks.
+
+ Args:
+ mask (ndarray): Input mask.
+ bsize (int): Size of the patch.
+
+ Returns:
+ tuple: Patch coordinates (y, x).
+
+ """
+ Ly, Lx = mask.shape
+ m = np.float32(mask > 0)
+ m = gaussian_filter(m, bsize / 2)
+ y, x = np.unravel_index(np.argmax(m), m.shape)
+ ycent = max(bsize // 2, min(y, Ly - bsize // 2))
+ xcent = max(bsize // 2, min(x, Lx - bsize // 2))
+ patch = [
+ np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int),
+ np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int)
+ ]
+ return patch
+
+
+def disk(med, r, Ly, Lx):
+ """Returns the pixels of a disk with a given radius and center.
+
+ Args:
+ med (tuple): The center coordinates of the disk.
+ r (float): The radius of the disk.
+ Ly (int): The height of the image.
+ Lx (int): The width of the image.
+
+ Returns:
+ tuple: A tuple containing the y and x coordinates of the pixels within the disk.
+
+ """
+ yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
+ indexing="ij")
+ inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r
+ y = yy[inds].flatten()
+ x = xx[inds].flatten()
+ return y, x
+
+
+def outline_view(img0, maski, color=[1, 0, 0], mode="inner"):
+ """
+ Generates a red outline overlay onto the image.
+
+ Args:
+ img0 (numpy.ndarray): The input image.
+ maski (numpy.ndarray): The mask representing the region of interest.
+ color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red).
+ mode (str, optional): The mode for generating the outline. Defaults to "inner".
+
+ Returns:
+ numpy.ndarray: The image with the red outline overlay.
+
+ """
+ if img0.ndim == 2:
+ img0 = np.stack([img0] * 3, axis=-1)
+ elif img0.ndim != 3:
+ raise ValueError("img0 not right size (must have ndim 2 or 3)")
+
+ if SKIMAGE_ENABLED:
+ outlines = find_boundaries(maski, mode=mode)
+ else:
+ outlines = utils.masks_to_outlines(maski, mode=mode)
+ outY, outX = np.nonzero(outlines)
+ imgout = img0.copy()
+ imgout[outY, outX] = np.array(color)
+
+ return imgout
diff --git a/cellpose/resnet_torch.py b/cellpose/resnet_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd69298e9ef6fca5048ad9d6f96321e6f3c2108a
--- /dev/null
+++ b/cellpose/resnet_torch.py
@@ -0,0 +1,345 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def batchconv(in_channels, out_channels, sz, conv_3D=False):
+ conv_layer = nn.Conv3d if conv_3D else nn.Conv2d
+ batch_norm = nn.BatchNorm3d if conv_3D else nn.BatchNorm2d
+ return nn.Sequential(
+ batch_norm(in_channels, eps=1e-5, momentum=0.05),
+ nn.ReLU(inplace=True),
+ conv_layer(in_channels, out_channels, sz, padding=sz // 2),
+ )
+
+
+def batchconv0(in_channels, out_channels, sz, conv_3D=False):
+ conv_layer = nn.Conv3d if conv_3D else nn.Conv2d
+ batch_norm = nn.BatchNorm3d if conv_3D else nn.BatchNorm2d
+ return nn.Sequential(
+ batch_norm(in_channels, eps=1e-5, momentum=0.05),
+ conv_layer(in_channels, out_channels, sz, padding=sz // 2),
+ )
+
+
+class resdown(nn.Module):
+
+ def __init__(self, in_channels, out_channels, sz, conv_3D=False):
+ super().__init__()
+ self.conv = nn.Sequential()
+ self.proj = batchconv0(in_channels, out_channels, 1, conv_3D)
+ for t in range(4):
+ if t == 0:
+ self.conv.add_module("conv_%d" % t,
+ batchconv(in_channels, out_channels, sz, conv_3D))
+ else:
+ self.conv.add_module("conv_%d" % t,
+ batchconv(out_channels, out_channels, sz, conv_3D))
+
+ def forward(self, x):
+ x = self.proj(x) + self.conv[1](self.conv[0](x))
+ x = x + self.conv[3](self.conv[2](x))
+ return x
+
+
+class downsample(nn.Module):
+
+ def __init__(self, nbase, sz, conv_3D=False, max_pool=True):
+ super().__init__()
+ self.down = nn.Sequential()
+ if max_pool:
+ self.maxpool = nn.MaxPool3d(2, stride=2) if conv_3D else nn.MaxPool2d(
+ 2, stride=2)
+ else:
+ self.maxpool = nn.AvgPool3d(2, stride=2) if conv_3D else nn.AvgPool2d(
+ 2, stride=2)
+ for n in range(len(nbase) - 1):
+ self.down.add_module("res_down_%d" % n,
+ resdown(nbase[n], nbase[n + 1], sz, conv_3D))
+
+ def forward(self, x):
+ xd = []
+ for n in range(len(self.down)):
+ if n > 0:
+ y = self.maxpool(xd[n - 1])
+ else:
+ y = x
+ xd.append(self.down[n](y))
+ return xd
+
+
+class batchconvstyle(nn.Module):
+
+ def __init__(self, in_channels, out_channels, style_channels, sz, conv_3D=False):
+ super().__init__()
+ self.concatenation = False
+ self.conv = batchconv(in_channels, out_channels, sz, conv_3D)
+ self.full = nn.Linear(style_channels, out_channels)
+
+ def forward(self, style, x, mkldnn=False, y=None):
+ if y is not None:
+ x = x + y
+ feat = self.full(style)
+ for k in range(len(x.shape[2:])):
+ feat = feat.unsqueeze(-1)
+ if mkldnn:
+ x = x.to_dense()
+ y = (x + feat).to_mkldnn()
+ else:
+ y = x + feat
+ y = self.conv(y)
+ return y
+
+
+class resup(nn.Module):
+
+ def __init__(self, in_channels, out_channels, style_channels, sz, conv_3D=False):
+ super().__init__()
+ self.concatenation = False
+ self.conv = nn.Sequential()
+ self.conv.add_module("conv_0",
+ batchconv(in_channels, out_channels, sz, conv_3D=conv_3D))
+ self.conv.add_module(
+ "conv_1",
+ batchconvstyle(out_channels, out_channels, style_channels, sz,
+ conv_3D=conv_3D))
+ self.conv.add_module(
+ "conv_2",
+ batchconvstyle(out_channels, out_channels, style_channels, sz,
+ conv_3D=conv_3D))
+ self.conv.add_module(
+ "conv_3",
+ batchconvstyle(out_channels, out_channels, style_channels, sz,
+ conv_3D=conv_3D))
+ self.proj = batchconv0(in_channels, out_channels, 1, conv_3D=conv_3D)
+
+ def forward(self, x, y, style, mkldnn=False):
+ x = self.proj(x) + self.conv[1](style, self.conv[0](x), y=y, mkldnn=mkldnn)
+ x = x + self.conv[3](style, self.conv[2](style, x, mkldnn=mkldnn),
+ mkldnn=mkldnn)
+ return x
+
+
+class make_style(nn.Module):
+
+ def __init__(self, conv_3D=False):
+ super().__init__()
+ self.flatten = nn.Flatten()
+ self.avg_pool = F.avg_pool3d if conv_3D else F.avg_pool2d
+
+ def forward(self, x0):
+ style = self.avg_pool(x0, kernel_size=x0.shape[2:])
+ style = self.flatten(style)
+ style = style / torch.sum(style**2, axis=1, keepdim=True)**.5
+ return style
+
+
+class upsample(nn.Module):
+
+ def __init__(self, nbase, sz, conv_3D=False):
+ super().__init__()
+ self.upsampling = nn.Upsample(scale_factor=2, mode="nearest")
+ self.up = nn.Sequential()
+ for n in range(1, len(nbase)):
+ self.up.add_module("res_up_%d" % (n - 1),
+ resup(nbase[n], nbase[n - 1], nbase[-1], sz, conv_3D))
+
+ def forward(self, style, xd, mkldnn=False):
+ x = self.up[-1](xd[-1], xd[-1], style, mkldnn=mkldnn)
+ for n in range(len(self.up) - 2, -1, -1):
+ if mkldnn:
+ x = self.upsampling(x.to_dense()).to_mkldnn()
+ else:
+ x = self.upsampling(x)
+ x = self.up[n](x, xd[n], style, mkldnn=mkldnn)
+ return x
+
+
+class CPnet(nn.Module):
+ """
+ CPnet is the Cellpose neural network model used for cell segmentation and image restoration.
+
+ Args:
+ nbase (list): List of integers representing the number of channels in each layer of the downsample path.
+ nout (int): Number of output channels.
+ sz (int): Size of the input image.
+ mkldnn (bool, optional): Whether to use MKL-DNN acceleration. Defaults to False.
+ conv_3D (bool, optional): Whether to use 3D convolution. Defaults to False.
+ max_pool (bool, optional): Whether to use max pooling. Defaults to True.
+ diam_mean (float, optional): Mean diameter of the cells. Defaults to 30.0.
+
+ Attributes:
+ nbase (list): List of integers representing the number of channels in each layer of the downsample path.
+ nout (int): Number of output channels.
+ sz (int): Size of the input image.
+ residual_on (bool): Whether to use residual connections.
+ style_on (bool): Whether to use style transfer.
+ concatenation (bool): Whether to use concatenation.
+ conv_3D (bool): Whether to use 3D convolution.
+ mkldnn (bool): Whether to use MKL-DNN acceleration.
+ downsample (nn.Module): Downsample blocks of the network.
+ upsample (nn.Module): Upsample blocks of the network.
+ make_style (nn.Module): Style module, avgpool's over all spatial positions.
+ output (nn.Module): Output module - batchconv layer.
+ diam_mean (nn.Parameter): Parameter representing the mean diameter to which the cells are rescaled to during training.
+ diam_labels (nn.Parameter): Parameter representing the mean diameter of the cells in the training set (before rescaling).
+
+ """
+
+ def __init__(self, nbase, nout, sz, mkldnn=False, conv_3D=False, max_pool=True,
+ diam_mean=30.):
+ super().__init__()
+ self.nchan = nbase[0]
+ self.nbase = nbase
+ self.nout = nout
+ self.sz = sz
+ self.residual_on = True
+ self.style_on = True
+ self.concatenation = False
+ self.conv_3D = conv_3D
+ self.mkldnn = mkldnn if mkldnn is not None else False
+ self.downsample = downsample(nbase, sz, conv_3D=conv_3D, max_pool=max_pool)
+ nbaseup = nbase[1:]
+ nbaseup.append(nbaseup[-1])
+ self.upsample = upsample(nbaseup, sz, conv_3D=conv_3D)
+ self.make_style = make_style(conv_3D=conv_3D)
+ self.output = batchconv(nbaseup[0], nout, 1, conv_3D=conv_3D)
+ self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean,
+ requires_grad=False)
+ self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean,
+ requires_grad=False)
+
+ @property
+ def device(self):
+ """
+ Get the device of the model.
+
+ Returns:
+ torch.device: The device of the model.
+ """
+ return next(self.parameters()).device
+
+ def forward(self, data):
+ """
+ Forward pass of the CPnet model.
+
+ Args:
+ data (torch.Tensor): Input data.
+
+ Returns:
+ tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
+ """
+ if self.mkldnn:
+ data = data.to_mkldnn()
+ T0 = self.downsample(data)
+ if self.mkldnn:
+ style = self.make_style(T0[-1].to_dense())
+ else:
+ style = self.make_style(T0[-1])
+ style0 = style
+ if not self.style_on:
+ style = style * 0
+ T1 = self.upsample(style, T0, self.mkldnn)
+ T1 = self.output(T1)
+ if self.mkldnn:
+ T0 = [t0.to_dense() for t0 in T0]
+ T1 = T1.to_dense()
+ return T1, style0, T0
+
+ def save_model(self, filename):
+ """
+ Save the model to a file.
+
+ Args:
+ filename (str): The path to the file where the model will be saved.
+ """
+ torch.save(self.state_dict(), filename)
+
+ def load_model(self, filename, device=None):
+ """
+ Load the model from a file.
+
+ Args:
+ filename (str): The path to the file where the model is saved.
+ device (torch.device, optional): The device to load the model on. Defaults to None.
+ """
+ if (device is not None) and (device.type != "cpu"):
+ state_dict = torch.load(filename, map_location=device, weights_only=True)
+ else:
+ self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
+ self.diam_mean)
+ state_dict = torch.load(filename, map_location=torch.device("cpu"),
+ weights_only=True)
+
+ if state_dict["output.2.weight"].shape[0] != self.nout:
+ for name in self.state_dict():
+ if "output" not in name:
+ self.state_dict()[name].copy_(state_dict[name])
+ else:
+ self.load_state_dict(
+ dict([(name, param) for name, param in state_dict.items()]),
+ strict=False)
+
+class CPnetBioImageIO(CPnet):
+ """
+ A subclass of the CPnet model compatible with the BioImage.IO Spec.
+
+ This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
+ allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
+ """
+
+ def forward(self, x):
+ """
+ Perform a forward pass of the CPnet model and return unpacked tensors.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
+ """
+ output_tensor, style_tensor, downsampled_tensors = super().forward(x)
+ return output_tensor, style_tensor, *downsampled_tensors
+
+
+ def load_model(self, filename, device=None):
+ """
+ Load the model from a file.
+
+ Args:
+ filename (str): The path to the file where the model is saved.
+ device (torch.device, optional): The device to load the model on. Defaults to None.
+ """
+ if (device is not None) and (device.type != "cpu"):
+ state_dict = torch.load(filename, map_location=device, weights_only=True)
+ else:
+ self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
+ self.diam_mean)
+ state_dict = torch.load(filename, map_location=torch.device("cpu"),
+ weights_only=True)
+
+ self.load_state_dict(state_dict)
+
+ def load_state_dict(self, state_dict):
+ """
+ Load the state dictionary into the model.
+
+ This method overrides the default `load_state_dict` to handle Cellpose's custom
+ loading mechanism and ensures compatibility with BioImage.IO Core.
+
+ Args:
+ state_dict (Mapping[str, Any]): A state dictionary to load into the model
+ """
+ if state_dict["output.2.weight"].shape[0] != self.nout:
+ for name in self.state_dict():
+ if "output" not in name:
+ self.state_dict()[name].copy_(state_dict[name])
+ else:
+ super().load_state_dict(
+ {name: param for name, param in state_dict.items()},
+ strict=False)
+
diff --git a/cellpose/segformer.py b/cellpose/segformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..da142f6f5d4d004b20f876d85decfd42479c4f41
--- /dev/null
+++ b/cellpose/segformer.py
@@ -0,0 +1,79 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+import torch
+from torch import nn
+
+try:
+ import segmentation_models_pytorch as smp
+
+ class Transformer(nn.Module):
+ """ Transformer encoder from segformer paper with MAnet decoder
+ (configuration from MEDIAR)
+ """
+
+ def __init__(self, encoder="mit_b5", encoder_weights=None, decoder="MAnet",
+ diam_mean=30.):
+ super().__init__()
+ net_fcn = smp.MAnet if decoder == "MAnet" else smp.FPN
+ self.encoder = encoder
+ self.decoder = decoder
+ self.net = net_fcn(
+ encoder_name=encoder,
+ encoder_weights=encoder_weights,
+ # (use "imagenet" pre-trained weights for encoder initialization if training)
+ in_channels=3,
+ classes=3,
+ activation=None)
+ self.nout = 3
+ self.mkldnn = False
+ self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean,
+ requires_grad=False)
+ self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean,
+ requires_grad=False)
+
+ def forward(self, X):
+ # have to convert to 3-chan (RGB)
+ if X.shape[1] < 3:
+ X = torch.cat(
+ (X,
+ torch.zeros((X.shape[0], 3 - X.shape[1], X.shape[2], X.shape[3]),
+ device=X.device)), dim=1)
+ y = self.net(X)
+ return y, torch.zeros((X.shape[0], 256), device=X.device)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def save_model(self, filename):
+ """
+ Save the model to a file.
+
+ Args:
+ filename (str): The path to the file where the model will be saved.
+ """
+ torch.save(self.state_dict(), filename)
+
+ def load_model(self, filename, device=None):
+ """
+ Load the model from a file.
+
+ Args:
+ filename (str): The path to the file where the model is saved.
+ device (torch.device, optional): The device to load the model on. Defaults to None.
+ """
+ if (device is not None) and (device.type != "cpu"):
+ state_dict = torch.load(filename, map_location=device, weights_only=True)
+ else:
+ self.__init__(encoder=self.encoder, decoder=self.decoder,
+ diam_mean=self.diam_mean)
+ state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)
+
+ self.load_state_dict(
+ dict([(name, param) for name, param in state_dict.items()]),
+ strict=False)
+
+except Exception as e:
+ print(e)
+ print("need to install segmentation_models_pytorch to run transformer")
diff --git a/cellpose/test_mkl.py b/cellpose/test_mkl.py
new file mode 100644
index 0000000000000000000000000000000000000000..142f4dbefa0b17eeb62f8a6733dd2f63454132e2
--- /dev/null
+++ b/cellpose/test_mkl.py
@@ -0,0 +1,42 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import os, sys
+
+os.environ["MKLDNN_VERBOSE"] = "1"
+import numpy as np
+import time
+
+try:
+ import mxnet as mx
+ x = mx.sym.Variable("x")
+ MXNET_ENABLED = True
+except:
+ MXNET_ENABLED = False
+
+
+def test_mkl():
+ if MXNET_ENABLED:
+ num_filter = 32
+ kernel = (3, 3)
+ pad = (1, 1)
+ shape = (32, 32, 256, 256)
+
+ x = mx.sym.Variable("x")
+ w = mx.sym.Variable("w")
+ y = mx.sym.Convolution(data=x, weight=w, num_filter=num_filter, kernel=kernel,
+ no_bias=True, pad=pad)
+ exe = y.simple_bind(mx.cpu(), x=shape)
+
+ exe.arg_arrays[0][:] = np.random.normal(size=exe.arg_arrays[0].shape)
+ exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape)
+
+ exe.forward(is_train=False)
+ o = exe.outputs[0]
+ t = o.asnumpy()
+
+
+if __name__ == "__main__":
+ if MXNET_ENABLED:
+ test_mkl()
diff --git a/cellpose/train.py b/cellpose/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..10f619f1c7b20c50d0b808252d11fb21e172b676
--- /dev/null
+++ b/cellpose/train.py
@@ -0,0 +1,708 @@
+import time
+import os
+import numpy as np
+from cellpose import io, transforms, utils, models, dynamics, metrics, resnet_torch
+from cellpose.transforms import normalize_img
+from pathlib import Path
+import torch
+from torch import nn
+from tqdm import trange
+from numba import prange
+
+import logging
+
+train_logger = logging.getLogger(__name__)
+
+
+def _loss_fn_seg(lbl, y, device):
+ """
+ Calculates the loss function between true labels lbl and prediction y.
+
+ Args:
+ lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
+ y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).
+ device (torch.device): Device on which the tensors are located.
+
+ Returns:
+ torch.Tensor: Loss value.
+
+ """
+ criterion = nn.MSELoss(reduction="mean")
+ criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
+ veci = 5. * torch.from_numpy(lbl[:, 1:]).to(device)
+ loss = criterion(y[:, :2], veci)
+ loss /= 2.
+ loss2 = criterion2(y[:, -1], torch.from_numpy(lbl[:, 0] > 0.5).to(device).float())
+ loss = loss + loss2
+ return loss
+
+
+def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
+ channels=None, channel_axis=None, rgb=False,
+ normalize_params={"normalize": False}):
+ """
+ Get a batch of images and labels.
+
+ Args:
+ inds (list): List of indices indicating which images and labels to retrieve.
+ data (list or None): List of image data. If None, images will be loaded from files.
+ labels (list or None): List of label data. If None, labels will be loaded from files.
+ files (list or None): List of file paths for images.
+ labels_files (list or None): List of file paths for labels.
+ channels (list or None): List of channel indices to extract from images.
+ channel_axis (int or None): Axis along which the channels are located.
+ normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
+
+ Returns:
+ tuple: A tuple containing two lists: the batch of images and the batch of labels.
+ """
+ if data is None:
+ lbls = None
+ imgs = [io.imread(files[i]) for i in inds]
+ imgs = _reshape_norm(imgs, channels=channels, channel_axis=channel_axis,
+ rgb=rgb, normalize_params=normalize_params)
+ if labels_files is not None:
+ lbls = [io.imread(labels_files[i])[1:] for i in inds]
+ else:
+ imgs = [data[i] for i in inds]
+ lbls = [labels[i][1:] for i in inds]
+ return imgs, lbls
+
+
+def pad_to_rgb(img):
+ if img.ndim == 2 or np.ptp(img[1]) < 1e-3:
+ if img.ndim == 2:
+ img = img[np.newaxis, :, :]
+ img = np.tile(img[:1], (3, 1, 1))
+ elif img.shape[0] < 3:
+ nc, Ly, Lx = img.shape
+ # randomly flip channels
+ if np.random.rand() > 0.5:
+ img = img[::-1]
+ # randomly insert blank channel
+ ic = np.random.randint(3)
+ img = np.insert(img, ic, np.zeros((3 - nc, Ly, Lx), dtype=img.dtype), axis=0)
+ return img
+
+
+def convert_to_rgb(img):
+ if img.ndim == 2:
+ img = img[np.newaxis, :, :]
+ img = np.tile(img, (3, 1, 1))
+ elif img.shape[0] < 3:
+ img = img.mean(axis=0, keepdims=True)
+ img = transforms.normalize99(img)
+ img = np.tile(img, (3, 1, 1))
+ return img
+
+
+def _reshape_norm(data, channels=None, channel_axis=None, rgb=False,
+ normalize_params={"normalize": False}):
+ """
+ Reshapes and normalizes the input data.
+
+ Args:
+ data (list): List of input data.
+ channels (int or list, optional): Number of channels or list of channel indices to keep. Defaults to None.
+ channel_axis (int, optional): Axis along which the channels are located. Defaults to None.
+ normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}.
+
+ Returns:
+ list: List of reshaped and normalized data.
+ """
+ if channels is not None or channel_axis is not None:
+ data = [
+ transforms.convert_image(td, channels=channels, channel_axis=channel_axis)
+ for td in data
+ ]
+ data = [td.transpose(2, 0, 1) for td in data]
+ if normalize_params["normalize"]:
+ data = [
+ transforms.normalize_img(td, normalize=normalize_params, axis=0)
+ for td in data
+ ]
+ if rgb:
+ data = [pad_to_rgb(td) for td in data]
+ return data
+
+
+def _reshape_norm_save(files, channels=None, channel_axis=None,
+ normalize_params={"normalize": False}):
+ """ not currently used -- normalization happening on each batch if not load_files """
+ files_new = []
+ for f in trange(files):
+ td = io.imread(f)
+ if channels is not None:
+ td = transforms.convert_image(td, channels=channels,
+ channel_axis=channel_axis)
+ td = td.transpose(2, 0, 1)
+ if normalize_params["normalize"]:
+ td = transforms.normalize_img(td, normalize=normalize_params, axis=0)
+ fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif"
+ io.imsave(fnew, td)
+ files_new.append(fnew)
+ return files_new
+ # else:
+ # train_files = reshape_norm_save(train_files, channels=channels,
+ # channel_axis=channel_axis, normalize_params=normalize_params)
+ # elif test_files is not None:
+ # test_files = reshape_norm_save(test_files, channels=channels,
+ # channel_axis=channel_axis, normalize_params=normalize_params)
+
+
+def _process_train_test(train_data=None, train_labels=None, train_files=None,
+ train_labels_files=None, train_probs=None, test_data=None,
+ test_labels=None, test_files=None, test_labels_files=None,
+ test_probs=None, load_files=True, min_train_masks=5,
+ compute_flows=False, channels=None, channel_axis=None,
+ rgb=False, normalize_params={"normalize": False
+ }, device=None):
+ """
+ Process train and test data.
+
+ Args:
+ train_data (list or None): List of training data arrays.
+ train_labels (list or None): List of training label arrays.
+ train_files (list or None): List of training file paths.
+ train_labels_files (list or None): List of training label file paths.
+ train_probs (ndarray or None): Array of training probabilities.
+ test_data (list or None): List of test data arrays.
+ test_labels (list or None): List of test label arrays.
+ test_files (list or None): List of test file paths.
+ test_labels_files (list or None): List of test label file paths.
+ test_probs (ndarray or None): Array of test probabilities.
+ load_files (bool): Whether to load data from files.
+ min_train_masks (int): Minimum number of masks required for training images.
+ compute_flows (bool): Whether to compute flows.
+ channels (list or None): List of channel indices to use.
+ channel_axis (int or None): Axis of channel dimension.
+ rgb (bool): Convert training/testing images to RGB.
+ normalize_params (dict): Dictionary of normalization parameters.
+ device (torch.device): Device to use for computation.
+
+ Returns:
+ tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
+ """
+ if device == None:
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
+
+ if train_data is not None and train_labels is not None:
+ # if data is loaded
+ nimg = len(train_data)
+ nimg_test = len(test_data) if test_data is not None else None
+ else:
+ # otherwise use files
+ nimg = len(train_files)
+ if train_labels_files is None:
+ train_labels_files = [
+ os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
+ ]
+ train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)]
+ if (test_data is not None or test_files is not None) and test_labels_files is None:
+ test_labels_files = [
+ os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
+ ]
+ test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)]
+ if not load_files:
+ train_logger.info(">>> using files instead of loading dataset")
+ else:
+ # load all images
+ train_logger.info(">>> loading images and labels")
+ train_data = [io.imread(train_files[i]) for i in trange(nimg)]
+ train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)]
+ nimg_test = len(test_files) if test_files is not None else None
+ if load_files and nimg_test:
+ test_data = [io.imread(test_files[i]) for i in trange(nimg_test)]
+ test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)]
+
+ ### check that arrays are correct size
+ if ((train_labels is not None and nimg != len(train_labels)) or
+ (train_labels_files is not None and nimg != len(train_labels_files))):
+ error_message = "train data and labels not same length"
+ train_logger.critical(error_message)
+ raise ValueError(error_message)
+ if ((test_labels is not None and nimg_test != len(test_labels)) or
+ (test_labels_files is not None and nimg_test != len(test_labels_files))):
+ train_logger.warning("test data and labels not same length, not using")
+ test_data, test_files = None, None
+ if train_labels is not None:
+ if train_labels[0].ndim < 2 or train_data[0].ndim < 2:
+ error_message = "training data or labels are not at least two-dimensional"
+ train_logger.critical(error_message)
+ raise ValueError(error_message)
+ if train_data[0].ndim > 3:
+ error_message = "training data is more than three-dimensional (should be 2D or 3D array)"
+ train_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ ### check that flows are computed
+ if train_labels is not None:
+ train_labels = dynamics.labels_to_flows(train_labels, files=train_files,
+ device=device)
+ if test_labels is not None:
+ test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
+ device=device)
+ elif compute_flows:
+ for k in trange(nimg):
+ tl = dynamics.labels_to_flows(io.imread(train_labels_files),
+ files=train_files, device=device)
+ if test_files is not None:
+ for k in trange(nimg_test):
+ tl = dynamics.labels_to_flows(io.imread(test_labels_files),
+ files=test_files, device=device)
+
+ ### compute diameters
+ nmasks = np.zeros(nimg)
+ diam_train = np.zeros(nimg)
+ train_logger.info(">>> computing diameters")
+ for k in trange(nimg):
+ tl = (train_labels[k][0]
+ if train_labels is not None else io.imread(train_labels_files[k])[0])
+ diam_train[k], dall = utils.diameters(tl)
+ nmasks[k] = len(dall)
+ diam_train[diam_train < 5] = 5.
+ if test_data is not None:
+ diam_test = np.array(
+ [utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))])
+ diam_test[diam_test < 5] = 5.
+ elif test_labels_files is not None:
+ diam_test = np.array([
+ utils.diameters(io.imread(test_labels_files[k])[0])[0]
+ for k in trange(len(test_labels_files))
+ ])
+ diam_test[diam_test < 5] = 5.
+ else:
+ diam_test = None
+
+ ### check to remove training images with too few masks
+ if min_train_masks > 0:
+ nremove = (nmasks < min_train_masks).sum()
+ if nremove > 0:
+ train_logger.warning(
+ f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
+ )
+ ikeep = np.nonzero(nmasks >= min_train_masks)[0]
+ if train_data is not None:
+ train_data = [train_data[i] for i in ikeep]
+ train_labels = [train_labels[i] for i in ikeep]
+ if train_files is not None:
+ train_files = [train_files[i] for i in ikeep]
+ if train_labels_files is not None:
+ train_labels_files = [train_labels_files[i] for i in ikeep]
+ if train_probs is not None:
+ train_probs = train_probs[ikeep]
+ diam_train = diam_train[ikeep]
+ nimg = len(train_data)
+
+ ### normalize probabilities
+ train_probs = 1. / nimg * np.ones(nimg,
+ "float64") if train_probs is None else train_probs
+ train_probs /= train_probs.sum()
+ if test_files is not None or test_data is not None:
+ test_probs = 1. / nimg_test * np.ones(
+ nimg_test, "float64") if test_probs is None else test_probs
+ test_probs /= test_probs.sum()
+
+ ### reshape and normalize train / test data
+ normed = False
+ if channels is not None or normalize_params["normalize"]:
+ if channels:
+ train_logger.info(f">>> using channels {channels}")
+ if normalize_params["normalize"]:
+ train_logger.info(f">>> normalizing {normalize_params}")
+ if train_data is not None:
+ train_data = _reshape_norm(train_data, channels=channels,
+ channel_axis=channel_axis, rgb=rgb,
+ normalize_params=normalize_params)
+ normed = True
+ if test_data is not None:
+ test_data = _reshape_norm(test_data, channels=channels,
+ channel_axis=channel_axis, rgb=rgb,
+ normalize_params=normalize_params)
+
+ return (train_data, train_labels, train_files, train_labels_files, train_probs,
+ diam_train, test_data, test_labels, test_files, test_labels_files,
+ test_probs, diam_test, normed)
+
+
+def train_seg(net, train_data=None, train_labels=None, train_files=None,
+ train_labels_files=None, train_probs=None, test_data=None,
+ test_labels=None, test_files=None, test_labels_files=None,
+ test_probs=None, load_files=True, batch_size=8, learning_rate=0.005,
+ n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None,
+ channel_axis=None, rgb=False, normalize=True, compute_flows=False,
+ save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
+ nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224,
+ min_train_masks=5, model_name=None):
+ """
+ Train the network with images for segmentation.
+
+ Args:
+ net (object): The network model to train.
+ train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
+ train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
+ train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
+ train_labels_files (list or None): List of training label file paths. Defaults to None.
+ train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None.
+ test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None.
+ test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
+ test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None.
+ test_labels_files (list or None): List of test label file paths. Defaults to None.
+ test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None.
+ load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True.
+ batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8.
+ learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005.
+ n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000.
+ weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5.
+ momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9.
+ SGD (bool, optional): Boolean - whether to use SGD as optimization instead of RAdam. Defaults to False.
+ channels (List[int], optional): List of ints - channels to use for training. Defaults to None.
+ channel_axis (int, optional): Integer - axis of the channel dimension in the input data. Defaults to None.
+ normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True.
+ compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
+ save_path (str, optional): String - where to save the trained model. Defaults to None.
+ save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
+ save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False.
+ nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
+ nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
+ rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
+ min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
+ model_name (str, optional): String - name of the network. Defaults to None.
+
+ Returns:
+ tuple: A tuple containing the path to the saved model weights, training losses, and test losses.
+
+ """
+ device = net.device
+
+ scale_range0 = 0.5 if rescale else 1.0
+ scale_range = scale_range if scale_range is not None else scale_range0
+
+ if isinstance(normalize, dict):
+ normalize_params = {**models.normalize_default, **normalize}
+ elif not isinstance(normalize, bool):
+ raise ValueError("normalize parameter must be a bool or a dict")
+ else:
+ normalize_params = models.normalize_default
+ normalize_params["normalize"] = normalize
+
+ out = _process_train_test(train_data=train_data, train_labels=train_labels,
+ train_files=train_files, train_labels_files=train_labels_files,
+ train_probs=train_probs,
+ test_data=test_data, test_labels=test_labels,
+ test_files=test_files, test_labels_files=test_labels_files,
+ test_probs=test_probs,
+ load_files=load_files, min_train_masks=min_train_masks,
+ compute_flows=compute_flows, channels=channels,
+ channel_axis=channel_axis, rgb=rgb,
+ normalize_params=normalize_params, device=net.device)
+ (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
+ test_data, test_labels, test_files, test_labels_files, test_probs, diam_test,
+ normed) = out
+ # already normalized, do not normalize during training
+ if normed:
+ kwargs = {}
+ else:
+ kwargs = {
+ "normalize_params": normalize_params,
+ "channels": channels,
+ "channel_axis": channel_axis,
+ "rgb": rgb
+ }
+
+ net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)
+
+ nimg = len(train_data) if train_data is not None else len(train_files)
+ nimg_test = len(test_data) if test_data is not None else None
+ nimg_test = len(test_files) if test_files is not None else nimg_test
+ nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
+ nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
+
+ # learning rate schedule
+ LR = np.linspace(0, learning_rate, 10)
+ LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10)))
+ if n_epochs > 300:
+ LR = LR[:-100]
+ for i in range(10):
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
+ elif n_epochs > 100:
+ LR = LR[:-50]
+ for i in range(10):
+ LR = np.append(LR, LR[-1] / 2 * np.ones(5))
+
+ train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}")
+
+ if not SGD:
+ train_logger.info(
+ f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}"
+ )
+ optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate,
+ weight_decay=weight_decay)
+ else:
+ train_logger.info(
+ f">>> SGD, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}, momentum={momentum:0.3f}"
+ )
+ optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate,
+ weight_decay=weight_decay, momentum=momentum)
+
+ t0 = time.time()
+ model_name = f"cellpose_{t0}" if model_name is None else model_name
+ save_path = Path.cwd() if save_path is None else Path(save_path)
+ filename = save_path / "models" / model_name
+ (save_path / "models").mkdir(exist_ok=True)
+
+ train_logger.info(f">>> saving model to {filename}")
+
+ lavg, nsum = 0, 0
+ train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs)
+ for iepoch in range(n_epochs):
+ np.random.seed(iepoch)
+ if nimg != nimg_per_epoch:
+ # choose random images for epoch with probability train_probs
+ rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
+ p=train_probs)
+ else:
+ # otherwise use all images
+ rperm = np.random.permutation(np.arange(0, nimg))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = LR[iepoch] # set learning rate
+ net.train()
+ for k in range(0, nimg_per_epoch, batch_size):
+ kend = min(k + batch_size, nimg_per_epoch)
+ inds = rperm[k:kend]
+ imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
+ files=train_files, labels_files=train_labels_files,
+ **kwargs)
+ diams = np.array([diam_train[i] for i in inds])
+ rsc = diams / net.diam_mean.item() if rescale else np.ones(
+ len(diams), "float32")
+ # augmentations
+ imgi, lbl = transforms.random_rotate_and_resize(imgs, Y=lbls, rescale=rsc,
+ scale_range=scale_range,
+ xy=(bsize, bsize))[:2]
+ # network and loss optimization
+ X = torch.from_numpy(imgi).to(device)
+ y = net(X)[0]
+ loss = _loss_fn_seg(lbl, y, device)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ train_loss = loss.item()
+ train_loss *= len(imgi)
+
+ # keep track of average training loss across epochs
+ lavg += train_loss
+ nsum += len(imgi)
+ # per epoch training loss
+ train_losses[iepoch] += train_loss
+ train_losses[iepoch] /= nimg_per_epoch
+
+ if iepoch == 5 or iepoch % 10 == 0:
+ lavgt = 0.
+ if test_data is not None or test_files is not None:
+ np.random.seed(42)
+ if nimg_test != nimg_test_per_epoch:
+ rperm = np.random.choice(np.arange(0, nimg_test),
+ size=(nimg_test_per_epoch,), p=test_probs)
+ else:
+ rperm = np.random.permutation(np.arange(0, nimg_test))
+ for ibatch in range(0, len(rperm), batch_size):
+ with torch.no_grad():
+ net.eval()
+ inds = rperm[ibatch:ibatch + batch_size]
+ imgs, lbls = _get_batch(inds, data=test_data,
+ labels=test_labels, files=test_files,
+ labels_files=test_labels_files,
+ **kwargs)
+ diams = np.array([diam_test[i] for i in inds])
+ rsc = diams / net.diam_mean.item() if rescale else np.ones(
+ len(diams), "float32")
+ imgi, lbl = transforms.random_rotate_and_resize(
+ imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
+ xy=(bsize, bsize))[:2]
+ X = torch.from_numpy(imgi).to(device)
+ y = net(X)[0]
+ loss = _loss_fn_seg(lbl, y, device)
+ test_loss = loss.item()
+ test_loss *= len(imgi)
+ lavgt += test_loss
+ lavgt /= len(rperm)
+ test_losses[iepoch] = lavgt
+ lavg /= nsum
+ train_logger.info(
+ f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s"
+ )
+ lavg, nsum = 0, 0
+
+ if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
+ if save_each and iepoch != n_epochs - 1: #separate files as model progresses
+ filename0 = str(filename) + f"_epoch_{iepoch:04d}"
+ else:
+ filename0 = filename
+ train_logger.info(f"saving network parameters to {filename0}")
+ net.save_model(filename0)
+
+ net.save_model(filename)
+
+ return filename, train_losses, test_losses
+
+
+def train_size(net, pretrained_model, train_data=None, train_labels=None,
+ train_files=None, train_labels_files=None, train_probs=None,
+ test_data=None, test_labels=None, test_files=None,
+ test_labels_files=None, test_probs=None, load_files=True,
+ min_train_masks=5, channels=None, channel_axis=None, rgb=False,
+ normalize=True, nimg_per_epoch=None, nimg_test_per_epoch=None,
+ batch_size=64, scale_range=1.0, bsize=512, l2_regularization=1.0,
+ n_epochs=10):
+ """Train the size model.
+
+ Args:
+ net (object): The neural network model.
+ pretrained_model (str): The path to the pretrained model.
+ train_data (numpy.ndarray, optional): The training data. Defaults to None.
+ train_labels (numpy.ndarray, optional): The training labels. Defaults to None.
+ train_files (list, optional): The training file paths. Defaults to None.
+ train_labels_files (list, optional): The training label file paths. Defaults to None.
+ train_probs (numpy.ndarray, optional): The training probabilities. Defaults to None.
+ test_data (numpy.ndarray, optional): The test data. Defaults to None.
+ test_labels (numpy.ndarray, optional): The test labels. Defaults to None.
+ test_files (list, optional): The test file paths. Defaults to None.
+ test_labels_files (list, optional): The test label file paths. Defaults to None.
+ test_probs (numpy.ndarray, optional): The test probabilities. Defaults to None.
+ load_files (bool, optional): Whether to load files. Defaults to True.
+ min_train_masks (int, optional): The minimum number of training masks. Defaults to 5.
+ channels (list, optional): The channels. Defaults to None.
+ channel_axis (int, optional): The channel axis. Defaults to None.
+ normalize (bool or dict, optional): Whether to normalize the data. Defaults to True.
+ nimg_per_epoch (int, optional): The number of images per epoch. Defaults to None.
+ nimg_test_per_epoch (int, optional): The number of test images per epoch. Defaults to None.
+ batch_size (int, optional): The batch size. Defaults to 64.
+ l2_regularization (float, optional): The L2 regularization factor. Defaults to 1.0.
+ n_epochs (int, optional): The number of epochs. Defaults to 10.
+
+ Returns:
+ dict: The trained size model parameters.
+ """
+ if isinstance(normalize, dict):
+ normalize_params = {**models.normalize_default, **normalize}
+ elif not isinstance(normalize, bool):
+ raise ValueError("normalize parameter must be a bool or a dict")
+ else:
+ normalize_params = models.normalize_default
+ normalize_params["normalize"] = normalize
+
+ out = _process_train_test(
+ train_data=train_data, train_labels=train_labels, train_files=train_files,
+ train_labels_files=train_labels_files, train_probs=train_probs,
+ test_data=test_data, test_labels=test_labels, test_files=test_files,
+ test_labels_files=test_labels_files, test_probs=test_probs,
+ load_files=load_files, min_train_masks=min_train_masks, compute_flows=False,
+ channels=channels, channel_axis=channel_axis, normalize_params=normalize_params,
+ device=net.device)
+ (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
+ test_data, test_labels, test_files, test_labels_files, test_probs, diam_test,
+ normed) = out
+
+ # already normalized, do not normalize during training
+ if normed:
+ kwargs = {}
+ else:
+ kwargs = {
+ "normalize_params": normalize_params,
+ "channels": channels,
+ "channel_axis": channel_axis,
+ "rgb": rgb
+ }
+
+ nimg = len(train_data) if train_data is not None else len(train_files)
+ nimg_test = len(test_data) if test_data is not None else None
+ nimg_test = len(test_files) if test_files is not None else nimg_test
+ nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
+ nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
+
+ diam_mean = net.diam_mean.item()
+ device = net.device
+ net.eval()
+
+ styles = np.zeros((n_epochs * nimg_per_epoch, 256), np.float32)
+ diams = np.zeros((n_epochs * nimg_per_epoch,), np.float32)
+ tic = time.time()
+ for iepoch in range(n_epochs):
+ np.random.seed(iepoch)
+ if nimg != nimg_per_epoch:
+ rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
+ p=train_probs)
+ else:
+ rperm = np.random.permutation(np.arange(0, nimg))
+ for ibatch in range(0, nimg_per_epoch, batch_size):
+ inds_batch = np.arange(ibatch, min(nimg_per_epoch, ibatch + batch_size))
+ inds = rperm[inds_batch]
+ imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels,
+ files=train_files, **kwargs)
+ diami = diam_train[inds].copy()
+ imgi, lbl, scale = transforms.random_rotate_and_resize(
+ imgs, scale_range=scale_range, xy=(bsize, bsize))
+ imgi = torch.from_numpy(imgi).to(device)
+ with torch.no_grad():
+ feat = net(imgi)[1]
+ indsi = inds_batch + nimg_per_epoch * iepoch
+ styles[indsi] = feat.cpu().numpy()
+ diams[indsi] = np.log(diami) - np.log(diam_mean) + np.log(scale)
+ del feat
+ train_logger.info("ran %d epochs in %0.3f sec" %
+ (iepoch + 1, time.time() - tic))
+
+ l2_regularization = 1.
+
+ # create model
+ smean = styles.copy().mean(axis=0)
+ X = ((styles.copy() - smean).T).copy()
+ ymean = diams.copy().mean()
+ y = diams.copy() - ymean
+
+ A = np.linalg.solve(X @ X.T + l2_regularization * np.eye(X.shape[0]), X @ y)
+ ypred = A @ X
+
+ train_logger.info("train correlation: %0.4f" % np.corrcoef(y, ypred)[0, 1])
+
+ if nimg_test:
+ np.random.seed(0)
+ styles_test = np.zeros((nimg_test_per_epoch, 256), np.float32)
+ diams_test = np.zeros((nimg_test_per_epoch,), np.float32)
+ diams_test0 = np.zeros((nimg_test_per_epoch,), np.float32)
+ if nimg_test != nimg_test_per_epoch:
+ rperm = np.random.choice(np.arange(0, nimg_test),
+ size=(nimg_test_per_epoch,), p=test_probs)
+ else:
+ rperm = np.random.permutation(np.arange(0, nimg_test))
+ for ibatch in range(0, nimg_test_per_epoch, batch_size):
+ inds_batch = np.arange(ibatch, min(nimg_test_per_epoch,
+ ibatch + batch_size))
+ inds = rperm[inds_batch]
+ imgs, lbls = _get_batch(inds, data=test_data, labels=test_labels,
+ files=test_files, labels_files=test_labels_files,
+ **kwargs)
+ diami = diam_test[inds].copy()
+ imgi, lbl, scale = transforms.random_rotate_and_resize(
+ imgs, Y=lbls, scale_range=scale_range, xy=(bsize, bsize))
+ imgi = torch.from_numpy(imgi).to(device)
+ diamt = np.array([utils.diameters(lbl0[0])[0] for lbl0 in lbl])
+ diamt = np.maximum(5., diamt)
+ with torch.no_grad():
+ feat = net(imgi)[1]
+ styles_test[inds_batch] = feat.cpu().numpy()
+ diams_test[inds_batch] = np.log(diami) - np.log(diam_mean) + np.log(scale)
+ diams_test0[inds_batch] = diamt
+
+ diam_test_pred = np.exp(A @ (styles_test - smean).T + np.log(diam_mean) + ymean)
+ diam_test_pred = np.maximum(5., diam_test_pred)
+ train_logger.info("test correlation: %0.4f" %
+ np.corrcoef(diams_test0, diam_test_pred)[0, 1])
+
+ pretrained_size = str(pretrained_model) + "_size.npy"
+ params = {"A": A, "smean": smean, "diam_mean": diam_mean, "ymean": ymean}
+ np.save(pretrained_size, params)
+ train_logger.info("model saved to " + pretrained_size)
+
+ return params
diff --git a/cellpose/transforms.py b/cellpose/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc5f5160272848462c94eedcc1db4c0e7b624a99
--- /dev/null
+++ b/cellpose/transforms.py
@@ -0,0 +1,1034 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+import logging
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from scipy.ndimage import gaussian_filter1d
+from torch.fft import fft2, fftshift, ifft2
+
+transforms_logger = logging.getLogger(__name__)
+
+
+def _taper_mask(ly=224, lx=224, sig=7.5):
+ """
+ Generate a taper mask.
+
+ Args:
+ ly (int): The height of the mask. Default is 224.
+ lx (int): The width of the mask. Default is 224.
+ sig (float): The sigma value for the tapering function. Default is 7.5.
+
+ Returns:
+ numpy.ndarray: The taper mask.
+
+ """
+ bsize = max(224, max(ly, lx))
+ xm = np.arange(bsize)
+ xm = np.abs(xm - xm.mean())
+ mask = 1 / (1 + np.exp((xm - (bsize / 2 - 20)) / sig))
+ mask = mask * mask[:, np.newaxis]
+ mask = mask[bsize // 2 - ly // 2:bsize // 2 + ly // 2 + ly % 2,
+ bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2]
+ return mask
+
+
+def unaugment_tiles(y):
+ """Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX).
+
+ Args:
+ y (float32): Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx) where chan = (flowsY, flowsX, cell prob).
+
+ Returns:
+ float32: Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx).
+
+ """
+ for j in range(y.shape[0]):
+ for i in range(y.shape[1]):
+ if j % 2 == 0 and i % 2 == 1:
+ y[j, i] = y[j, i, :, ::-1, :]
+ y[j, i, 0] *= -1
+ elif j % 2 == 1 and i % 2 == 0:
+ y[j, i] = y[j, i, :, :, ::-1]
+ y[j, i, 1] *= -1
+ elif j % 2 == 1 and i % 2 == 1:
+ y[j, i] = y[j, i, :, ::-1, ::-1]
+ y[j, i, 0] *= -1
+ y[j, i, 1] *= -1
+ return y
+
+
+def average_tiles(y, ysub, xsub, Ly, Lx):
+ """
+ Average the results of the network over tiles.
+
+ Args:
+ y (float): Output of cellpose network for each tile. Shape: [ntiles x nclasses x bsize x bsize]
+ ysub (list): List of arrays with start and end of tiles in Y of length ntiles
+ xsub (list): List of arrays with start and end of tiles in X of length ntiles
+ Ly (int): Size of pre-tiled image in Y (may be larger than original image if image size is less than bsize)
+ Lx (int): Size of pre-tiled image in X (may be larger than original image if image size is less than bsize)
+
+ Returns:
+ yf (float32): Network output averaged over tiles. Shape: [nclasses x Ly x Lx]
+ """
+ Navg = np.zeros((Ly, Lx))
+ yf = np.zeros((y.shape[1], Ly, Lx), np.float32)
+ # taper edges of tiles
+ mask = _taper_mask(ly=y.shape[-2], lx=y.shape[-1])
+ for j in range(len(ysub)):
+ yf[:, ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += y[j] * mask
+ Navg[ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += mask
+ yf /= Navg
+ return yf
+
+
+def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
+ """Make tiles of image to run at test-time.
+
+ Args:
+ imgi (np.ndarray): Array of shape (nchan, Ly, Lx) representing the input image.
+ bsize (int, optional): Size of tiles. Defaults to 224.
+ augment (bool, optional): Whether to flip tiles and set tile_overlap=2. Defaults to False.
+ tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1.
+
+ Returns:
+ A tuple containing (IMG, ysub, xsub, Ly, Lx):
+ IMG (np.ndarray): Array of shape (ntiles, nchan, bsize, bsize) representing the tiles.
+ ysub (list): List of arrays with start and end of tiles in Y of length ntiles.
+ xsub (list): List of arrays with start and end of tiles in X of length ntiles.
+ Ly (int): Height of the input image.
+ Lx (int): Width of the input image.
+ """
+ nchan, Ly, Lx = imgi.shape
+ if augment:
+ bsize = np.int32(bsize)
+ # pad if image smaller than bsize
+ if Ly < bsize:
+ imgi = np.concatenate((imgi, np.zeros((nchan, bsize - Ly, Lx))), axis=1)
+ Ly = bsize
+ if Lx < bsize:
+ imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2)
+ Ly, Lx = imgi.shape[-2:]
+
+ # tiles overlap by half of tile size
+ ny = max(2, int(np.ceil(2. * Ly / bsize)))
+ nx = max(2, int(np.ceil(2. * Lx / bsize)))
+ ystart = np.linspace(0, Ly - bsize, ny).astype(int)
+ xstart = np.linspace(0, Lx - bsize, nx).astype(int)
+
+ ysub = []
+ xsub = []
+
+ # flip tiles so that overlapping segments are processed in rotation
+ IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32)
+ for j in range(len(ystart)):
+ for i in range(len(xstart)):
+ ysub.append([ystart[j], ystart[j] + bsize])
+ xsub.append([xstart[i], xstart[i] + bsize])
+ IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
+ # flip tiles to allow for augmentation of overlapping segments
+ if j % 2 == 0 and i % 2 == 1:
+ IMG[j, i] = IMG[j, i, :, ::-1, :]
+ elif j % 2 == 1 and i % 2 == 0:
+ IMG[j, i] = IMG[j, i, :, :, ::-1]
+ elif j % 2 == 1 and i % 2 == 1:
+ IMG[j, i] = IMG[j, i, :, ::-1, ::-1]
+ else:
+ tile_overlap = min(0.5, max(0.05, tile_overlap))
+ bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
+ bsizeY = np.int32(bsizeY)
+ bsizeX = np.int32(bsizeX)
+ # tiles overlap by 10% tile size
+ ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
+ nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
+ ystart = np.linspace(0, Ly - bsizeY, ny).astype(int)
+ xstart = np.linspace(0, Lx - bsizeX, nx).astype(int)
+
+ ysub = []
+ xsub = []
+ IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32)
+ for j in range(len(ystart)):
+ for i in range(len(xstart)):
+ ysub.append([ystart[j], ystart[j] + bsizeY])
+ xsub.append([xstart[i], xstart[i] + bsizeX])
+ IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
+
+ return IMG, ysub, xsub, Ly, Lx
+
+
+def normalize99(Y, lower=1, upper=99, copy=True, downsample=False):
+ """
+ Normalize the image so that 0.0 corresponds to the 1st percentile and 1.0 corresponds to the 99th percentile.
+
+ Args:
+ Y (ndarray): The input image (for downsample, use [Ly x Lx] or [Lz x Ly x Lx]).
+ lower (int, optional): The lower percentile. Defaults to 1.
+ upper (int, optional): The upper percentile. Defaults to 99.
+ copy (bool, optional): Whether to create a copy of the input image. Defaults to True.
+ downsample (bool, optional): Whether to downsample image to compute percentiles. Defaults to False.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ X = Y.copy() if copy else Y
+ X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X
+ if downsample and X.size > 224**3:
+ nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)]
+ nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0]
+ slc = tuple([slice(0, X.shape[i], nskip[i]) for i in range(X.ndim)])
+ x01 = np.percentile(X[slc], lower)
+ x99 = np.percentile(X[slc], upper)
+ else:
+ x01 = np.percentile(X, lower)
+ x99 = np.percentile(X, upper)
+ if x99 - x01 > 1e-3:
+ X -= x01
+ X /= (x99 - x01)
+ else:
+ X[:] = 0
+ return X
+
+
+def normalize99_tile(img, blocksize=100, lower=1., upper=99., tile_overlap=0.1,
+ norm3D=False, smooth3D=1, is3D=False):
+ """Compute normalization like normalize99 function but in tiles.
+
+ Args:
+ img (numpy.ndarray): Array of shape (Lz x) Ly x Lx (x nchan) containing the image.
+ blocksize (float, optional): Size of tiles. Defaults to 100.
+ lower (float, optional): Lower percentile for normalization. Defaults to 1.0.
+ upper (float, optional): Upper percentile for normalization. Defaults to 99.0.
+ tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1.
+ norm3D (bool, optional): Use same tiled normalization for each z-plane. Defaults to False.
+ smooth3D (int, optional): Smoothing factor for 3D normalization. Defaults to 1.
+ is3D (bool, optional): Set to True if image is a 3D stack. Defaults to False.
+
+ Returns:
+ numpy.ndarray: Normalized image array of shape (Lz x) Ly x Lx (x nchan).
+ """
+ is1c = True if img.ndim == 2 or (is3D and img.ndim == 3) else False
+ is3D = True if img.ndim > 3 or (is3D and img.ndim == 3) else False
+ img = img[..., np.newaxis] if is1c else img
+ img = img[np.newaxis, ...] if img.ndim == 3 else img
+ Lz, Ly, Lx, nchan = img.shape
+
+ tile_overlap = min(0.5, max(0.05, tile_overlap))
+ blocksizeY, blocksizeX = min(blocksize, Ly), min(blocksize, Lx)
+ blocksizeY = np.int32(blocksizeY)
+ blocksizeX = np.int32(blocksizeX)
+ # tiles overlap by 10% tile size
+ ny = 1 if Ly <= blocksize else int(np.ceil(
+ (1. + 2 * tile_overlap) * Ly / blocksize))
+ nx = 1 if Lx <= blocksize else int(np.ceil(
+ (1. + 2 * tile_overlap) * Lx / blocksize))
+ ystart = np.linspace(0, Ly - blocksizeY, ny).astype(int)
+ xstart = np.linspace(0, Lx - blocksizeX, nx).astype(int)
+ ysub = []
+ xsub = []
+ for j in range(len(ystart)):
+ for i in range(len(xstart)):
+ ysub.append([ystart[j], ystart[j] + blocksizeY])
+ xsub.append([xstart[i], xstart[i] + blocksizeX])
+
+ x01_tiles_z = []
+ x99_tiles_z = []
+ for z in range(Lz):
+ IMG = np.zeros((len(ystart), len(xstart), blocksizeY, blocksizeX, nchan),
+ "float32")
+ k = 0
+ for j in range(len(ystart)):
+ for i in range(len(xstart)):
+ IMG[j, i] = img[z, ysub[k][0]:ysub[k][1], xsub[k][0]:xsub[k][1], :]
+ k += 1
+ x01_tiles = np.percentile(IMG, lower, axis=(-3, -2))
+ x99_tiles = np.percentile(IMG, upper, axis=(-3, -2))
+
+ # fill areas with small differences with neighboring squares
+ to_fill = np.zeros(x01_tiles.shape[:2], "bool")
+ for c in range(nchan):
+ to_fill = x99_tiles[:, :, c] - x01_tiles[:, :, c] < +1e-3
+ if to_fill.sum() > 0 and to_fill.sum() < x99_tiles[:, :, c].size:
+ fill_vals = np.nonzero(to_fill)
+ fill_neigh = np.nonzero(~to_fill)
+ nearest_neigh = (
+ (fill_vals[0] - fill_neigh[0][:, np.newaxis])**2 +
+ (fill_vals[1] - fill_neigh[1][:, np.newaxis])**2).argmin(axis=0)
+ x01_tiles[fill_vals[0], fill_vals[1],
+ c] = x01_tiles[fill_neigh[0][nearest_neigh],
+ fill_neigh[1][nearest_neigh], c]
+ x99_tiles[fill_vals[0], fill_vals[1],
+ c] = x99_tiles[fill_neigh[0][nearest_neigh],
+ fill_neigh[1][nearest_neigh], c]
+ elif to_fill.sum() > 0 and to_fill.sum() == x99_tiles[:, :, c].size:
+ x01_tiles[:, :, c] = 0
+ x99_tiles[:, :, c] = 1
+ x01_tiles_z.append(x01_tiles)
+ x99_tiles_z.append(x99_tiles)
+
+ x01_tiles_z = np.array(x01_tiles_z)
+ x99_tiles_z = np.array(x99_tiles_z)
+ # do not smooth over z-axis if not normalizing separately per plane
+ for a in range(2):
+ x01_tiles_z = gaussian_filter1d(x01_tiles_z, 1, axis=a)
+ x99_tiles_z = gaussian_filter1d(x99_tiles_z, 1, axis=a)
+ if norm3D:
+ smooth3D = 1 if smooth3D == 0 else smooth3D
+ x01_tiles_z = gaussian_filter1d(x01_tiles_z, smooth3D, axis=a)
+ x99_tiles_z = gaussian_filter1d(x99_tiles_z, smooth3D, axis=a)
+
+ if not norm3D and Lz > 1:
+ x01 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32")
+ x99 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32")
+ for z in range(Lz):
+ x01_rsz = cv2.resize(x01_tiles_z[z], (Lx, Ly),
+ interpolation=cv2.INTER_LINEAR)
+ x01[z] = x01_rsz[..., np.newaxis] if nchan == 1 else x01_rsz
+ x99_rsz = cv2.resize(x99_tiles_z[z], (Lx, Ly),
+ interpolation=cv2.INTER_LINEAR)
+ x99[z] = x99_rsz[..., np.newaxis] if nchan == 1 else x01_rsz
+ if (x99 - x01).min() < 1e-3:
+ raise ZeroDivisionError(
+ "cannot use norm3D=False with tile_norm, sample is too sparse; set norm3D=True or tile_norm=0"
+ )
+ else:
+ x01 = cv2.resize(x01_tiles_z.mean(axis=0), (Lx, Ly),
+ interpolation=cv2.INTER_LINEAR)
+ x99 = cv2.resize(x99_tiles_z.mean(axis=0), (Lx, Ly),
+ interpolation=cv2.INTER_LINEAR)
+ if x01.ndim < 3:
+ x01 = x01[..., np.newaxis]
+ x99 = x99[..., np.newaxis]
+
+ if is1c:
+ img, x01, x99 = img.squeeze(), x01.squeeze(), x99.squeeze()
+ elif not is3D:
+ img, x01, x99 = img[0], x01[0], x99[0]
+
+ # normalize
+ img -= x01
+ img /= (x99 - x01)
+
+ return img
+
+
+def gaussian_kernel(sigma, Ly, Lx, device=torch.device("cpu")):
+ """
+ Generates a 2D Gaussian kernel.
+
+ Args:
+ sigma (float): Standard deviation of the Gaussian distribution.
+ Ly (int): Number of pixels in the y-axis.
+ Lx (int): Number of pixels in the x-axis.
+ device (torch.device, optional): Device to store the kernel tensor. Defaults to torch.device("cpu").
+
+ Returns:
+ torch.Tensor: 2D Gaussian kernel tensor.
+
+ """
+ y = torch.linspace(-Ly / 2, Ly / 2 + 1, Ly, device=device)
+ x = torch.linspace(-Ly / 2, Ly / 2 + 1, Lx, device=device)
+ y, x = torch.meshgrid(y, x, indexing="ij")
+ kernel = torch.exp(-(y**2 + x**2) / (2 * sigma**2))
+ kernel /= kernel.sum()
+ return kernel
+
+
+def smooth_sharpen_img(img, smooth_radius=6, sharpen_radius=12,
+ device=torch.device("cpu"), is3D=False):
+ """Sharpen blurry images with surround subtraction and/or smooth noisy images.
+
+ Args:
+ img (float32): Array that's (Lz x) Ly x Lx (x nchan).
+ smooth_radius (float, optional): Size of gaussian smoothing filter, recommended to be 1/10-1/4 of cell diameter
+ (if also sharpening, should be 2-3x smaller than sharpen_radius). Defaults to 6.
+ sharpen_radius (float, optional): Size of gaussian surround filter, recommended to be 1/8-1/2 of cell diameter
+ (if also smoothing, should be 2-3x larger than smooth_radius). Defaults to 12.
+ device (torch.device, optional): Device on which to perform sharpening.
+ Will be faster on GPU but need to ensure GPU has RAM for image. Defaults to torch.device("cpu").
+ is3D (bool, optional): If image is 3D stack (only necessary to set if img.ndim==3). Defaults to False.
+
+ Returns:
+ img_sharpen (float32): Array that's (Lz x) Ly x Lx (x nchan).
+ """
+ img_sharpen = torch.from_numpy(img.astype("float32")).to(device)
+ shape = img_sharpen.shape
+
+ is1c = True if img_sharpen.ndim == 2 or (is3D and img_sharpen.ndim == 3) else False
+ is3D = True if img_sharpen.ndim > 3 or (is3D and img_sharpen.ndim == 3) else False
+ img_sharpen = img_sharpen.unsqueeze(-1) if is1c else img_sharpen
+ img_sharpen = img_sharpen.unsqueeze(0) if img_sharpen.ndim == 3 else img_sharpen
+ Lz, Ly, Lx, nchan = img_sharpen.shape
+
+ if smooth_radius > 0:
+ kernel = gaussian_kernel(smooth_radius, Ly, Lx, device=device)
+ if sharpen_radius > 0:
+ kernel += -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device)
+ elif sharpen_radius > 0:
+ kernel = -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device)
+ kernel[Ly // 2, Lx // 2] = 1
+
+ fhp = fft2(kernel)
+ for z in range(Lz):
+ for c in range(nchan):
+ img_filt = torch.real(ifft2(
+ fft2(img_sharpen[z, :, :, c]) * torch.conj(fhp)))
+ img_filt = fftshift(img_filt)
+ img_sharpen[z, :, :, c] = img_filt
+
+ img_sharpen = img_sharpen.reshape(shape)
+ return img_sharpen.cpu().numpy()
+
+
+def move_axis(img, m_axis=-1, first=True):
+ """ move axis m_axis to first or last position """
+ if m_axis == -1:
+ m_axis = img.ndim - 1
+ m_axis = min(img.ndim - 1, m_axis)
+ axes = np.arange(0, img.ndim)
+ if first:
+ axes[1:m_axis + 1] = axes[:m_axis]
+ axes[0] = m_axis
+ else:
+ axes[m_axis:-1] = axes[m_axis + 1:]
+ axes[-1] = m_axis
+ img = img.transpose(tuple(axes))
+ return img
+
+
+def move_min_dim(img, force=False):
+ """Move the minimum dimension last as channels if it is less than 10 or force is True.
+
+ Args:
+ img (ndarray): The input image.
+ force (bool, optional): If True, the minimum dimension will always be moved.
+ Defaults to False.
+
+ Returns:
+ ndarray: The image with the minimum dimension moved to the last axis as channels.
+ """
+ if len(img.shape) > 2:
+ min_dim = min(img.shape)
+ if min_dim < 10 or force:
+ if img.shape[-1] == min_dim:
+ channel_axis = -1
+ else:
+ channel_axis = (img.shape).index(min_dim)
+ img = move_axis(img, m_axis=channel_axis, first=False)
+ return img
+
+
+def update_axis(m_axis, to_squeeze, ndim):
+ """
+ Squeeze the axis value based on the given parameters.
+
+ Args:
+ m_axis (int): The current axis value.
+ to_squeeze (numpy.ndarray): An array of indices to squeeze.
+ ndim (int): The number of dimensions.
+
+ Returns:
+ int or None: The updated axis value.
+ """
+ if m_axis == -1:
+ m_axis = ndim - 1
+ if (to_squeeze == m_axis).sum() == 1:
+ m_axis = None
+ else:
+ inds = np.ones(ndim, bool)
+ inds[to_squeeze] = False
+ m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0]
+ if len(m_axis) > 0:
+ m_axis = m_axis[0]
+ else:
+ m_axis = None
+ return m_axis
+
+
+def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, nchan=2):
+ """Converts the image to have the z-axis first, channels last.
+
+ Args:
+ x (numpy.ndarray or torch.Tensor): The input image.
+ channels (list or None): The list of channels to use (ones-based, 0=gray). If None, all channels are kept.
+ channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically.
+ z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically.
+ do_3D (bool): Whether to process the image in 3D mode. Defaults to False.
+ nchan (int): The number of channels to keep if the input image has more than nchan channels.
+
+ Returns:
+ numpy.ndarray: The converted image.
+
+ Raises:
+ ValueError: If the input image has less than two channels and channels are not specified.
+ ValueError: If the input image is 2D and do_3D is True.
+ ValueError: If the input image is 4D and do_3D is False.
+ """
+ # check if image is a torch array instead of numpy array
+ # converts torch to numpy
+ ndim = x.ndim
+ if torch.is_tensor(x):
+ transforms_logger.warning("torch array used as input, converting to numpy")
+ x = x.cpu().numpy()
+
+ # squeeze image, and if channel_axis or z_axis given, transpose image
+ if x.ndim > 3:
+ to_squeeze = np.array([int(isq) for isq, s in enumerate(x.shape) if s == 1])
+ # remove channel axis if number of channels is 1
+ if len(to_squeeze) > 0:
+ channel_axis = update_axis(
+ channel_axis, to_squeeze,
+ x.ndim) if channel_axis is not None else None
+ z_axis = update_axis(z_axis, to_squeeze,
+ x.ndim) if z_axis is not None else None
+ x = x.squeeze()
+
+ # put z axis first
+ if z_axis is not None and x.ndim > 2 and z_axis != 0:
+ x = move_axis(x, m_axis=z_axis, first=True)
+ if channel_axis is not None:
+ channel_axis += 1
+ z_axis = 0
+ elif z_axis is None and x.ndim > 2 and channels is not None and min(x.shape) > 5 :
+ # if there are > 5 channels and channels!=None, assume first dimension is z
+ min_dim = min(x.shape)
+ if min_dim != channel_axis:
+ z_axis = (x.shape).index(min_dim)
+ if z_axis != 0:
+ x = move_axis(x, m_axis=z_axis, first=True)
+ if channel_axis is not None:
+ channel_axis += 1
+ transforms_logger.warning(f"z_axis not specified, assuming it is dim {z_axis}")
+ transforms_logger.warning(f"if this is actually the channel_axis, use 'model.eval(channel_axis={z_axis}, ...)'")
+ z_axis = 0
+
+ if z_axis is not None:
+ if x.ndim == 3:
+ x = x[..., np.newaxis]
+
+ # put channel axis last
+ if channel_axis is not None and x.ndim > 2:
+ x = move_axis(x, m_axis=channel_axis, first=False)
+ elif x.ndim == 2:
+ x = x[:, :, np.newaxis]
+
+ if do_3D:
+ if ndim < 3:
+ transforms_logger.critical("ERROR: cannot process 2D images in 3D mode")
+ raise ValueError("ERROR: cannot process 2D images in 3D mode")
+ elif x.ndim < 4:
+ x = x[..., np.newaxis]
+
+ if channel_axis is None:
+ x = move_min_dim(x)
+
+ if x.ndim > 3:
+ transforms_logger.info(
+ "multi-stack tiff read in as having %d planes %d channels" %
+ (x.shape[0], x.shape[-1]))
+
+ # convert to float32
+ x = x.astype("float32")
+
+ if channels is not None:
+ channels = channels[0] if len(channels) == 1 else channels
+ if len(channels) < 2:
+ transforms_logger.critical("ERROR: two channels not specified")
+ raise ValueError("ERROR: two channels not specified")
+ x = reshape(x, channels=channels)
+
+ else:
+ # code above put channels last
+ if nchan is not None and x.shape[-1] > nchan:
+ transforms_logger.warning(
+ "WARNING: more than %d channels given, use 'channels' input for specifying channels - just using first %d channels to run processing"
+ % (nchan, nchan))
+ x = x[..., :nchan]
+
+ # if not do_3D and x.ndim > 3:
+ # transforms_logger.critical("ERROR: cannot process 4D images in 2D mode")
+ # raise ValueError("ERROR: cannot process 4D images in 2D mode")
+
+ if nchan is not None and x.shape[-1] < nchan:
+ x = np.concatenate((x, np.tile(np.zeros_like(x), (1, 1, nchan - 1))),
+ axis=-1)
+
+ return x
+
+
+def reshape(data, channels=[0, 0], chan_first=False):
+ """Reshape data using channels.
+
+ Args:
+ data (numpy.ndarray): The input data. It should have shape (Z x ) Ly x Lx x nchan
+ if data.ndim==3 and data.shape[0]<8, it is assumed to be nchan x Ly x Lx.
+ channels (list of int, optional): The channels to use for reshaping. The first element
+ of the list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). The
+ second element of the list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
+ For instance, to train on grayscale images, input [0,0]. To train on images with cells
+ in green and nuclei in blue, input [2,3]. Defaults to [0, 0].
+ chan_first (bool, optional): Whether to return the reshaped data with channel as the first
+ dimension. Defaults to False.
+
+ Returns:
+ numpy.ndarray: The reshaped data with shape (Z x ) Ly x Lx x nchan (if chan_first==False).
+ """
+ if data.ndim < 3:
+ data = data[:, :, np.newaxis]
+ elif data.shape[0] < 8 and data.ndim == 3:
+ data = np.transpose(data, (1, 2, 0))
+
+ # use grayscale image
+ if data.shape[-1] == 1:
+ data = np.concatenate((data, np.zeros(data.shape, "float32")), axis=-1)
+ else:
+ if channels[0] == 0:
+ data = data.mean(axis=-1, keepdims=True)
+ data = np.concatenate((data, np.zeros(data.shape, "float32")), axis=-1)
+ else:
+ chanid = [channels[0] - 1]
+ if channels[1] > 0:
+ chanid.append(channels[1] - 1)
+ data = data[..., chanid]
+ for i in range(data.shape[-1]):
+ if np.ptp(data[..., i]) == 0.0:
+ if i == 0:
+ warnings.warn("'chan to seg' to seg has value range of ZERO")
+ else:
+ warnings.warn(
+ "'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0"
+ )
+ if data.shape[-1] == 1:
+ data = np.concatenate((data, np.zeros(data.shape, "float32")), axis=-1)
+ if chan_first:
+ if data.ndim == 4:
+ data = np.transpose(data, (3, 0, 1, 2))
+ else:
+ data = np.transpose(data, (2, 0, 1))
+ return data
+
+
+def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None,
+ percentile=(1., 99.), sharpen_radius=0, smooth_radius=0,
+ tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1):
+ """Normalize each channel of the image with optional inversion, smoothing, and sharpening.
+
+ Args:
+ img (ndarray): The input image. It should have at least 3 dimensions.
+ If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension.
+ normalize (bool, optional): Whether to perform normalization. Defaults to True.
+ norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will
+ be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False.
+ invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright.
+ Defaults to False.
+ lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization.
+ Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2)
+ for per-channel normalization. Incompatible with smoothing and sharpening.
+ Defaults to None.
+ percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be
+ a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0).
+ sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0.
+ smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0.
+ tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0.
+ tile_norm_smooth3D (int, optional): The smoothness factor for tile-based normalization in 3D. Defaults to 1.
+ axis (int, optional): The channel axis to loop over for normalization. Defaults to -1.
+
+ Returns:
+ ndarray: The normalized image of the same size.
+
+ Raises:
+ ValueError: If the image has less than 3 dimensions.
+ ValueError: If the provided lowhigh or percentile values are invalid.
+ ValueError: If the image is inverted without normalization.
+
+ """
+ if img.ndim < 3:
+ error_message = "Image needs to have at least 3 dimensions"
+ transforms_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ img_norm = img if img.dtype=="float32" else img.astype(np.float32)
+ if axis != -1 and axis != img_norm.ndim - 1:
+ img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last
+
+ nchan = img_norm.shape[-1]
+
+ # Validate and handle lowhigh bounds
+ if lowhigh is not None:
+ lowhigh = np.array(lowhigh)
+ if lowhigh.shape == (2,):
+ lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds
+ elif lowhigh.shape != (nchan, 2):
+ error_message = "`lowhigh` must have shape (2,) or (nchan, 2)"
+ transforms_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ # Validate percentile
+ if percentile is None:
+ percentile = (1.0, 99.0)
+ elif not (0 <= percentile[0] < percentile[1] <= 100):
+ error_message = "Invalid percentile range, should be between 0 and 100"
+ transforms_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ # Apply normalization based on lowhigh or percentile
+ cgood = np.zeros(nchan, "bool")
+ if lowhigh is not None:
+ for c in range(nchan):
+ lower = lowhigh[c, 0]
+ upper = lowhigh[c, 1]
+ img_norm[..., c] -= lower
+ img_norm[..., c] /= (upper - lower)
+ cgood[c] = True
+ else:
+ # Apply sharpening and smoothing if specified
+ if sharpen_radius > 0 or smooth_radius > 0:
+ img_norm = smooth_sharpen_img(
+ img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius
+ )
+
+ # Apply tile-based normalization or standard normalization
+ if tile_norm_blocksize > 0:
+ img_norm = normalize99_tile(
+ img_norm,
+ blocksize=tile_norm_blocksize,
+ lower=percentile[0],
+ upper=percentile[1],
+ smooth3D=tile_norm_smooth3D,
+ norm3D=norm3D,
+ )
+ elif normalize:
+ if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True
+ for c in range(nchan):
+ if np.ptp(img_norm[..., c]) > 0.:
+ img_norm[..., c] = normalize99(
+ img_norm[..., c],
+ lower=percentile[0],
+ upper=percentile[1],
+ copy=False, downsample=True,
+ )
+ cgood[c] = True
+ else: # i.e. if ZYXC with norm3D=False then per Z-slice
+ for z in range(img_norm.shape[0]):
+ for c in range(nchan):
+ if np.ptp(img_norm[z, ..., c]) > 0.:
+ img_norm[z, ..., c] = normalize99(
+ img_norm[z, ..., c],
+ lower=percentile[0],
+ upper=percentile[1],
+ copy=False, downsample=True,
+ )
+ cgood[c] = True
+
+
+ if invert:
+ if lowhigh is not None or tile_norm_blocksize > 0 or normalize:
+ for c in range(nchan):
+ if cgood[c]:
+ img_norm[..., c] = 1 - img_norm[..., c]
+ else:
+ error_message = "Cannot invert image without normalization"
+ transforms_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ # Move channel axis back to the original position
+ if axis != -1 and axis != img_norm.ndim - 1:
+ img_norm = np.moveaxis(img_norm, -1, axis)
+
+ return img_norm
+
+def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR):
+ """OpenCV resize function does not support uint32.
+
+ This function converts the image to float32 before resizing and then converts it back to uint32. Not safe!
+ References issue: https://github.com/MouseLand/cellpose/issues/937
+
+ Implications:
+ * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not
+ a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU.
+ * Memory: However, memory usage increases. Not tested by how much.
+
+ Args:
+ img (ndarray): Image of size [Ly x Lx].
+ Ly (int): Desired height of the resized image.
+ Lx (int): Desired width of the resized image.
+ interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
+
+ Returns:
+ ndarray: Resized image of size [Ly x Lx].
+
+ """
+
+ # cast image
+ cast = img.dtype == np.uint32
+ if cast:
+ img = img.astype(np.float32)
+
+ # resize
+ img = cv2.resize(img, (Lx, Ly), interpolation=interpolation)
+
+ # cast back
+ if cast:
+ img = img.round().astype(np.uint32)
+
+ return img
+
+
+def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR,
+ no_channels=False):
+ """Resize image for computing flows / unresize for computing dynamics.
+
+ Args:
+ img0 (ndarray): Image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X].
+ Ly (int, optional): Desired height of the resized image. Defaults to None.
+ Lx (int, optional): Desired width of the resized image. Defaults to None.
+ rsz (float, optional): Resize coefficient(s) for the image. If Ly is None, rsz is used. Defaults to None.
+ interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
+ no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel.
+ Defaults to False.
+
+ Returns:
+ ndarray: Resized image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
+
+ Raises:
+ ValueError: If Ly is None and rsz is None.
+
+ """
+ if Ly is None and rsz is None:
+ error_message = "must give size to resize to or factor to use for resizing"
+ transforms_logger.critical(error_message)
+ raise ValueError(error_message)
+
+ if Ly is None:
+ # determine Ly and Lx using rsz
+ if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
+ rsz = [rsz, rsz]
+ if no_channels:
+ Ly = int(img0.shape[-2] * rsz[-2])
+ Lx = int(img0.shape[-1] * rsz[-1])
+ else:
+ Ly = int(img0.shape[-3] * rsz[-2])
+ Lx = int(img0.shape[-2] * rsz[-1])
+
+ # no_channels useful for z-stacks, so the third dimension is not treated as a channel
+ # but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but
+ if (img0.ndim > 2 and no_channels) or (img0.ndim == 4 and not no_channels):
+ if Ly == 0 or Lx == 0:
+ raise ValueError(
+ "anisotropy too high / low -- not enough pixels to resize to ratio")
+ for i, img in enumerate(img0):
+ imgi = resize_safe(img, Ly, Lx, interpolation=interpolation)
+ if i==0:
+ if no_channels:
+ imgs = np.zeros((img0.shape[0], Ly, Lx), imgi.dtype)
+ else:
+ imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), imgi.dtype)
+ imgs[i] = imgi if imgi.ndim > 2 or no_channels else imgi[..., np.newaxis]
+ else:
+ imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation)
+ return imgs
+
+def get_pad_yx(Ly, Lx, div=16, extra=1, min_size=None):
+ if min_size is None or Ly >= min_size[-2]:
+ Lpad = int(div * np.ceil(Ly / div) - Ly)
+ else:
+ Lpad = min_size[-2] - Ly
+ ypad1 = extra * div // 2 + Lpad // 2
+ ypad2 = extra * div // 2 + Lpad - Lpad // 2
+ if min_size is None or Lx >= min_size[-1]:
+ Lpad = int(div * np.ceil(Lx / div) - Lx)
+ else:
+ Lpad = min_size[-1] - Lx
+ xpad1 = extra * div // 2 + Lpad // 2
+ xpad2 = extra * div // 2 + Lpad - Lpad // 2
+
+ return ypad1, ypad2, xpad1, xpad2
+
+
+def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False):
+ """Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D).
+
+ Args:
+ img0 (ndarray): Image of size [nchan (x Lz) x Ly x Lx].
+ div (int, optional): Divisor for padding. Defaults to 16.
+ extra (int, optional): Extra padding. Defaults to 1.
+ min_size (tuple, optional): Minimum size of the image. Defaults to None.
+
+ Returns:
+ A tuple containing (I, ysub, xsub) or (I, ysub, xsub, zsub), I is padded image, -sub are ranges of pixels in the padded image corresponding to img0.
+
+ """
+ Ly, Lx = img0.shape[-2:]
+ ypad1, ypad2, xpad1, xpad2 = get_pad_yx(Ly, Lx, div=div, extra=extra, min_size=min_size)
+
+ if img0.ndim > 3:
+ if zpad:
+ Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3])
+ zpad1 = extra * div // 2 + Lpad // 2
+ zpad2 = extra * div // 2 + Lpad - Lpad // 2
+ else:
+ zpad1, zpad2 = 0, 0
+ pads = np.array([[0, 0], [zpad1, zpad2], [ypad1, ypad2], [xpad1, xpad2]])
+ else:
+ pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
+
+ I = np.pad(img0, pads, mode="constant")
+
+ ysub = np.arange(ypad1, ypad1 + Ly)
+ xsub = np.arange(xpad1, xpad1 + Lx)
+ if zpad:
+ zsub = np.arange(zpad1, zpad1 + img0.shape[-3])
+ return I, ysub, xsub, zsub
+ else:
+ return I, ysub, xsub
+
+
+def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False,
+ zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False,
+ random_per_image=True):
+ """Augmentation by random rotation and resizing.
+
+ Args:
+ X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx].
+ Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx].
+ The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
+ If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
+ If unet, second channel is dist_to_bound. Defaults to None.
+ scale_range (float, optional): Range of resizing of images for augmentation.
+ Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0.
+ xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224).
+ do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True.
+ rotate (bool, optional): Whether or not to rotate images. Defaults to True.
+ rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None.
+ unet (bool, optional): Whether or not to use unet. Defaults to False.
+ random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True.
+
+ Returns:
+ A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]];
+ lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]];
+ scale (array, float): Amount each image was resized by.
+ """
+ scale_range = max(0, min(2, float(scale_range)))
+ nimg = len(X)
+ if X[0].ndim > 2:
+ nchan = X[0].shape[0]
+ else:
+ nchan = 1
+ if do_3D and X[0].ndim > 3:
+ shape = (zcrop, xy[0], xy[1])
+ else:
+ shape = (xy[0], xy[1])
+ imgi = np.zeros((nimg, nchan, *shape), "float32")
+
+ lbl = []
+ if Y is not None:
+ if Y[0].ndim > 2:
+ nt = Y[0].shape[0]
+ else:
+ nt = 1
+ lbl = np.zeros((nimg, nt, *shape), np.float32)
+
+ scale = np.ones(nimg, np.float32)
+
+ for n in range(nimg):
+
+ if random_per_image or n == 0:
+ Ly, Lx = X[n].shape[-2:]
+ # generate random augmentation parameters
+ flip = np.random.rand() > .5
+ theta = np.random.rand() * np.pi * 2 if rotate else 0.
+ scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand()
+ if rescale is not None:
+ scale[n] *= 1. / rescale[n]
+ dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1],
+ Ly * scale[n] - xy[0]]))
+ dxy = (np.random.rand(2,) - .5) * dxy
+
+ # create affine transform
+ cc = np.array([Lx / 2, Ly / 2])
+ cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy
+ pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
+ pts2 = np.float32([
+ cc1,
+ cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]),
+ cc1 + scale[n] *
+ np.array([np.cos(np.pi / 2 + theta),
+ np.sin(np.pi / 2 + theta)])
+ ])
+ M = cv2.getAffineTransform(pts1, pts2)
+
+ img = X[n].copy()
+ if Y is not None:
+ labels = Y[n].copy()
+ if labels.ndim < 3:
+ labels = labels[np.newaxis, :, :]
+
+ if do_3D:
+ Lz = X[n].shape[-3]
+ flip_z = np.random.rand() > .5
+ lz = int(np.round(zcrop / scale[n]))
+ iz = np.random.randint(0, Lz - lz)
+ img = img[:,iz:iz + lz,:,:]
+ if Y is not None:
+ labels = labels[:,iz:iz + lz,:,:]
+
+ if do_flip:
+ if flip:
+ img = img[..., ::-1]
+ if Y is not None:
+ labels = labels[..., ::-1]
+ if nt > 1 and not unet:
+ labels[-1] = -labels[-1]
+ if do_3D and flip_z:
+ img = img[:, ::-1]
+ if Y is not None:
+ labels = labels[:,::-1]
+ if nt > 1 and not unet:
+ labels[-3] = -labels[-3]
+
+ for k in range(nchan):
+ if do_3D:
+ img0 = np.zeros((lz, xy[0], xy[1]), "float32")
+ for z in range(lz):
+ I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
+ flags=cv2.INTER_LINEAR)
+ img0[z] = I
+ if scale[n] != 1.0:
+ for y in range(imgi.shape[-2]):
+ imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
+ interpolation=cv2.INTER_LINEAR)
+ else:
+ imgi[n, k] = img0
+ else:
+ I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
+ imgi[n, k] = I
+
+ if Y is not None:
+ for k in range(nt):
+ flag = cv2.INTER_NEAREST if k == 0 else cv2.INTER_LINEAR
+ if do_3D:
+ lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
+ for z in range(lz):
+ I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
+ flags=flag)
+ lbl0[z] = I
+ if scale[n] != 1.0:
+ for y in range(lbl.shape[-2]):
+ lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
+ interpolation=flag)
+ else:
+ lbl[n, k] = lbl0
+ else:
+ lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
+
+ if nt > 1 and not unet:
+ v1 = lbl[n, -1].copy()
+ v2 = lbl[n, -2].copy()
+ lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta))
+ lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta))
+
+ return imgi, lbl, scale
diff --git a/cellpose/utils.py b/cellpose/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac6549608ebd8fd7309fe17322b176bd8598e642
--- /dev/null
+++ b/cellpose/utils.py
@@ -0,0 +1,659 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+import logging
+import os, tempfile, shutil, io
+from tqdm import tqdm, trange
+from urllib.request import urlopen
+import cv2
+from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label, maximum_filter1d, binary_fill_holes
+from scipy.spatial import ConvexHull
+import numpy as np
+import colorsys
+import fastremap
+from multiprocessing import Pool, cpu_count
+
+from . import metrics
+
+try:
+ from skimage.morphology import remove_small_holes
+ SKIMAGE_ENABLED = True
+except:
+ SKIMAGE_ENABLED = False
+
+
+class TqdmToLogger(io.StringIO):
+ """
+ Output stream for TQDM which will output to logger module instead of
+ the StdOut.
+ """
+ logger = None
+ level = None
+ buf = ""
+
+ def __init__(self, logger, level=None):
+ super(TqdmToLogger, self).__init__()
+ self.logger = logger
+ self.level = level or logging.INFO
+
+ def write(self, buf):
+ self.buf = buf.strip("\r\n\t ")
+
+ def flush(self):
+ self.logger.log(self.level, self.buf)
+
+
+def rgb_to_hsv(arr):
+ rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv)
+ r, g, b = np.rollaxis(arr, axis=-1)
+ h, s, v = rgb_to_hsv_channels(r, g, b)
+ hsv = np.stack((h, s, v), axis=-1)
+ return hsv
+
+
+def hsv_to_rgb(arr):
+ hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
+ h, s, v = np.rollaxis(arr, axis=-1)
+ r, g, b = hsv_to_rgb_channels(h, s, v)
+ rgb = np.stack((r, g, b), axis=-1)
+ return rgb
+
+
+def download_url_to_file(url, dst, progress=True):
+ r"""Download object at the given URL to a local path.
+ Thanks to torch, slightly modified
+ Args:
+ url (string): URL of the object to download
+ dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
+ progress (bool, optional): whether or not to display a progress bar to stderr
+ Default: True
+ """
+ file_size = None
+ import ssl
+ ssl._create_default_https_context = ssl._create_unverified_context
+ u = urlopen(url)
+ meta = u.info()
+ if hasattr(meta, "getheaders"):
+ content_length = meta.getheaders("Content-Length")
+ else:
+ content_length = meta.get_all("Content-Length")
+ if content_length is not None and len(content_length) > 0:
+ file_size = int(content_length[0])
+ # We deliberately save it in a temp file and move it after
+ dst = os.path.expanduser(dst)
+ dst_dir = os.path.dirname(dst)
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
+ try:
+ with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True,
+ unit_divisor=1024) as pbar:
+ while True:
+ buffer = u.read(8192)
+ if len(buffer) == 0:
+ break
+ f.write(buffer)
+ pbar.update(len(buffer))
+ f.close()
+ shutil.move(f.name, dst)
+ finally:
+ f.close()
+ if os.path.exists(f.name):
+ os.remove(f.name)
+
+
+def distance_to_boundary(masks):
+ """Get the distance to the boundary of mask pixels.
+
+ Args:
+ masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
+
+ Returns:
+ dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx].
+
+ Raises:
+ ValueError: If the masks array is not 2D or 3D.
+
+ """
+ if masks.ndim > 3 or masks.ndim < 2:
+ raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" %
+ masks.ndim)
+ dist_to_bound = np.zeros(masks.shape, np.float64)
+
+ if masks.ndim == 3:
+ for i in range(masks.shape[0]):
+ dist_to_bound[i] = distance_to_boundary(masks[i])
+ return dist_to_bound
+ else:
+ slices = find_objects(masks)
+ for i, si in enumerate(slices):
+ if si is not None:
+ sr, sc = si
+ mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
+ ypix, xpix = np.nonzero(mask)
+ min_dist = ((ypix[:, np.newaxis] - pvr)**2 +
+ (xpix[:, np.newaxis] - pvc)**2).min(axis=1)
+ dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist
+ return dist_to_bound
+
+
+def masks_to_edges(masks, threshold=1.0):
+ """Get edges of masks as a 0-1 array.
+
+ Args:
+ masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
+ threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0.
+
+ Returns:
+ edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels.
+ """
+ dist_to_bound = distance_to_boundary(masks)
+ edges = (dist_to_bound < threshold) * (masks > 0)
+ return edges
+
+
+def remove_edge_masks(masks, change_index=True):
+ """Removes masks with pixels on the edge of the image.
+
+ Args:
+ masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
+ change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True.
+
+ Returns:
+ outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
+ """
+ slices = find_objects(masks.astype(int))
+ for i, si in enumerate(slices):
+ remove = False
+ if si is not None:
+ for d, sid in enumerate(si):
+ if sid.start == 0 or sid.stop == masks.shape[d]:
+ remove = True
+ break
+ if remove:
+ masks[si][masks[si] == i + 1] = 0
+ shape = masks.shape
+ if change_index:
+ _, masks = np.unique(masks, return_inverse=True)
+ masks = np.reshape(masks, shape).astype(np.int32)
+
+ return masks
+
+
+def masks_to_outlines(masks):
+ """Get outlines of masks as a 0-1 array.
+
+ Args:
+ masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
+
+ Returns:
+ outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
+ """
+ if masks.ndim > 3 or masks.ndim < 2:
+ raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
+ masks.ndim)
+ outlines = np.zeros(masks.shape, bool)
+
+ if masks.ndim == 3:
+ for i in range(masks.shape[0]):
+ outlines[i] = masks_to_outlines(masks[i])
+ return outlines
+ else:
+ slices = find_objects(masks.astype(int))
+ for i, si in enumerate(slices):
+ if si is not None:
+ sr, sc = si
+ mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_NONE)
+ pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
+ vr, vc = pvr + sr.start, pvc + sc.start
+ outlines[vr, vc] = 1
+ return outlines
+
+
+def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None):
+ """Get outlines of masks as a list to loop over for plotting.
+
+ Args:
+ masks (ndarray): Array of masks.
+ multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000.
+ multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None.
+
+ Returns:
+ list: List of outlines.
+
+ Raises:
+ None
+
+ Notes:
+ - This function is a wrapper for outlines_list_single and outlines_list_multi.
+ - Multiprocessing is disabled for Windows.
+ """
+ # default to use multiprocessing if not few_masks, but allow user to override
+ if multiprocessing is None:
+ few_masks = np.max(masks) < multiprocessing_threshold
+ multiprocessing = not few_masks
+
+ # disable multiprocessing for Windows
+ if os.name == "nt":
+ if multiprocessing:
+ logging.getLogger(__name__).warning(
+ "Multiprocessing is disabled for Windows")
+ multiprocessing = False
+
+ if multiprocessing:
+ return outlines_list_multi(masks)
+ else:
+ return outlines_list_single(masks)
+
+
+def outlines_list_single(masks):
+ """Get outlines of masks as a list to loop over for plotting.
+
+ Args:
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
+
+ Returns:
+ list: List of outlines as pixel coordinates.
+
+ """
+ outpix = []
+ for n in np.unique(masks)[1:]:
+ mn = masks == n
+ if mn.sum() > 0:
+ contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
+ method=cv2.CHAIN_APPROX_NONE)
+ contours = contours[-2]
+ cmax = np.argmax([c.shape[0] for c in contours])
+ pix = contours[cmax].astype(int).squeeze()
+ if len(pix) > 4:
+ outpix.append(pix)
+ else:
+ outpix.append(np.zeros((0, 2)))
+ return outpix
+
+
+def outlines_list_multi(masks, num_processes=None):
+ """
+ Get outlines of masks as a list to loop over for plotting.
+
+ Args:
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
+
+ Returns:
+ list: List of outlines as pixel coordinates.
+ """
+ if num_processes is None:
+ num_processes = cpu_count()
+
+ unique_masks = np.unique(masks)[1:]
+ with Pool(processes=num_processes) as pool:
+ outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
+ return outpix
+
+
+def get_outline_multi(args):
+ """Get the outline of a specific mask in a multi-mask image.
+
+ Args:
+ args (tuple): A tuple containing the masks and the mask number.
+
+ Returns:
+ numpy.ndarray: The outline of the specified mask as an array of coordinates.
+
+ """
+ masks, n = args
+ mn = masks == n
+ if mn.sum() > 0:
+ contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
+ method=cv2.CHAIN_APPROX_NONE)
+ contours = contours[-2]
+ cmax = np.argmax([c.shape[0] for c in contours])
+ pix = contours[cmax].astype(int).squeeze()
+ return pix if len(pix) > 4 else np.zeros((0, 2))
+ return np.zeros((0, 2))
+
+
+def dilate_masks(masks, n_iter=5):
+ """Dilate masks by n_iter pixels.
+
+ Args:
+ masks (ndarray): Array of masks.
+ n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5.
+
+ Returns:
+ ndarray: Dilated masks.
+ """
+ dilated_masks = masks.copy()
+ for n in range(n_iter):
+ # define the structuring element to use for dilation
+ kernel = np.ones((3, 3), "uint8")
+ # find the distance to each mask (distances are zero within masks)
+ dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"),
+ cv2.DIST_L2, 5)
+ # dilate each mask and assign to it the pixels along the border of the mask
+ # (does not allow dilation into other masks since dist_transform is zero there)
+ for i in range(1, np.max(masks) + 1):
+ mask = (dilated_masks == i).astype("uint8")
+ dilated_mask = cv2.dilate(mask, kernel, iterations=1)
+ dilated_mask = np.logical_and(dist_transform < 2, dilated_mask)
+ dilated_masks[dilated_mask > 0] = i
+ return dilated_masks
+
+
+def get_perimeter(points):
+ """
+ Calculate the perimeter of a set of points.
+
+ Parameters:
+ points (ndarray): An array of points with shape (npoints, ndim).
+
+ Returns:
+ float: The perimeter of the points.
+
+ """
+ if points.shape[0] > 4:
+ points = np.append(points, points[:1], axis=0)
+ return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum()
+ else:
+ return 0
+
+
+def get_mask_compactness(masks):
+ """
+ Calculate the compactness of masks.
+
+ Parameters:
+ masks (ndarray): Binary masks representing objects.
+
+ Returns:
+ ndarray: Array of compactness values for each mask.
+ """
+ perimeters = get_mask_perimeters(masks)
+ npoints = np.unique(masks, return_counts=True)[1][1:]
+ areas = npoints
+ compactness = 4 * np.pi * areas / perimeters**2
+ compactness[perimeters == 0] = 0
+ compactness[compactness > 1.0] = 1.0
+ return compactness
+
+
+def get_mask_perimeters(masks):
+ """
+ Calculate the perimeters of the given masks.
+
+ Parameters:
+ masks (numpy.ndarray): Binary masks representing objects.
+
+ Returns:
+ numpy.ndarray: Array containing the perimeters of each mask.
+ """
+ perimeters = np.zeros(masks.max())
+ for n in range(masks.max()):
+ mn = masks == (n + 1)
+ if mn.sum() > 0:
+ contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
+ method=cv2.CHAIN_APPROX_NONE)[-2]
+ perimeters[n] = np.array(
+ [get_perimeter(c.astype(int).squeeze()) for c in contours]).sum()
+
+ return perimeters
+
+
+def circleMask(d0):
+ """
+ Creates an array with indices which are the radius of that x,y point.
+
+ Args:
+ d0 (tuple): Patch of (-d0, d0+1) over which radius is computed.
+
+ Returns:
+ tuple: A tuple containing:
+ - rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1).
+ - dx (ndarray): Indices of the patch along the x-axis.
+ - dy (ndarray): Indices of the patch along the y-axis.
+ """
+ dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1))
+ dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1))
+ dy = dy.transpose()
+
+ rs = (dy**2 + dx**2)**0.5
+ return rs, dx, dy
+
+
+def get_mask_stats(masks_true):
+ """
+ Calculate various statistics for the given binary masks.
+
+ Parameters:
+ masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
+
+ Returns:
+ convexity (ndarray): Convexity values for each mask.
+ solidity (ndarray): Solidity values for each mask.
+ compactness (ndarray): Compactness values for each mask.
+ """
+ mask_perimeters = get_mask_perimeters(masks_true)
+
+ # disk for compactness
+ rs, dy, dx = circleMask(np.array([100, 100]))
+ rsort = np.sort(rs.flatten())
+
+ # area for solidity
+ npoints = np.unique(masks_true, return_counts=True)[1][1:]
+ areas = npoints - mask_perimeters / 2 - 1
+
+ compactness = np.zeros(masks_true.max())
+ convexity = np.zeros(masks_true.max())
+ solidity = np.zeros(masks_true.max())
+ convex_perimeters = np.zeros(masks_true.max())
+ convex_areas = np.zeros(masks_true.max())
+ for ic in range(masks_true.max()):
+ points = np.array(np.nonzero(masks_true == (ic + 1))).T
+ if len(points) > 15 and mask_perimeters[ic] > 0:
+ med = np.median(points, axis=0)
+ # compute compactness of ROI
+ r2 = ((points - med)**2).sum(axis=1)**0.5
+ compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean()
+ try:
+ hull = ConvexHull(points)
+ convex_perimeters[ic] = hull.area
+ convex_areas[ic] = hull.volume
+ except:
+ convex_perimeters[ic] = 0
+
+ convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] /
+ mask_perimeters[mask_perimeters > 0.0])
+ solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] /
+ convex_areas[convex_areas > 0.0])
+ convexity = np.clip(convexity, 0.0, 1.0)
+ solidity = np.clip(solidity, 0.0, 1.0)
+ compactness = np.clip(compactness, 0.0, 1.0)
+ return convexity, solidity, compactness
+
+
+def get_masks_unet(output, cell_threshold=0, boundary_threshold=0):
+ """Create masks using cell probability and cell boundary.
+
+ Args:
+ output (ndarray): The output array containing cell probability and cell boundary.
+ cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0.
+ boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0.
+
+ Returns:
+ ndarray: The masks representing the segmented cells.
+
+ """
+ cells = (output[..., 1] - output[..., 0]) > cell_threshold
+ selem = generate_binary_structure(cells.ndim, connectivity=1)
+ labels, nlabels = label(cells, selem)
+
+ if output.shape[-1] > 2:
+ slices = find_objects(labels)
+ dists = 10000 * np.ones(labels.shape, np.float32)
+ mins = np.zeros(labels.shape, np.int32)
+ borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold)
+ pad = 10
+ for i, slc in enumerate(slices):
+ if slc is not None:
+ slc_pad = tuple([
+ slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad))
+ for j, sli in enumerate(slc)
+ ])
+ msk = (labels[slc_pad] == (i + 1)).astype(np.float32)
+ msk = 1 - gaussian_filter(msk, 5)
+ dists[slc_pad] = np.minimum(dists[slc_pad], msk)
+ mins[slc_pad][dists[slc_pad] == msk] = (i + 1)
+ labels[labels == 0] = borders[labels == 0] * mins[labels == 0]
+
+ masks = labels
+ shape0 = masks.shape
+ _, masks = np.unique(masks, return_inverse=True)
+ masks = np.reshape(masks, shape0)
+ return masks
+
+
+def stitch3D(masks, stitch_threshold=0.25):
+ """
+ Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
+
+ Args:
+ masks (list or ndarray): List of 2D masks.
+ stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
+
+ Returns:
+ list: List of stitched 3D masks.
+ """
+ mmax = masks[0].max()
+ empty = 0
+ for i in trange(len(masks) - 1):
+ iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
+ if not iou.size and empty == 0:
+ masks[i + 1] = masks[i + 1]
+ mmax = masks[i + 1].max()
+ elif not iou.size and not empty == 0:
+ icount = masks[i + 1].max()
+ istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
+ mmax += icount
+ istitch = np.append(np.array(0), istitch)
+ masks[i + 1] = istitch[masks[i + 1]]
+ else:
+ iou[iou < stitch_threshold] = 0.0
+ iou[iou < iou.max(axis=0)] = 0.0
+ istitch = iou.argmax(axis=1) + 1
+ ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
+ istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
+ mmax += len(ino)
+ istitch = np.append(np.array(0), istitch)
+ masks[i + 1] = istitch[masks[i + 1]]
+ empty = 1
+
+ return masks
+
+
+def diameters(masks):
+ """
+ Calculate the diameters of the objects in the given masks.
+
+ Parameters:
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
+
+ Returns:
+ tuple: A tuple containing the median diameter and an array of diameters for each object.
+
+ Examples:
+ >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
+ >>> diameters(masks)
+ (1.0, array([1.41421356, 1.0, 1.0]))
+ """
+ uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
+ counts = counts[1:]
+ md = np.median(counts**0.5)
+ if np.isnan(md):
+ md = 0
+ md /= (np.pi**0.5) / 2
+ return md, counts**0.5
+
+
+def radius_distribution(masks, bins):
+ """
+ Calculate the radius distribution of masks.
+
+ Args:
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
+ bins (int): Number of bins for the histogram.
+
+ Returns:
+ A tuple containing a normalized histogram of radii, median radius, array of radii.
+
+ """
+ unique, counts = np.unique(masks, return_counts=True)
+ counts = counts[unique != 0]
+ nb, _ = np.histogram((counts**0.5) * 0.5, bins)
+ nb = nb.astype(np.float32)
+ if nb.sum() > 0:
+ nb = nb / nb.sum()
+ md = np.median(counts**0.5) * 0.5
+ if np.isnan(md):
+ md = 0
+ md /= (np.pi**0.5) / 2
+ return nb, md, (counts**0.5) / 2
+
+
+def size_distribution(masks):
+ """
+ Calculates the size distribution of masks.
+
+ Args:
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
+
+ Returns:
+ float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
+ """
+ counts = np.unique(masks, return_counts=True)[1][1:]
+ return np.percentile(counts, 25) / np.percentile(counts, 75)
+
+
+def fill_holes_and_remove_small_masks(masks, min_size=15):
+ """ Fills holes in masks (2D/3D) and discards masks smaller than min_size.
+
+ This function fills holes in each mask using scipy.ndimage.morphology.binary_fill_holes.
+ It also removes masks that are smaller than the specified min_size.
+
+ Parameters:
+ masks (ndarray): Int, 2D or 3D array of labelled masks.
+ 0 represents no mask, while positive integers represent mask labels.
+ The size can be [Ly x Lx] or [Lz x Ly x Lx].
+ min_size (int, optional): Minimum number of pixels per mask.
+ Masks smaller than min_size will be removed.
+ Set to -1 to turn off this functionality. Default is 15.
+
+ Returns:
+ ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
+ 0 represents no mask, while positive integers represent mask labels.
+ The size is [Ly x Lx] or [Lz x Ly x Lx].
+ """
+
+ if masks.ndim > 3 or masks.ndim < 2:
+ raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
+ masks.ndim)
+
+ slices = find_objects(masks)
+ j = 0
+ for i, slc in enumerate(slices):
+ if slc is not None:
+ msk = masks[slc] == (i + 1)
+ npix = msk.sum()
+ if min_size > 0 and npix < min_size:
+ masks[slc][msk] = 0
+ elif npix > 0:
+ if msk.ndim == 3:
+ for k in range(msk.shape[0]):
+ msk[k] = binary_fill_holes(msk[k])
+ else:
+ msk = binary_fill_holes(msk)
+ masks[slc][msk] = (j + 1)
+ j += 1
+ return masks
diff --git a/cellpose/version.py b/cellpose/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4bc34f08462dbb2c8e3c41c80240fbbf4334d38
--- /dev/null
+++ b/cellpose/version.py
@@ -0,0 +1,19 @@
+"""
+Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
+"""
+
+from importlib.metadata import PackageNotFoundError, version
+import sys
+from platform import python_version
+import torch
+
+try:
+ version = version("cellpose")
+except PackageNotFoundError:
+ version = "unknown"
+
+version_str = f"""
+cellpose version: \t{version}
+platform: \t{sys.platform}
+python version: \t{python_version()}
+torch version: \t{torch.__version__}"""
diff --git a/example_images/er_cos-7.png b/example_images/er_cos-7.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d88935594de634a67da1fb23e62c1ad5cd85f48
--- /dev/null
+++ b/example_images/er_cos-7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c56d2010cee1cd6d023979577c59b07ce97b1026b76e7562dd23c7bdf9393ca
+size 153569
diff --git a/example_images/er_hela.png b/example_images/er_hela.png
new file mode 100644
index 0000000000000000000000000000000000000000..2ba48891af510b6d58d8e94ce4517017bfcf3923
--- /dev/null
+++ b/example_images/er_hela.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:823ba07e97664d1887382272ff086796956ecdc567a59cfe6f064b1ad5fb1483
+size 124319
diff --git a/example_images/f-actin_cos-7.png b/example_images/f-actin_cos-7.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f2216147fa2feea0982b32374806d552528f3af
--- /dev/null
+++ b/example_images/f-actin_cos-7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:00a8ee7bcfb7d16c9ee8e95aaf8f54417877085bdd819f19f99459931efb170d
+size 180910
diff --git a/example_images/microtubules_hela.png b/example_images/microtubules_hela.png
new file mode 100644
index 0000000000000000000000000000000000000000..73c3ed75b39ac756fd071865f8228b6d2992acfe
--- /dev/null
+++ b/example_images/microtubules_hela.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3733f4a2c8de149cb94f91a24e4b7fb02985c62bb573410d94b4575ff4399d7d
+size 90439
diff --git a/example_images/mitochondria_bpae.png b/example_images/mitochondria_bpae.png
new file mode 100644
index 0000000000000000000000000000000000000000..d31e29ff41053f9e57c1d6ee206e5a16f7d190c4
--- /dev/null
+++ b/example_images/mitochondria_bpae.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5a81c59d2aaa187f1737a673b25da8bd2dd58e944d6eb2c38d5e8927b5a4f96
+size 99075
diff --git a/example_images/nucleus_bpae.png b/example_images/nucleus_bpae.png
new file mode 100644
index 0000000000000000000000000000000000000000..c8c191dc4491d4cdde1f953f951766af46e4fee9
--- /dev/null
+++ b/example_images/nucleus_bpae.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0425ddce55927bd67407a356ab16ac5e1c484ecf8258efb6ed7033c9413a1a1
+size 75708
diff --git a/example_images_cls/cls_input_1754670277543.png b/example_images_cls/cls_input_1754670277543.png
new file mode 100644
index 0000000000000000000000000000000000000000..522eda5f6bf11465c1493339f361aa4ac825d025
--- /dev/null
+++ b/example_images_cls/cls_input_1754670277543.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc2c084aecd4610274280ffc0b8dd2eafdb0899a7cbf70d4b15708b2f13cb733
+size 63959
diff --git a/example_images_cls/cls_input_1754670358223.png b/example_images_cls/cls_input_1754670358223.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a4288ee086455ace24309f6033a6c23ebcdae70
--- /dev/null
+++ b/example_images_cls/cls_input_1754670358223.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:922489201a874d70d6bed45e30dc98fb75f7d90d4910f868d53b2a380caa8821
+size 37023
diff --git a/example_images_cls/cls_input_1754670363268.png b/example_images_cls/cls_input_1754670363268.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a4288ee086455ace24309f6033a6c23ebcdae70
--- /dev/null
+++ b/example_images_cls/cls_input_1754670363268.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:922489201a874d70d6bed45e30dc98fb75f7d90d4910f868d53b2a380caa8821
+size 37023
diff --git a/example_images_cls/cls_input_1754670366893.png b/example_images_cls/cls_input_1754670366893.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c8bfc632735cb6080ebf6dea426363cebb4ba48
--- /dev/null
+++ b/example_images_cls/cls_input_1754670366893.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f704fe2bbd72985491fb419dd9e8998e76b21e0b059cff092cfd1dbf2da9d5d7
+size 42489
diff --git a/example_images_cls/cls_input_1754670372624.png b/example_images_cls/cls_input_1754670372624.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c8bfc632735cb6080ebf6dea426363cebb4ba48
--- /dev/null
+++ b/example_images_cls/cls_input_1754670372624.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f704fe2bbd72985491fb419dd9e8998e76b21e0b059cff092cfd1dbf2da9d5d7
+size 42489
diff --git a/example_images_cls/erdak_1.png b/example_images_cls/erdak_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ba4daecdd67b9b3d5ee0d104656367161be541b
--- /dev/null
+++ b/example_images_cls/erdak_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b37c672e4debb464d1d3756a3a862721b310f762472914236325c44779a13a8a
+size 37826
diff --git a/example_images_cls/gidap_1.png b/example_images_cls/gidap_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..ccd2b45720a749dba46d6f52be277f785d0dd1f5
--- /dev/null
+++ b/example_images_cls/gidap_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6743b29ae6cbfbf118102e15370e8e3510efc81613e535ae4d231af68f84c37
+size 25636
diff --git a/example_images_cls/tubul_1.png b/example_images_cls/tubul_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ab81138d994d037274e0e0305ed05818e592a80
--- /dev/null
+++ b/example_images_cls/tubul_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc21b98eb667f6560f846293a60534e455e619ff8c42f3b66f4d9fb1ab32d0a5
+size 21904
diff --git a/example_images_dn/dn_input_MICE_1754056108129.tif b/example_images_dn/dn_input_MICE_1754056108129.tif
new file mode 100644
index 0000000000000000000000000000000000000000..b0717a6c509111081d2dfa1248d3161d82c10bd8
--- /dev/null
+++ b/example_images_dn/dn_input_MICE_1754056108129.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e95984e9259e8b08f653693d692808c0001794fe180a10faf2e2640cfbba92bd
+size 262400
diff --git a/example_images_m2i/0002.tif b/example_images_m2i/0002.tif
new file mode 100644
index 0000000000000000000000000000000000000000..d9e8d54d32fa4ee24a945176df8b81aa248ce38c
--- /dev/null
+++ b/example_images_m2i/0002.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:566676c8e6bfe1cb3e952a63504d35fbff43ad78a41985a1d0357c2a0e5210c6
+size 262266
diff --git a/example_images_seg/seg_input_1754626817311.tif b/example_images_seg/seg_input_1754626817311.tif
new file mode 100644
index 0000000000000000000000000000000000000000..e51336fff932e7e396c0208d3b3ca8db5b48cdc7
--- /dev/null
+++ b/example_images_seg/seg_input_1754626817311.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2de23ffc7eb40791dfea1ec07e92d7295c55e3f5408827b1591e4fcc7a1e489c
+size 262400
diff --git a/example_images_sr/00000001.tif b/example_images_sr/00000001.tif
new file mode 100644
index 0000000000000000000000000000000000000000..653f4cacdb2d1d85957fbb005ecad12406ecc6f3
--- /dev/null
+++ b/example_images_sr/00000001.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:436a899290607814ba5f5c0e9456e9d6551bb2cb986223aff7fc4ae396d7a86a
+size 1181232
diff --git a/example_images_t2i/er_cos-7.png b/example_images_t2i/er_cos-7.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ee5823b6ce8160f7c87b1c6ca8bd931e88a9e90
--- /dev/null
+++ b/example_images_t2i/er_cos-7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67d75d10bd20a95f10e4a4a8993e4260c36f4ff576686321f2582deef738ae1e
+size 138638
diff --git a/example_images_t2i/er_hela.png b/example_images_t2i/er_hela.png
new file mode 100644
index 0000000000000000000000000000000000000000..31d3f2b71b1c229ab2166210bdfc44f06c0c4b37
--- /dev/null
+++ b/example_images_t2i/er_hela.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec846061c99ac794d0bd44292792d762dab9136438d0f7db892630ba9e715573
+size 124240
diff --git a/example_images_t2i/f-actin_cos-7.png b/example_images_t2i/f-actin_cos-7.png
new file mode 100644
index 0000000000000000000000000000000000000000..d1a99e4a2451cfb4543404cca83c0e9a9d44455a
--- /dev/null
+++ b/example_images_t2i/f-actin_cos-7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbfcd14c047749fb3a304c3e70e92ce4f8b890ac2d1f1785260aab3313bcda05
+size 158679
diff --git a/example_images_t2i/microtubules_hela.png b/example_images_t2i/microtubules_hela.png
new file mode 100644
index 0000000000000000000000000000000000000000..bafb3034b2fca58d53afff8dc5b9013a90fb2c9e
--- /dev/null
+++ b/example_images_t2i/microtubules_hela.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:deb5da46a71c3d1796e3353b9fd5127afed6e723de0250ff4eda8a005be21573
+size 136049
diff --git a/example_images_t2i/mitochondria_bpae.png b/example_images_t2i/mitochondria_bpae.png
new file mode 100644
index 0000000000000000000000000000000000000000..39e8abad8765766da69515a2a181aad8b06867e3
--- /dev/null
+++ b/example_images_t2i/mitochondria_bpae.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4f7b3eb68ac603e8868e196289517a16a00579d5968ad5c4c06d825e8a6d001
+size 94118
diff --git a/example_images_t2i/nucleus_bpae.png b/example_images_t2i/nucleus_bpae.png
new file mode 100644
index 0000000000000000000000000000000000000000..01add92f86b7e66d821ed7c61164c0d6ed277546
--- /dev/null
+++ b/example_images_t2i/nucleus_bpae.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63987c6ccef723b9e7f71a9bc4ebdee54fc9b5c800f89848b212f0a7afd8ffce
+size 31676
diff --git a/models/__pycache__/controlnet.cpython-311.pyc b/models/__pycache__/controlnet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc14d9367d1441e0607e920112b4cf12a6b2a38b
Binary files /dev/null and b/models/__pycache__/controlnet.cpython-311.pyc differ
diff --git a/models/__pycache__/pipeline_controlnet.cpython-311.pyc b/models/__pycache__/pipeline_controlnet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..949440f11b21a0ff28ce4acc422ede50d575b1e9
Binary files /dev/null and b/models/__pycache__/pipeline_controlnet.cpython-311.pyc differ
diff --git a/models/__pycache__/pipeline_ddpm_text_encoder.cpython-311.pyc b/models/__pycache__/pipeline_ddpm_text_encoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d690cb649898871fe3b65e48ce39537a31c80e75
Binary files /dev/null and b/models/__pycache__/pipeline_ddpm_text_encoder.cpython-311.pyc differ
diff --git a/models/__pycache__/unet_2d.cpython-311.pyc b/models/__pycache__/unet_2d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94444c02e9b8cc26287fdab50b9c3c2df08e8b98
Binary files /dev/null and b/models/__pycache__/unet_2d.cpython-311.pyc differ
diff --git a/models/__pycache__/unet_2d_condition.cpython-311.pyc b/models/__pycache__/unet_2d_condition.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dbfcf2d6ef7045ee62018481f48dc93a77fc585
Binary files /dev/null and b/models/__pycache__/unet_2d_condition.cpython-311.pyc differ
diff --git a/models/controlnet.py b/models/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df3908b3cccad8e06725f25bfbbc7285b16a3eb
--- /dev/null
+++ b/models/controlnet.py
@@ -0,0 +1,494 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+import pdb
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unets.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from models.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 Γ 512 images into smaller 64 Γ 64 βlatent imagesβ for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 Γ 64 feature space to match the
+ convolution size. We use a tiny network E(Β·) of four convolution layers with 4 Γ 4 kernels and 2 Γ 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=1))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ A ControlNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ attention_head_dim: Union[int] = 8,
+ resnet_time_scale_shift: str = "default",
+ add_attention: bool = True,
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ use_prompt: bool = False,
+ encoder_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ # prompt
+ if use_prompt:
+ self.prompt_embedding = nn.Sequential(
+ nn.Linear(encoder_size, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim)
+ )
+ else:
+ self.prompt_embedding = None
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ if mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=add_attention,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ mid_block_type = unet.config.mid_block_type if "mid_block_type" in unet.config else "UNetMidBlock2D"
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ in_channels=unet.config.in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ mid_block_type=mid_block_type,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ transformer_layers_per_block=transformer_layers_per_block,
+ attention_head_dim=unet.config.attention_head_dim,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ use_prompt=unet.config.use_prompt,
+ encoder_size=unet.config.encoder_size,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+ controlnet.prompt_embedding.load_state_dict(unet.prompt_embedding.state_dict())
+
+ if hasattr(controlnet, "add_embedding"):
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ controlnet_cond: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ conditioning_scale: float = 1.0,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`ControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ if self.prompt_embedding is not None:
+ encoder_hidden_states = encoder_hidden_states.reshape(sample.shape[0], -1).contiguous()
+ prompt_emb = self.prompt_embedding(encoder_hidden_states)
+ emb = emb + prompt_emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ sample = sample + controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/models/pipeline_controlnet.py b/models/pipeline_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f83e12ba7f9f5608c2aec6154ed7c6c7d98a6d15
--- /dev/null
+++ b/models/pipeline_controlnet.py
@@ -0,0 +1,106 @@
+from typing import List, Optional, Tuple, Union
+from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
+from diffusers.image_processor import PipelineImageInput
+from transformers import CLIPTextModel, CLIPTokenizer
+from models.controlnet import ControlNetModel
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+from models.pipeline_ddpm_text_encoder import DDPMPipeline
+
+import torch
+import pdb
+import skimage
+import numpy as np
+
+class DDPMControlnetPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ controlnet,
+ text_encoder: CLIPTextModel | None = None,
+ tokenizer: CLIPTokenizer | None = None
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ controlnet=controlnet,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ image_cond: PipelineImageInput = None,
+ generator: torch.Generator | None = None,
+ num_inference_steps: int = 1000,
+ output_type: str | None = "pil",
+ return_dict: bool = True,
+ prompt: Optional[str] = None,
+ ) -> ImagePipelineOutput :
+ text_inputs = self.tokenizer(
+ prompt.lower(),
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.to(self.device)
+ encoder_hidden_states = self.text_encoder(text_input_ids, return_dict=False)[0]
+
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
+ image = image.to(self.device)
+ else:
+ image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+ # denoising loop
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. controlnet output
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ sample=image,
+ timestep=t,
+ encoder_hidden_states = encoder_hidden_states,
+ controlnet_cond=image_cond,
+ return_dict=False,
+ )
+ # 2. predict noise model_output
+ model_output = self.unet(
+ sample=image,
+ timestep=t,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ encoder_hidden_states = encoder_hidden_states,
+ return_dict=False,
+ )[0]
+
+ # 3. compute previous image: x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
\ No newline at end of file
diff --git a/models/pipeline_ddpm_text_encoder.py b/models/pipeline_ddpm_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e79be7a8c12a95ea8621475ca322061dd166d184
--- /dev/null
+++ b/models/pipeline_ddpm_text_encoder.py
@@ -0,0 +1,155 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pdb
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from transformers import CLIPTextModel, CLIPTokenizer
+
+
+class DDPMPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Parameters:
+ unet ([`UNet2DModel`]):
+ A `UNet2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ model_cpu_offload_seq = "unet"
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ text_encoder: Optional[CLIPTextModel]=None,
+ tokenizer: Optional[CLIPTokenizer]=None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 1000,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 1000):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Example:
+
+ ```py
+ >>> from diffusers import DDPMPipeline
+
+ >>> # load model and scheduler
+ >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
+
+ >>> # run pipeline in inference (sample random noise and denoise)
+ >>> image = pipe().images[0]
+
+ >>> # save image
+ >>> image.save("ddpm_generated_image.png")
+ ```
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # Prepare prompt embedding
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt.lower(),
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.to(self.device)
+ encoder_hidden_states = self.text_encoder(text_input_ids, return_dict=False)[0]
+
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ image = randn_tensor(image_shape, generator=generator)
+ image = image.to(self.device)
+ else:
+ image = randn_tensor(image_shape, generator=generator, device=self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t, encoder_hidden_states).sample
+ # 2. compute previous image: x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
\ No newline at end of file
diff --git a/models/unet_2d.py b/models/unet_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed20ebb352c5b16aecc0503d940af8384cfc6372
--- /dev/null
+++ b/models/unet_2d.py
@@ -0,0 +1,340 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+import pdb
+
+@dataclass
+class UNet2DOutput(BaseOutput):
+ """
+ The output of [`UNet2DModel`].
+
+ Args:
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output from the last layer of the model.
+ """
+
+ sample: torch.Tensor
+
+
+class UNet2DModel(ModelMixin, ConfigMixin):
+ r"""
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
+ 1)`.
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip sin to cos for Fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
+ Tuple of downsample block types.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
+ Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
+ downsample_type (`str`, *optional*, defaults to `conv`):
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
+ upsample_type (`str`, *optional*, defaults to `conv`):
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
+ attn_norm_num_groups (`int`, *optional*, defaults to `None`):
+ If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
+ given number of groups. If left as `None`, the group norm layer will only be created if
+ `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, or `"identity"`.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
+ conditioning with `class_embed_type` equal to `None`.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = True,
+ down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
+ block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
+ layers_per_block: int = 2,
+ mid_block_scale_factor: float = 1,
+ downsample_padding: int = 1,
+ downsample_type: str = "conv",
+ upsample_type: str = "conv",
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ attention_head_dim: Optional[int] = 8,
+ norm_num_groups: int = 32,
+ attn_norm_num_groups: Optional[int] = None,
+ norm_eps: float = 1e-5,
+ resnet_time_scale_shift: str = "default",
+ add_attention: bool = True,
+ num_train_timesteps: Optional[int] = None,
+ use_prompt: bool = False,
+ encoder_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ elif time_embedding_type == "learned":
+ self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # prompt
+ if use_prompt:
+ self.prompt_embedding = nn.Sequential(
+ nn.Linear(encoder_size, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim)
+ )
+ else:
+ self.prompt_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ attn_groups=attn_norm_num_groups,
+ add_attention=add_attention,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ r"""
+ The [`UNet2DModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.prompt_embedding is not None:
+ encoder_hidden_states = encoder_hidden_states.reshape(sample.shape[0], -1)
+ prompt_emb = self.prompt_embedding(encoder_hidden_states)
+ emb = emb + prompt_emb
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "skip_conv"):
+ sample, res_samples, skip_sample = downsample_block(
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb)
+
+ # 5. up
+ skip_sample = None
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
diff --git a/models/unet_2d_condition.py b/models/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f7a4b4d928882fda3c7e1006b5d360ee0da8e7c
--- /dev/null
+++ b/models/unet_2d_condition.py
@@ -0,0 +1,215 @@
+from typing import Tuple
+from diffusers.models.unets.unet_2d import UNet2DOutput
+from typing import Any, Dict, List, Optional, Tuple, Union
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+
+from models.unet_2d import UNet2DModel
+
+import torch
+
+class UNet2DConditionModel(UNet2DModel):
+ def __init__(
+ self,
+ sample_size: int | Tuple[int] | None = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = False,
+ down_block_types: Tuple[str] = ...,
+ up_block_types: Tuple[str] = ...,
+ block_out_channels: Tuple[int] = ...,
+ layers_per_block: int = 2,
+ mid_block_scale_factor: float = 1,
+ downsample_padding: int = 1,
+ downsample_type: str = "conv",
+ upsample_type: str = "conv",
+ dropout: float = 0,
+ act_fn: str = "silu",
+ attention_head_dim: int | None = 8,
+ norm_num_groups: int = 32,
+ attn_norm_num_groups: int | None = None,
+ norm_eps: float = 0.00001,
+ resnet_time_scale_shift: str = "default",
+ add_attention: bool = False,
+ num_train_timesteps: int | None = None,
+ use_prompt: bool = False,
+ encoder_size: int | None = None
+ ):
+ super().__init__(
+ sample_size,
+ in_channels,
+ out_channels,
+ center_input_sample,
+ time_embedding_type,
+ freq_shift,
+ flip_sin_to_cos,
+ down_block_types,
+ up_block_types,
+ block_out_channels,
+ layers_per_block,
+ mid_block_scale_factor,
+ downsample_padding,
+ downsample_type,
+ upsample_type,
+ dropout,
+ act_fn,
+ attention_head_dim,
+ norm_num_groups,
+ attn_norm_num_groups,
+ norm_eps,
+ resnet_time_scale_shift,
+ add_attention,
+ num_train_timesteps,
+ use_prompt,
+ encoder_size
+ )
+
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.prompt_embedding is not None:
+ encoder_hidden_states = encoder_hidden_states.reshape(sample.shape[0], -1).to(dtype=self.dtype).contiguous()
+ prompt_emb = self.prompt_embedding(encoder_hidden_states)
+ emb = emb + prompt_emb
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ skip_sample = None
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7b148c94777f29f22c9091dd5be8cc8d5644a6b7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,105 @@
+accelerate==1.8.1
+aiofiles==24.1.0
+aiohappyeyeballs==2.6.1
+aiohttp==3.12.15
+aiosignal==1.4.0
+annotated-types==0.7.0
+antlr4-python3-runtime==4.9.3
+anyio==4.9.0
+attrs==25.3.0
+certifi==2025.6.15
+charset-normalizer==3.4.2
+click==8.2.1
+datasets==4.0.0
+diffusers==0.34.0
+dill==0.3.8
+einops==0.8.0
+fastapi==0.116.0
+fastremap==1.17.2
+ffmpy==0.6.0
+filelock==3.18.0
+frozenlist==1.7.0
+fsspec==2025.3.0
+gradio==5.35.0
+gradio_client==1.10.4
+groovy==0.1.2
+h11==0.16.0
+hf-xet==1.1.5
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==0.33.2
+hydra-core==1.3.2
+idna==3.10
+imageio==2.37.0
+importlib_metadata==8.7.0
+iopath==0.1.10
+isat-sam-backend==1.0.0
+Jinja2==3.1.6
+lazy_loader==0.4
+llvmlite==0.44.0
+markdown-it-py==3.0.0
+MarkupSafe==3.0.2
+mdurl==0.1.2
+mpmath==1.3.0
+multidict==6.6.3
+multiprocess==0.70.16
+natsort==8.4.0
+networkx
+numba==0.61.2
+numpy==1.25.2
+omegaconf==2.3.0
+opencv-python==4.10.0.82
+orjson==3.10.18
+packaging==25.0
+pandas==2.3.1
+pillow==11.3.0
+portalocker==3.2.0
+propcache==0.3.2
+psutil==7.0.0
+pyarrow==21.0.0
+pydantic==2.11.7
+pydantic_core==2.33.2
+pydub==0.25.1
+Pygments==2.19.2
+python-dateutil==2.9.0.post0
+python-multipart==0.0.20
+pytorch-msssim==1.0.0
+pytz==2025.2
+PyYAML==6.0.2
+regex==2024.11.6
+requests==2.32.4
+rich==14.0.0
+roifile==2025.5.10
+ruff==0.12.2
+safehttpx==0.1.6
+safetensors==0.5.3
+scikit-image==0.25.2
+scipy
+semantic-version==2.10.0
+shellingham==1.5.4
+six==1.17.0
+sniffio==1.3.1
+starlette==0.46.2
+sympy==1.14.0
+tabulate==0.9.0
+tifffile
+timm==1.0.19
+tokenizers==0.21.2
+tomlkit==0.13.3
+torch==2.7.1
+torchaudio==2.7.1
+torchvision==0.22.1
+tqdm==4.67.1
+transformers==4.53.1
+triton==3.3.1
+typer==0.16.0
+typing-inspection==0.4.1
+typing_extensions==4.14.1
+tzdata==2025.2
+urllib3==2.5.0
+uvicorn==0.35.0
+warmup_scheduler==0.3
+websockets==15.0.1
+xxhash==3.5.0
+yarl==1.20.1
+zipp==3.23.0
diff --git a/utils/logo2_resize.png b/utils/logo2_resize.png
new file mode 100644
index 0000000000000000000000000000000000000000..1fe5e7b1f3de97cd7f5547a49148b8f6cf6eb138
--- /dev/null
+++ b/utils/logo2_resize.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94f7d69a514ce901b27788782160194cca676b1a44a5565643e08800eb23e3b5
+size 251472
diff --git a/utils/logo2_transparent.png b/utils/logo2_transparent.png
new file mode 100644
index 0000000000000000000000000000000000000000..6176eb2063d671f430872cc55d32f6c5f249d7c8
--- /dev/null
+++ b/utils/logo2_transparent.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddd12a8a1702ab903916cf99d4c5f555e8ba111f6bedf915f79d87cf54092ae6
+size 211713
diff --git a/utils/logo_0801_1.png b/utils/logo_0801_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9f200a8af59dce3fd2266bdaf78a291afd8328f
--- /dev/null
+++ b/utils/logo_0801_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cad83ec9f0fbf178f50dcbfb0661928fba007a06f223b23d940ec904bf8f8569
+size 1529724
diff --git a/utils/logo_0801_2.png b/utils/logo_0801_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..27b1f6d343026a72ab4d835c31275bca2426c9da
--- /dev/null
+++ b/utils/logo_0801_2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:26465d64eaa2767c32f9b763ccf57bef72bbdf16ef13c82bf36886dc29e30328
+size 1226031