diff --git a/README.md b/README.md index 758e3be788552de668ae0964e4f8812a45c8f6cb..c02cf011506a1b534ec255d540149c881bdf80ce 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ --- -title: FluoGen -emoji: πŸš€ -colorFrom: indigo -colorTo: pink +title: FluoGen Demo +emoji: πŸ“‰ +colorFrom: red +colorTo: red sdk: gradio sdk_version: 5.49.1 app_file: app.py pinned: false license: mit -short_description: TMP +short_description: 'Demo space of FluoGen: An Open-Source Generative Foundation ' --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..bda41e1a96f67f80f8dd4ebb6e5f742d9c9d69d2 --- /dev/null +++ b/app.py @@ -0,0 +1,664 @@ +import torch +import numpy as np +import gradio as gr +from PIL import Image +import os +import json +import glob +import random +import tifffile +import re +import imageio +from torchvision import transforms, models +import accelerate +import shutil +import time +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict + + +# --- Imports from both scripts --- +from diffusers import DDPMScheduler, DDIMScheduler +from transformers import CLIPTextModel, CLIPTokenizer +from accelerate.state import AcceleratorState +from transformers.utils import ContextManagers + +# --- Custom Model Imports --- +from models.pipeline_ddpm_text_encoder import DDPMPipeline +from models.unet_2d import UNet2DModel +from models.controlnet import ControlNetModel +from models.unet_2d_condition import UNet2DConditionModel +from models.pipeline_controlnet import DDPMControlnetPipeline + +# --- New Import for Segmentation --- +from cellpose import models as cellpose_models +from cellpose import plot as cellpose_plot +from huggingface_hub import hf_hub_download + +# --- 0. Configuration & Constants --- +# --- General --- +MODEL_TITLE = "πŸ”¬ FluoGen: AI-Powered Fluorescence Microscopy Suite" +MODEL_DESCRIPTION = """ +**Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis* +
+Select a task from the tabs below: generate new images from text, enhance existing images using super-resolution, denoise them, generate training data from segmentation masks, perform cell segmentation, or classify cell images. +""" +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" +WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 +LOGO_PATH = "utils/logo_0801_2.png" + +# --- Global switch to control example saving --- +SAVE_EXAMPLES = False + +# --- Base directory for all models --- +# NOTE: All model paths are now relative. +# Run the `copy_weights.py` script once to copy all necessary model files into this local directory. +REPO_ID = "rayquaza384mega/FluoGen-demo-test-ckpts" +MODELS_ROOT_DIR = hf_hub_download(repo_id=REPO_ID) #"models_collection" + + +# --- Tab 1: Mask-to-Image Config (Formerly Segmentation-to-Image) --- +M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000" +M2I_EXAMPLE_IMG_DIR = "example_images_m2i" + +# --- Tab 2: Text-to-Image Config --- +T2I_PROMPTS = ["F-actin of COS-7", "ER of COS-7", "Mitochondria of BPAE", "Nucleus of BPAE", "ER of HeLa", "Microtubules of HeLa"] +T2I_EXAMPLE_IMG_DIR = "example_images" +T2I_CHECKPOINT = 285000 +T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5" +T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-{T2I_CHECKPOINT}" + +# --- Tab 3, 4: ControlNet-based Tasks Config --- +CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5" +CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000" + +# Super-Resolution Models +SR_CONTROLNET_MODELS = { + "Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000", + "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000", +} +SR_EXAMPLE_IMG_DIR = "example_images_sr" + +# Denoising Model +DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000" +DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'} +DN_EXAMPLE_IMG_DIR = "example_images_dn" + +# --- Tab 5: Cell Segmentation Config --- +SEG_MODELS = { + "DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100", + "DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300", + "DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135", + "DSB Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135", +} +SEG_EXAMPLE_IMG_DIR = "example_images_seg" + + +# --- Tab 6: Classification Config --- +CLS_MODEL_PATHS = OrderedDict({ + "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re", + #"10shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_re", + #"15shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_re", + #"20shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_re", + "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re", + #"10shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_aug_re", + #"15shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_aug_re", + #"20shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_aug_re", +}) +CLS_CLASS_NAMES = ['dap', 'erdak', 'giant', 'gpp130', 'h4b4', 'mc151', 'nucle', 'phal', 'tfr', 'tubul'] +CLS_EXAMPLE_IMG_DIR = "example_images_cls" + + +# --- Helper Functions --- +def sanitize_prompt_for_filename(prompt): + prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt) + return f"{prompt}.png" + +def min_max_norm(x): + x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x) + if max_val - min_val < 1e-8: return np.zeros_like(x) + return (x - min_val) / (max_val - min_val) + +def numpy_to_pil(image_np, target_mode="RGB"): + # If the input is already a PIL image, just ensure mode and return + if isinstance(image_np, Image.Image): + if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB") + if target_mode == "L" and image_np.mode != "L": return image_np.convert("L") + return image_np + + # Handle numpy array conversion + squeezed_np = np.squeeze(image_np); + if squeezed_np.dtype == np.uint8: + # If it's already uint8, it's likely in the 0-255 range. + image_8bit = squeezed_np + else: + # Normalize and scale for other types + normalized_np = min_max_norm(squeezed_np) + image_8bit = (normalized_np * 255).astype(np.uint8) + + pil_image = Image.fromarray(image_8bit) + + if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") + elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L") + return pil_image + +# --- 1. Model Loading --- +print("--- Initializing FluoGen Application ---") +t2i_pipe, controlnet_pipe = None, None +try: + print("Loading Text-to-Image model...") + t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing") + t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet") + t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE) + t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer") + t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer) + t2i_pipe.to(DEVICE) + print("βœ“ Text-to-Image model loaded successfully!") +except Exception as e: + print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}") + +try: + print("Loading shared ControlNet pipeline components...") + controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE) + default_controlnet_path = M2I_CONTROLNET_PATH # Start with the first tab's model + controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE) + controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing") + controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer") + with ContextManagers([]): + controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE) + controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer) + controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE) + controlnet_pipe.current_controlnet_path = default_controlnet_path + print("βœ“ Shared ControlNet pipeline loaded successfully!") +except Exception as e: + print(f"!!!!!! FATAL: ControlNet Pipeline Loading Failed !!!!!!\nError: {e}") + +# --- 2. Core Logic Functions --- +def swap_controlnet(pipe, target_path): + if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path): + print(f"Swapping ControlNet model to: {target_path}") + try: + pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE) + pipe.current_controlnet_path = target_path + except Exception as e: + raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}") + return pipe + +def generate_t2i(prompt, num_inference_steps): + if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.") + print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}") + image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images + generated_image = numpy_to_pil(image_np) + print("βœ“ Image generated") + if SAVE_EXAMPLES: + example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt)) + if not os.path.exists(example_filepath): + generated_image.save(example_filepath); print(f"βœ“ New T2I example saved: {example_filepath}") + return generated_image + +def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed): + if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") + if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask TIF file.") + if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.") + + if SAVE_EXAMPLES: + input_path = mask_file_obj.name + filename = os.path.basename(input_path) + dest_path = os.path.join(M2I_EXAMPLE_IMG_DIR, filename) + if not os.path.exists(dest_path): + shutil.copy(input_path, dest_path) + print(f"βœ“ New Mask-to-Image example saved: {dest_path}") + + pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH) + try: + mask_np = tifffile.imread(mask_file_obj.name) + except Exception as e: + raise gr.Error(f"Failed to read the TIF file. Error: {e}") + + input_display_image = numpy_to_pil(mask_np, "L") + mask_normalized = min_max_norm(mask_np) + image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)) + image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) + image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) + + prompt = f"nuclei of {cell_type.strip()}" + print(f"\nTask started... | Task: Mask-to-Image | Prompt: '{prompt}' | Steps: {steps} | Images: {num_images}") + + generated_images_pil = [] + for i in range(int(num_images)): + current_seed = int(seed) + i + generator = torch.Generator(device=DEVICE).manual_seed(current_seed) + with torch.autocast("cuda"): + output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images + pil_image = numpy_to_pil(output_np) + generated_images_pil.append(pil_image) + print(f"βœ“ Generated image {i+1}/{int(num_images)}") + + return input_display_image, generated_images_pil + +def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed): + if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") + if low_res_file_obj is None: raise gr.Error("Please upload a low-resolution TIF file.") + + if SAVE_EXAMPLES: + input_path = low_res_file_obj.name + filename = os.path.basename(input_path) + dest_path = os.path.join(SR_EXAMPLE_IMG_DIR, filename) + if not os.path.exists(dest_path): + shutil.copy(input_path, dest_path) + print(f"βœ“ New SR example saved: {dest_path}") + + target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name) + if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.") + + pipe = swap_controlnet(controlnet_pipe, target_path) + try: + image_stack_np = tifffile.imread(low_res_file_obj.name) + except Exception as e: + raise gr.Error(f"Failed to read the TIF file. Error: {e}") + + if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9: + raise gr.Error(f"Invalid TIF shape. Expected 9 channels (shape 9, H, W), but got {image_stack_np.shape}.") + + average_projection_np = np.mean(image_stack_np, axis=0) + input_display_image = numpy_to_pil(average_projection_np, "L") + + image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) + image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) + + generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) + with torch.autocast("cuda"): + output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images + + return input_display_image, numpy_to_pil(output_np) + +def run_denoising(noisy_image_np, image_type, steps, seed): + if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") + if noisy_image_np is None: raise gr.Error("Please upload a noisy image.") + + if SAVE_EXAMPLES: + timestamp = int(time.time() * 1000) + filename = f"dn_input_{image_type}_{timestamp}.tif" + dest_path = os.path.join(DN_EXAMPLE_IMG_DIR, filename) + try: + img_to_save = noisy_image_np.astype(np.uint8) if noisy_image_np.dtype != np.uint8 else noisy_image_np + tifffile.imwrite(dest_path, img_to_save) + print(f"βœ“ New Denoising example saved: {dest_path}") + except Exception as e: + print(f"βœ— Failed to save denoising example: {e}") + + pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH) + prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image') + print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}") + + image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0) + image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) + image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) + + generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) + with torch.autocast("cuda"): + output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images + + return numpy_to_pil(noisy_image_np, "L"), numpy_to_pil(output_np) + +def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold): + """ + Runs cell segmentation and creates a dark red overlay. + """ + if input_image_np is None: + raise gr.Error("Please upload an image to segment.") + + model_path = SEG_MODELS.get(model_name) + if not model_path: + raise gr.Error(f"Segmentation model '{model_name}' not found.") + + if not os.path.exists(model_path): + raise gr.Error(f"Model file not found at path: {model_path}. Please check the configuration.") + + print(f"\nTask started... | Task: Cell Segmentation | Model: '{model_name}'") + + # 1. Load Cellpose Model + try: + use_gpu = torch.cuda.is_available() + model = cellpose_models.CellposeModel(gpu=use_gpu, pretrained_model=model_path) + except Exception as e: + raise gr.Error(f"Failed to load Cellpose model. Error: {e}") + + diameter_to_use = model.diam_labels if diameter == 0 else float(diameter) + print(f"Using Diameter: {diameter_to_use}") + + # 2. Run model evaluation + try: + masks, _, _ = model.eval( + [input_image_np], + channels=[0, 0], + diameter=diameter_to_use, + flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold + ) + mask_output = masks[0] + except Exception as e: + raise gr.Error(f"Cellpose model evaluation failed. Error: {e}") + + # 3. Create custom dark red overlay + # Ensure input image is uint8 and 3-channel for blending + original_rgb = numpy_to_pil(input_image_np, "RGB") + original_rgb_np = np.array(original_rgb) + + # Create a blank layer for the red mask + red_mask_layer = np.zeros_like(original_rgb_np) + dark_red_color = [139, 0, 0] + + # Apply the red color where the mask is present + is_mask_pixels = mask_output > 0 + red_mask_layer[is_mask_pixels] = dark_red_color + + # Blend the original image with the red mask layer + alpha = 0.4 # Opacity of the mask + blended_image_np = ((1 - alpha) * original_rgb_np + alpha * red_mask_layer).astype(np.uint8) + + # 4. Save example if enabled + if SAVE_EXAMPLES: + timestamp = int(time.time() * 1000) + filename = f"seg_input_{timestamp}.tif" + dest_path = os.path.join(SEG_EXAMPLE_IMG_DIR, filename) + try: + img_to_save = input_image_np.astype(np.uint8) if input_image_np.dtype != np.uint8 else input_image_np + tifffile.imwrite(dest_path, img_to_save) + print(f"βœ“ New Segmentation example saved: {dest_path}") + except Exception as e: + print(f"βœ— Failed to save segmentation example: {e}") + + print("βœ“ Segmentation complete") + + return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB") + +def run_classification(input_image_np, model_name): + """ + Runs classification on a single image using a pre-trained ResNet50 model. + """ + if input_image_np is None: + raise gr.Error("Please upload an image to classify.") + + model_dir = CLS_MODEL_PATHS.get(model_name) + if not model_dir: + raise gr.Error(f"Classification model '{model_name}' not found.") + + model_path = os.path.join(model_dir, "best_resnet50.pth") + if not os.path.exists(model_path): + raise gr.Error(f"Model file not found at {model_path}. Please check the configuration.") + + print(f"\nTask started... | Task: Classification | Model: '{model_name}'") + + # 1. Load Model + try: + model = models.resnet50(weights=None) + num_features = model.fc.in_features + model.fc = nn.Linear(num_features, len(CLS_CLASS_NAMES)) + model.load_state_dict(torch.load(model_path, map_location=DEVICE)) + model.to(DEVICE) + model.eval() + except Exception as e: + raise gr.Error(f"Failed to load classification model. Error: {e}") + + # 2. Preprocess Image + # Grayscale numpy -> RGB PIL -> transform -> tensor + input_pil = numpy_to_pil(input_image_np, "RGB") + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # ResNet needs 3-channel norm + ]) + input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE) + + # 3. Perform Inference + with torch.no_grad(): + outputs = model(input_tensor) + probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy() + + # 4. Format output for Gradio Label component + confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)} + + # 5. Save example + if SAVE_EXAMPLES: + timestamp = int(time.time() * 1000) + filename = f"cls_input_{timestamp}.png" # Save as png for compatibility + dest_path = os.path.join(CLS_EXAMPLE_IMG_DIR, filename) + try: + input_pil.save(dest_path) + print(f"βœ“ New Classification example saved: {dest_path}") + except Exception as e: + print(f"βœ— Failed to save classification example: {e}") + + print("βœ“ Classification complete") + + return numpy_to_pil(input_image_np, "L"), confidences + + +# --- 3. Gradio UI Layout --- +print("Building Gradio interface...") +# Create directories for all example types +os.makedirs(M2I_EXAMPLE_IMG_DIR, exist_ok=True) +os.makedirs(T2I_EXAMPLE_IMG_DIR, exist_ok=True) +os.makedirs(SR_EXAMPLE_IMG_DIR, exist_ok=True) +os.makedirs(DN_EXAMPLE_IMG_DIR, exist_ok=True) +os.makedirs(SEG_EXAMPLE_IMG_DIR, exist_ok=True) +os.makedirs(CLS_EXAMPLE_IMG_DIR, exist_ok=True) + +# --- Load examples --- +filename_to_prompt_map = { sanitize_prompt_for_filename(prompt): prompt for prompt in T2I_PROMPTS } +t2i_gallery_examples = [] +for filename in os.listdir(T2I_EXAMPLE_IMG_DIR): + if filename in filename_to_prompt_map: + filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, filename) + prompt = filename_to_prompt_map[filename] + t2i_gallery_examples.append((filepath, prompt)) + +def load_image_examples(example_dir, is_stack=False): + examples = [] + if not os.path.exists(example_dir): return examples + for f in sorted(os.listdir(example_dir)): + if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): + filepath = os.path.join(example_dir, f) + try: + if f.lower().endswith(('.tif', '.tiff')): + img_np = tifffile.imread(filepath) + else: + img_np = np.array(Image.open(filepath).convert("L")) + + if is_stack and img_np.ndim == 3: + img_np = np.mean(img_np, axis=0) + + display_img = numpy_to_pil(img_np, "L") + examples.append((display_img, filepath)) + except Exception as e: + print(f"Warning: Could not load gallery image {filepath}. Error: {e}") + return examples + +m2i_gallery_examples = load_image_examples(M2I_EXAMPLE_IMG_DIR) +sr_gallery_examples = load_image_examples(SR_EXAMPLE_IMG_DIR, is_stack=True) +dn_gallery_examples = load_image_examples(DN_EXAMPLE_IMG_DIR) +seg_gallery_examples = load_image_examples(SEG_EXAMPLE_IMG_DIR) +cls_gallery_examples = load_image_examples(CLS_EXAMPLE_IMG_DIR) + +# --- Universal event handlers --- +def select_example_prompt(evt: gr.SelectData): + return evt.value['caption'] + +def select_example_input_file(evt: gr.SelectData): + return evt.value['caption'] + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + with gr.Row(): + gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False) + gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}") + + with gr.Tabs(): + # --- TAB 1: Mask-to-Image --- + with gr.Tab("Mask-to-Image", id="mask2img"): + gr.Markdown(""" + ### Instructions + 1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below. + 2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt. + 3. Select how many sample images you want to generate. + 4. Adjust 'Inference Steps' and 'Seed' as needed. + 5. Click 'Generate Training Samples' to start the process. + 6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference. + """) # Content hidden for brevity + with gr.Row(variant="panel"): + with gr.Column(scale=1, min_width=350): + m2i_input_file = gr.File(label="Upload Segmentation Mask (.tif)", file_types=['.tif', '.tiff']) + m2i_cell_type_input = gr.Textbox(label="Cell Type (for prompt)", placeholder="e.g., CoNSS, HeLa, MCF-7") + m2i_num_images_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Images to Generate") + m2i_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps") + m2i_seed_input = gr.Number(label="Seed", value=42) + m2i_generate_button = gr.Button("Generate Training Samples", variant="primary") + with gr.Column(scale=2): + m2i_output_gallery = gr.Gallery(label="Generated Samples", columns=5, object_fit="contain", height="auto") + m2i_input_display = gr.Image(label="Input Mask", type="pil", interactive=False) + m2i_gallery = gr.Gallery(value=m2i_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto") + + # --- TAB 2: Text-to-Image --- + with gr.Tab("Text-to-Image Generation", id="txt2img"): + gr.Markdown(""" + ### Instructions + 1. Select a desired prompt from the dropdown menu. + 2. Adjust the 'Inference Steps' slider to control generation quality. + 3. Click the 'Generate' button to create a new image. + 4. Explore the 'Examples' gallery; clicking an image will load its prompt. + + **Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model. + """) # Content hidden for brevity + with gr.Row(variant="panel"): + with gr.Column(scale=1, min_width=350): + t2i_prompt_input = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Select a Prompt") + t2i_steps_slider = gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Inference Steps") + t2i_generate_button = gr.Button("Generate", variant="primary") + with gr.Column(scale=2): + t2i_generated_output = gr.Image(label="Generated Image", type="pil", interactive=False) + t2i_gallery = gr.Gallery(value=t2i_gallery_examples, label="Examples (Click an image to use its prompt)", columns=6, object_fit="contain", height="auto") + + # --- TAB 3: Super-Resolution --- + with gr.Tab("Super-Resolution", id="super_res"): + gr.Markdown(""" + ### Instructions + 1. Upload a low-resolution 9-channel TIF stack, or select one from the examples. + 2. Select a 'Super-Resolution Model' from the dropdown. + 3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7'). + 4. Adjust 'Inference Steps' and 'Seed' as needed. + 5. Click 'Generate Super-Resolution' to process the image. + + **Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. + """) # Content hidden for brevity + with gr.Row(variant="panel"): + with gr.Column(scale=1, min_width=350): + sr_input_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff']) + sr_model_selector = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Select Super-Resolution Model") + sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7") + sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps") + sr_seed_input = gr.Number(label="Seed", value=42) + sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary") + with gr.Column(scale=2): + with gr.Row(): + sr_input_display = gr.Image(label="Input (Average Projection)", type="pil", interactive=False) + sr_output_image = gr.Image(label="Super-Resolved Image", type="pil", interactive=False) + sr_gallery = gr.Gallery(value=sr_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto") + + # --- TAB 4: Denoising --- + with gr.Tab("Denoising", id="denoising"): + gr.Markdown(""" + ### Instructions + 1. Upload a noisy single-channel image, or select one from the examples. + 2. Select the 'Image Type' from the dropdown to provide context for the model. + 3. Adjust 'Inference Steps' and 'Seed' as needed. + 4. Click 'Denoise Image' to reduce the noise. + + **Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. + """) # Content hidden for brevity + with gr.Row(variant="panel"): + with gr.Column(scale=1, min_width=350): + dn_input_image = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L") + dn_image_type_selector = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Select Image Type (for Prompt)") + dn_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps") + dn_seed_input = gr.Number(label="Seed", value=42) + dn_generate_button = gr.Button("Denoise Image", variant="primary") + with gr.Column(scale=2): + with gr.Row(): + dn_original_display = gr.Image(label="Original Noisy Image", type="pil", interactive=False) + dn_output_image = gr.Image(label="Denoised Image", type="pil", interactive=False) + dn_gallery = gr.Gallery(value=dn_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto") + + # --- TAB 5: Cell Segmentation --- + with gr.Tab("Cell Segmentation", id="segmentation"): + gr.Markdown(""" + ### Instructions + 1. Upload a single-channel image for segmentation, or select one from the examples. + 2. Select a 'Segmentation Model' from the dropdown menu. + 3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it. + 4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control. + 5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image. + """) + with gr.Row(variant="panel"): + with gr.Column(scale=1, min_width=350): + gr.Markdown("### 1. Inputs & Controls") + seg_input_image = gr.Image(type="numpy", label="Upload Image for Segmentation", image_mode="L") + seg_model_selector = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Select Segmentation Model") + seg_diameter_input = gr.Number(label="Cell Diameter (pixels, 0=auto)", value=30) + seg_flow_slider = gr.Slider(minimum=0.0, maximum=3.0, step=0.1, value=0.4, label="Flow Threshold") + seg_cellprob_slider = gr.Slider(minimum=-6.0, maximum=6.0, step=0.5, value=0.0, label="Cell Probability Threshold") + seg_generate_button = gr.Button("Segment Cells", variant="primary") + with gr.Column(scale=2): + gr.Markdown("### 2. Results") + with gr.Row(): + seg_original_display = gr.Image(label="Original Image", type="pil", interactive=False) + seg_output_image = gr.Image(label="Segmented Image (Overlay)", type="pil", interactive=False) + seg_gallery = gr.Gallery(value=seg_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto") + + # --- NEW TAB 6: Classification --- + with gr.Tab("Classification", id="classification"): + gr.Markdown(""" + ### Instructions + 1. Upload a single-channel image for classification, or select an example. + 2. Select a pre-trained 'Classification Model' from the dropdown menu. + 3. Click 'Classify Image' to view the prediction probabilities for each class. + + **Note:** The models provided are ResNet50 trained on the 2D HeLa dataset. + """) + with gr.Row(variant="panel"): + with gr.Column(scale=1, min_width=350): + gr.Markdown("### 1. Inputs & Controls") + cls_input_image = gr.Image(type="numpy", label="Upload Image for Classification", image_mode="L") + cls_model_selector = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Select Classification Model") + cls_generate_button = gr.Button("Classify Image", variant="primary") + with gr.Column(scale=2): + gr.Markdown("### 2. Results") + cls_original_display = gr.Image(label="Input Image", type="pil", interactive=False) + cls_output_label = gr.Label(label="Classification Results", num_top_classes=len(CLS_CLASS_NAMES)) + cls_gallery = gr.Gallery(value=cls_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto") + + + # --- Event Handlers --- + m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input], outputs=[m2i_input_display, m2i_output_gallery]) + m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file) + + t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider], outputs=[t2i_generated_output]) + t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input) + + sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input], outputs=[sr_input_display, sr_output_image]) + sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file) + + dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input], outputs=[dn_original_display, dn_output_image]) + dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image) + + seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image]) + seg_gallery.select(fn=select_example_input_file, outputs=seg_input_image) + + cls_generate_button.click(fn=run_classification, inputs=[cls_input_image, cls_model_selector], outputs=[cls_original_display, cls_output_label]) + cls_gallery.select(fn=select_example_input_file, outputs=cls_input_image) + + +# --- 4. Launch Application --- +if __name__ == "__main__": + print("Interface built. Launching server...") + demo.launch() \ No newline at end of file diff --git a/cache/00000001.tif b/cache/00000001.tif new file mode 100644 index 0000000000000000000000000000000000000000..653f4cacdb2d1d85957fbb005ecad12406ecc6f3 --- /dev/null +++ b/cache/00000001.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:436a899290607814ba5f5c0e9456e9d6551bb2cb986223aff7fc4ae396d7a86a +size 1181232 diff --git a/cellpose/__init__.py b/cellpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9750249a2a5ab1f0c17bbfb0c26e3e19b88f13 --- /dev/null +++ b/cellpose/__init__.py @@ -0,0 +1 @@ +from cellpose.version import version, version_str diff --git a/cellpose/__main__.py b/cellpose/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..fea1d9c9b58cfef9831cbfaeb5b2e3fa85b2c9cb --- /dev/null +++ b/cellpose/__main__.py @@ -0,0 +1,380 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import sys, os, glob, pathlib, time +import numpy as np +from natsort import natsorted +from tqdm import tqdm +from cellpose import utils, models, io, version_str, train, denoise +from cellpose.cli import get_arg_parser + +try: + from cellpose.gui import gui3d, gui + GUI_ENABLED = True +except ImportError as err: + GUI_ERROR = err + GUI_ENABLED = False + GUI_IMPORT = True +except Exception as err: + GUI_ENABLED = False + GUI_ERROR = err + GUI_IMPORT = False + raise + +import logging + + +# settings re-grouped a bit +def main(): + """ Run cellpose from command line + """ + + args = get_arg_parser().parse_args( + ) # this has to be in a separate file for autodoc to work + + if args.version: + print(version_str) + return + + if args.check_mkl: + mkl_enabled = models.check_mkl() + else: + mkl_enabled = True + + if len(args.dir) == 0 and len(args.image_path) == 0: + if args.add_model: + io.add_model(args.add_model) + else: + if not GUI_ENABLED: + print("GUI ERROR: %s" % GUI_ERROR) + if GUI_IMPORT: + print( + "GUI FAILED: GUI dependencies may not be installed, to install, run" + ) + print(" pip install 'cellpose[gui]'") + else: + if args.Zstack: + gui3d.run() + else: + gui.run() + + else: + if args.verbose: + from .io import logger_setup + logger, log_file = logger_setup() + else: + print( + ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose") + print("No --verbose => no progress or info printed") + logger = logging.getLogger(__name__) + + use_gpu = False + channels = [args.chan, args.chan2] + + # find images + if len(args.img_filter) > 0: + imf = args.img_filter + else: + imf = None + + # Check with user if they REALLY mean to run without saving anything + if not (args.train or args.train_size): + saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt + + device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu, + device=args.gpu_device) + + if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0": + pretrained_model = False + else: + pretrained_model = args.pretrained_model + + restore_type = args.restore_type + if restore_type is not None: + try: + denoise.model_path(restore_type) + except Exception as e: + raise ValueError("restore_type invalid") + if args.train or args.train_size: + raise ValueError("restore_type cannot be used with training on CLI yet") + + if args.transformer and (restore_type is None): + default_model = "transformer_cp3" + backbone = "transformer" + elif args.transformer and restore_type is not None: + raise ValueError("no transformer based restoration") + else: + default_model = "cyto3" + backbone = "default" + + if args.norm_percentile is not None: + value1, value2 = args.norm_percentile + normalize = {'percentile': (float(value1), float(value2))} + else: + normalize = (not args.no_norm) + + + model_type = None + if pretrained_model and not os.path.exists(pretrained_model): + model_type = pretrained_model if pretrained_model is not None else "cyto3" + model_strings = models.get_user_models() + all_models = models.MODEL_NAMES.copy() + all_models.extend(model_strings) + if ~np.any([model_type == s for s in all_models]): + model_type = default_model + logger.warning( + f"pretrained model has incorrect path, using {default_model}") + if model_type == "nuclei": + szmean = 17. + else: + szmean = 30. + builtin_size = (model_type == "cyto" or model_type == "cyto2" or + model_type == "nuclei" or model_type == "cyto3") + + if len(args.image_path) > 0 and (args.train or args.train_size): + raise ValueError("ERROR: cannot train model with single image input") + + if not args.train and not args.train_size: + tic = time.time() + if len(args.dir) > 0: + image_names = io.get_image_files( + args.dir, args.mask_filter, imf=imf, + look_one_level_down=args.look_one_level_down) + else: + if os.path.exists(args.image_path): + image_names = [args.image_path] + else: + raise ValueError(f"ERROR: no file found at {args.image_path}") + nimg = len(image_names) + + if args.savedir: + if not os.path.exists(args.savedir): + raise FileExistsError("--savedir {args.savedir} does not exist") + + cstr0 = ["GRAY", "RED", "GREEN", "BLUE"] + cstr1 = ["NONE", "RED", "GREEN", "BLUE"] + logger.info( + ">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s" + % (nimg, cstr0[channels[0]], cstr1[channels[1]])) + + # handle built-in model exceptions + if builtin_size and restore_type is None and not args.pretrained_model_ortho: + model = models.Cellpose(gpu=gpu, device=device, model_type=model_type, + backbone=backbone) + else: + builtin_size = False + if args.all_channels: + channels = None + img = io.imread(image_names[0]) + if img.ndim == 3: + nchan = min(img.shape) + elif img.ndim == 2: + nchan = 1 + channels = None + else: + nchan = 2 + + pretrained_model = None if model_type is not None else pretrained_model + if restore_type is None: + pretrained_model_ortho = None if args.pretrained_model_ortho is None else args.pretrained_model_ortho + model = models.CellposeModel(gpu=gpu, device=device, + pretrained_model=pretrained_model, + model_type=model_type, + nchan=nchan, + backbone=backbone, + pretrained_model_ortho=pretrained_model_ortho) + else: + model = denoise.CellposeDenoiseModel( + gpu=gpu, device=device, pretrained_model=pretrained_model, + model_type=model_type, restore_type=restore_type, nchan=nchan, + chan2_restore=args.chan2_restore) + + # handle diameters + if args.diameter == 0: + if builtin_size: + diameter = None + logger.info(">>>> estimating diameter for each image") + else: + if restore_type is None: + logger.info( + ">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter" + ) + else: + logger.info( + ">>>> cannot auto-estimate diameter for image restoration") + diameter = model.diam_labels + logger.info(">>>> using diameter %0.3f for all images" % diameter) + else: + diameter = args.diameter + logger.info(">>>> using diameter %0.3f for all images" % diameter) + + tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO) + + for image_name in tqdm(image_names, file=tqdm_out): + image = io.imread(image_name) + out = model.eval( + image, channels=channels, diameter=diameter, do_3D=args.do_3D, + augment=args.augment, resample=(not args.no_resample), + flow_threshold=args.flow_threshold, + cellprob_threshold=args.cellprob_threshold, + stitch_threshold=args.stitch_threshold, min_size=args.min_size, + invert=args.invert, batch_size=args.batch_size, + interp=(not args.no_interp), normalize=normalize, + channel_axis=args.channel_axis, z_axis=args.z_axis, + anisotropy=args.anisotropy, niter=args.niter, + flow3D_smooth=args.flow3D_smooth) + masks, flows = out[:2] + if len(out) > 3 and restore_type is None: + diams = out[-1] + else: + diams = diameter + ratio = 1. + if restore_type is not None: + imgs_dn = out[-1] + ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1. + diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams + else: + imgs_dn = None + if args.exclude_on_edges: + masks = utils.remove_edge_masks(masks) + if not args.no_npy: + io.masks_flows_to_seg(image, masks, flows, image_name, + imgs_restore=imgs_dn, channels=channels, + diams=diams, restore_type=restore_type, + ratio=1.) + if saving_something: + suffix = "_cp_masks" + if args.output_name is not None: + # (1) If `savedir` is not defined, then must have a non-zero `suffix` + if args.savedir is None and len(args.output_name) > 0: + suffix = args.output_name + elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir): + # (2) If `savedir` is defined, and different from `dir` then + # takes the value passed as a param. (which can be empty string) + suffix = args.output_name + + io.save_masks(image, masks, flows, image_name, + suffix=suffix, png=args.save_png, + tif=args.save_tif, save_flows=args.save_flows, + save_outlines=args.save_outlines, + dir_above=args.dir_above, savedir=args.savedir, + save_txt=args.save_txt, in_folders=args.in_folders, + save_mpl=args.save_mpl) + if args.save_rois: + io.save_rois(masks, image_name) + logger.info(">>>> completed in %0.3f sec" % (time.time() - tic)) + else: + test_dir = None if len(args.test_dir) == 0 else args.test_dir + images, labels, image_names, train_probs = None, None, None, None + test_images, test_labels, image_names_test, test_probs = None, None, None, None + compute_flows = False + if len(args.file_list) > 0: + if os.path.exists(args.file_list): + dat = np.load(args.file_list, allow_pickle=True).item() + image_names = dat["train_files"] + image_names_test = dat.get("test_files", None) + train_probs = dat.get("train_probs", None) + test_probs = dat.get("test_probs", None) + compute_flows = dat.get("compute_flows", False) + load_files = False + else: + logger.critical(f"ERROR: {args.file_list} does not exist") + else: + output = io.load_train_test_data(args.dir, test_dir, imf, + args.mask_filter, + args.look_one_level_down) + images, labels, image_names, test_images, test_labels, image_names_test = output + load_files = True + + # training with all channels + if args.all_channels: + img = images[0] if images is not None else io.imread(image_names[0]) + if img.ndim == 3: + nchan = min(img.shape) + elif img.ndim == 2: + nchan = 1 + channels = None + else: + nchan = 2 + + # model path + szmean = args.diam_mean + if not os.path.exists(pretrained_model) and model_type is None: + if not args.train: + error_message = "ERROR: model path missing or incorrect - cannot train size model" + logger.critical(error_message) + raise ValueError(error_message) + pretrained_model = False + logger.info(">>>> training from scratch") + if args.train: + logger.info( + ">>>> during training rescaling images to fixed diameter of %0.1f pixels" + % args.diam_mean) + + # initialize model + model = models.CellposeModel( + device=device, model_type=model_type, diam_mean=szmean, nchan=nchan, + pretrained_model=pretrained_model if model_type is None else None, + backbone=backbone) + + # train segmentation model + if args.train: + cpmodel_path = train.train_seg( + model.net, images, labels, train_files=image_names, + test_data=test_images, test_labels=test_labels, + test_files=image_names_test, train_probs=train_probs, + test_probs=test_probs, compute_flows=compute_flows, + load_files=load_files, normalize=normalize, + channels=channels, channel_axis=args.channel_axis, rgb=(nchan == 3), + learning_rate=args.learning_rate, weight_decay=args.weight_decay, + SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size, + min_train_masks=args.min_train_masks, + nimg_per_epoch=args.nimg_per_epoch, + nimg_test_per_epoch=args.nimg_test_per_epoch, + save_path=os.path.realpath(args.dir), save_every=args.save_every, + model_name=args.model_name_out)[0] + model.pretrained_model = cpmodel_path + logger.info(">>>> model trained and saved to %s" % cpmodel_path) + + # train size model + if args.train_size: + sz_model = models.SizeModel(cp_model=model, device=device) + # data has already been normalized and reshaped + sz_model.params = train.train_size( + model.net, model.pretrained_model, images, labels, + train_files=image_names, test_data=test_images, + test_labels=test_labels, test_files=image_names_test, + train_probs=train_probs, test_probs=test_probs, + load_files=load_files, channels=channels, + min_train_masks=args.min_train_masks, + channel_axis=args.channel_axis, rgb=(nchan == 3), + nimg_per_epoch=args.nimg_per_epoch, normalize=normalize, + nimg_test_per_epoch=args.nimg_test_per_epoch, + batch_size=args.batch_size) + if test_images is not None: + test_masks = [lbl[0] for lbl in test_labels + ] if test_labels is not None else test_labels + predicted_diams, diams_style = sz_model.eval( + test_images, channels=channels) + ccs = np.corrcoef( + diams_style, + np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0, 1] + cc = np.corrcoef( + predicted_diams, + np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0, 1] + logger.info( + "style test correlation: %0.4f; final test correlation: %0.4f" % + (ccs, cc)) + np.save( + os.path.join( + args.test_dir, + "%s_predicted_diams.npy" % os.path.split(cpmodel_path)[1]), + { + "predicted_diams": predicted_diams, + "diams_style": diams_style + }) + + +if __name__ == "__main__": + main() diff --git a/cellpose/__pycache__/__init__.cpython-310.pyc b/cellpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4711a7c4a12e6f8aa4502c9f3d502f87d6d9610 Binary files /dev/null and b/cellpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/cellpose/__pycache__/__init__.cpython-311.pyc b/cellpose/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddeeae369d0e38bd8427c87ca5708112ff6ed3e8 Binary files /dev/null and b/cellpose/__pycache__/__init__.cpython-311.pyc differ diff --git a/cellpose/__pycache__/__init__.cpython-312.pyc b/cellpose/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13d99532b2ec3a9109f7cd12b63171cfde0be08c Binary files /dev/null and b/cellpose/__pycache__/__init__.cpython-312.pyc differ diff --git a/cellpose/__pycache__/core.cpython-310.pyc b/cellpose/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82c745924bf4c29d4ac6d11dd6f81a8fbd22952f Binary files /dev/null and b/cellpose/__pycache__/core.cpython-310.pyc differ diff --git a/cellpose/__pycache__/core.cpython-311.pyc b/cellpose/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a7d57f5bddc8f4d1b4878b43611d5c17561c039 Binary files /dev/null and b/cellpose/__pycache__/core.cpython-311.pyc differ diff --git a/cellpose/__pycache__/core.cpython-312.pyc b/cellpose/__pycache__/core.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e760a67005aab4954dbd3b3a4e1a646c5eb1cdbe Binary files /dev/null and b/cellpose/__pycache__/core.cpython-312.pyc differ diff --git a/cellpose/__pycache__/dynamics.cpython-310.pyc b/cellpose/__pycache__/dynamics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21dcc91e54563d7df02f963113c7c9f8ab948c0a Binary files /dev/null and b/cellpose/__pycache__/dynamics.cpython-310.pyc differ diff --git a/cellpose/__pycache__/dynamics.cpython-311.pyc b/cellpose/__pycache__/dynamics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b36365a8468f847a3bb8c0cab1e7b4be6064276 Binary files /dev/null and b/cellpose/__pycache__/dynamics.cpython-311.pyc differ diff --git a/cellpose/__pycache__/dynamics.cpython-312.pyc b/cellpose/__pycache__/dynamics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b270ed4b6adb3d085c4ac268f8ec3735fa71362 Binary files /dev/null and b/cellpose/__pycache__/dynamics.cpython-312.pyc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc new file mode 100644 index 0000000000000000000000000000000000000000..ed51d53bcd4f6c0cd8a0490d274e82ca7a6cf445 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc new file mode 100644 index 0000000000000000000000000000000000000000..c95b9a6259ad7a6987b7d0ac6241b0ea5bf3a4d9 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi new file mode 100644 index 0000000000000000000000000000000000000000..259b67b759402723411cfd4b214b242c3090c993 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc new file mode 100644 index 0000000000000000000000000000000000000000..9843e74d95f2fff5da4d7d6ccf0c29ca341ba374 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc new file mode 100644 index 0000000000000000000000000000000000000000..5416c7461c6c86d56c1b66b70e86ab624fe9211e Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi new file mode 100644 index 0000000000000000000000000000000000000000..9f7c8fd3edc8a1da3f66f1018a8f22699f8fcb38 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc new file mode 100644 index 0000000000000000000000000000000000000000..4969bd00d68697357f57bdda43414d52aeaaf8e3 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc new file mode 100644 index 0000000000000000000000000000000000000000..3d771a418ddc10de51f7a48efb324757f8179b92 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc differ diff --git a/cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi new file mode 100644 index 0000000000000000000000000000000000000000..e6ee3eb368c8ee8273c2a70a5ee37a13449210d4 Binary files /dev/null and b/cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi differ diff --git a/cellpose/__pycache__/io.cpython-310.pyc b/cellpose/__pycache__/io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ed96f8b731c319779100d8fa73bd40499947498 Binary files /dev/null and b/cellpose/__pycache__/io.cpython-310.pyc differ diff --git a/cellpose/__pycache__/io.cpython-311.pyc b/cellpose/__pycache__/io.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6142b4480f8621f72ba9b8acd7f2871683de0cff Binary files /dev/null and b/cellpose/__pycache__/io.cpython-311.pyc differ diff --git a/cellpose/__pycache__/io.cpython-312.pyc b/cellpose/__pycache__/io.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..891a7298846644d03cf85f355f228443594bbe70 Binary files /dev/null and b/cellpose/__pycache__/io.cpython-312.pyc differ diff --git a/cellpose/__pycache__/metrics.cpython-310.pyc b/cellpose/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd2d6f9b9450218669b43cd37544e09357039d53 Binary files /dev/null and b/cellpose/__pycache__/metrics.cpython-310.pyc differ diff --git a/cellpose/__pycache__/metrics.cpython-311.pyc b/cellpose/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c53ac1a8dcbf08920123dae8abb128af34405f9 Binary files /dev/null and b/cellpose/__pycache__/metrics.cpython-311.pyc differ diff --git a/cellpose/__pycache__/metrics.cpython-312.pyc b/cellpose/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a861397f990129ee506a4b7b52344b62721a152 Binary files /dev/null and b/cellpose/__pycache__/metrics.cpython-312.pyc differ diff --git a/cellpose/__pycache__/models.cpython-310.pyc b/cellpose/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b16eb38aa02d62e6bbda76c6c5642d4137ab5aa5 Binary files /dev/null and b/cellpose/__pycache__/models.cpython-310.pyc differ diff --git a/cellpose/__pycache__/models.cpython-311.pyc b/cellpose/__pycache__/models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ac605152416c732b961eccfe4387a6e814c62c0 Binary files /dev/null and b/cellpose/__pycache__/models.cpython-311.pyc differ diff --git a/cellpose/__pycache__/models.cpython-312.pyc b/cellpose/__pycache__/models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c085d2d24f7ede30539e3cd15cd330966108bd24 Binary files /dev/null and b/cellpose/__pycache__/models.cpython-312.pyc differ diff --git a/cellpose/__pycache__/plot.cpython-310.pyc b/cellpose/__pycache__/plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea04fb0793a5e924ec2fab9d8d5feb66b6d02b9 Binary files /dev/null and b/cellpose/__pycache__/plot.cpython-310.pyc differ diff --git a/cellpose/__pycache__/plot.cpython-311.pyc b/cellpose/__pycache__/plot.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93aed81d324ac260ece773da03a7c6fa1511d8f Binary files /dev/null and b/cellpose/__pycache__/plot.cpython-311.pyc differ diff --git a/cellpose/__pycache__/plot.cpython-312.pyc b/cellpose/__pycache__/plot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3f38da5a63c1811a31e44a2699e4e4c5e159a4a Binary files /dev/null and b/cellpose/__pycache__/plot.cpython-312.pyc differ diff --git a/cellpose/__pycache__/resnet_torch.cpython-310.pyc b/cellpose/__pycache__/resnet_torch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24fb000cc1ed4529bcaffcb983390ea09d566d04 Binary files /dev/null and b/cellpose/__pycache__/resnet_torch.cpython-310.pyc differ diff --git a/cellpose/__pycache__/resnet_torch.cpython-311.pyc b/cellpose/__pycache__/resnet_torch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2758c2293455c7b70986aa6c8bb78c7b01c8cae Binary files /dev/null and b/cellpose/__pycache__/resnet_torch.cpython-311.pyc differ diff --git a/cellpose/__pycache__/resnet_torch.cpython-312.pyc b/cellpose/__pycache__/resnet_torch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35561f86b3574c71f7f818ddad2ed4a7d216ff28 Binary files /dev/null and b/cellpose/__pycache__/resnet_torch.cpython-312.pyc differ diff --git a/cellpose/__pycache__/train.cpython-310.pyc b/cellpose/__pycache__/train.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b6ce3df8c32e54ea01d887c24390955c1de2a3d Binary files /dev/null and b/cellpose/__pycache__/train.cpython-310.pyc differ diff --git a/cellpose/__pycache__/train.cpython-312.pyc b/cellpose/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45c9a129f7e6baf921b9bbdd16ff191d8f5c176e Binary files /dev/null and b/cellpose/__pycache__/train.cpython-312.pyc differ diff --git a/cellpose/__pycache__/transforms.cpython-310.pyc b/cellpose/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1731f5d9dfa4998fe91909fd42344823daabdff4 Binary files /dev/null and b/cellpose/__pycache__/transforms.cpython-310.pyc differ diff --git a/cellpose/__pycache__/transforms.cpython-311.pyc b/cellpose/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0fc835a45f7f66d886178f500052b34407a8556 Binary files /dev/null and b/cellpose/__pycache__/transforms.cpython-311.pyc differ diff --git a/cellpose/__pycache__/transforms.cpython-312.pyc b/cellpose/__pycache__/transforms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f39e7e28f94f97148fde5e9d99f7bca85b9bbd83 Binary files /dev/null and b/cellpose/__pycache__/transforms.cpython-312.pyc differ diff --git a/cellpose/__pycache__/utils.cpython-310.pyc b/cellpose/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18ab946e07853c6263cf66097bf59a06510210d1 Binary files /dev/null and b/cellpose/__pycache__/utils.cpython-310.pyc differ diff --git a/cellpose/__pycache__/utils.cpython-311.pyc b/cellpose/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05bbb4980db6fa0026a3bebf2883b3beb3a11186 Binary files /dev/null and b/cellpose/__pycache__/utils.cpython-311.pyc differ diff --git a/cellpose/__pycache__/utils.cpython-312.pyc b/cellpose/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88a08e18a7343c4a14dc9e03b54903441ebe527b Binary files /dev/null and b/cellpose/__pycache__/utils.cpython-312.pyc differ diff --git a/cellpose/__pycache__/version.cpython-310.pyc b/cellpose/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eaf07082f1d5e1f23571e30a1ba9ea8a06a3fb1 Binary files /dev/null and b/cellpose/__pycache__/version.cpython-310.pyc differ diff --git a/cellpose/__pycache__/version.cpython-311.pyc b/cellpose/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51fd4cf427242d59d7f210b8d974859855079c02 Binary files /dev/null and b/cellpose/__pycache__/version.cpython-311.pyc differ diff --git a/cellpose/__pycache__/version.cpython-312.pyc b/cellpose/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f6544dc4596e8b2c63931d06fdadc0bcfa4b3fe Binary files /dev/null and b/cellpose/__pycache__/version.cpython-312.pyc differ diff --git a/cellpose/cli.py b/cellpose/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..be31d06dbfd2970adac0decb050c4bdee4e91a81 --- /dev/null +++ b/cellpose/cli.py @@ -0,0 +1,231 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden. +""" + +import argparse + + +def get_arg_parser(): + """ Parses command line arguments for cellpose main function + + Note: this function has to be in a separate file to allow autodoc to work for CLI. + The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes, + see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823 + """ + + parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters") + + # misc settings + parser.add_argument("--version", action="store_true", + help="show cellpose version info") + parser.add_argument( + "--verbose", action="store_true", + help="show information about running and settings and save to log") + parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode") + + # settings for CPU vs GPU + hardware_args = parser.add_argument_group("Hardware Arguments") + hardware_args.add_argument("--use_gpu", action="store_true", + help="use gpu if torch with cuda installed") + hardware_args.add_argument( + "--gpu_device", required=False, default="0", type=str, + help="which gpu device to use, use an integer for torch, or mps for M1") + hardware_args.add_argument("--check_mkl", action="store_true", + help="check if mkl working") + + # settings for locating and formatting images + input_img_args = parser.add_argument_group("Input Image Arguments") + input_img_args.add_argument("--dir", default=[], type=str, + help="folder containing data to run or train on.") + input_img_args.add_argument( + "--image_path", default=[], type=str, help= + "if given and --dir not given, run on single image instead of folder (cannot train with this option)" + ) + input_img_args.add_argument( + "--look_one_level_down", action="store_true", + help="run processing on all subdirectories of current folder") + input_img_args.add_argument("--img_filter", default=[], type=str, + help="end string for images to run on") + input_img_args.add_argument( + "--channel_axis", default=None, type=int, + help="axis of image which corresponds to image channels") + input_img_args.add_argument("--z_axis", default=None, type=int, + help="axis of image which corresponds to Z dimension") + input_img_args.add_argument( + "--chan", default=0, type=int, help= + "channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s") + input_img_args.add_argument( + "--chan2", default=0, type=int, help= + "nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s" + ) + input_img_args.add_argument("--invert", action="store_true", + help="invert grayscale channel") + input_img_args.add_argument( + "--all_channels", action="store_true", help= + "use all channels in image if using own model and images with special channels") + + # model settings + model_args = parser.add_argument_group("Model Arguments") + model_args.add_argument("--pretrained_model", required=False, default="cyto3", + type=str, + help="model to use for running or starting training") + model_args.add_argument("--restore_type", required=False, default=None, type=str, + help="model to use for image restoration") + model_args.add_argument("--chan2_restore", action="store_true", + help="use nuclei restore model for second channel") + model_args.add_argument( + "--add_model", required=False, default=None, type=str, + help="model path to copy model to hidden .cellpose folder for using in GUI/CLI") + model_args.add_argument( + "--transformer", action="store_true", help= + "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)") + model_args.add_argument("--pretrained_model_ortho", required=False, default=None, + type=str, + help="model to use for running 3D ortho views (ZY and ZX)") + + # algorithm settings + algorithm_args = parser.add_argument_group("Algorithm Arguments") + algorithm_args.add_argument( + "--no_resample", action="store_true", help= + "disable dynamics on full image (makes algorithm faster for images with large diameters)" + ) + algorithm_args.add_argument( + "--no_interp", action="store_true", + help="do not interpolate when running dynamics (was default)") + algorithm_args.add_argument("--no_norm", action="store_true", + help="do not normalize images (normalize=False)") + parser.add_argument( + '--norm_percentile', + nargs=2, # Require exactly two values + metavar=('VALUE1', 'VALUE2'), + help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)" + ) + algorithm_args.add_argument( + "--do_3D", action="store_true", + help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx") + algorithm_args.add_argument( + "--diameter", required=False, default=30., type=float, help= + "cell diameter, if 0 will use the diameter of the training labels used in the model, or with built-in model will estimate diameter for each image" + ) + algorithm_args.add_argument( + "--stitch_threshold", required=False, default=0.0, type=float, + help="compute masks in 2D then stitch together masks with IoU>0.9 across planes" + ) + algorithm_args.add_argument( + "--min_size", required=False, default=15, type=int, + help="minimum number of pixels per mask, can turn off with -1") + algorithm_args.add_argument( + "--flow3D_smooth", required=False, default=0, type=float, + help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing") + + algorithm_args.add_argument( + "--flow_threshold", default=0.4, type=float, help= + "flow error threshold, 0 turns off this optional QC step. Default: %(default)s") + algorithm_args.add_argument( + "--cellprob_threshold", default=0, type=float, + help="cellprob threshold, default is 0, decrease to find more and larger masks") + algorithm_args.add_argument( + "--niter", default=0, type=int, help= + "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs" + ) + + algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float, + help="anisotropy of volume in 3D") + algorithm_args.add_argument("--exclude_on_edges", action="store_true", + help="discard masks which touch edges of image") + algorithm_args.add_argument( + "--augment", action="store_true", + help="tiles image with overlapping tiles and flips overlapped regions to augment" + ) + + # output settings + output_args = parser.add_argument_group("Output Arguments") + output_args.add_argument( + "--save_png", action="store_true", + help="save masks as png") + output_args.add_argument( + "--save_tif", action="store_true", + help="save masks as tif") + output_args.add_argument( + "--output_name", default=None, type=str, + help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`") + output_args.add_argument("--no_npy", action="store_true", + help="suppress saving of npy") + output_args.add_argument( + "--savedir", default=None, type=str, help= + "folder to which segmentation results will be saved (defaults to input image directory)" + ) + output_args.add_argument( + "--dir_above", action="store_true", help= + "save output folders adjacent to image folder instead of inside it (off by default)" + ) + output_args.add_argument("--in_folders", action="store_true", + help="flag to save output in folders (off by default)") + output_args.add_argument( + "--save_flows", action="store_true", help= + "whether or not to save RGB images of flows when masks are saved (disabled by default)" + ) + output_args.add_argument( + "--save_outlines", action="store_true", help= + "whether or not to save RGB outline images when masks are saved (disabled by default)" + ) + output_args.add_argument( + "--save_rois", action="store_true", + help="whether or not to save ImageJ compatible ROI archive (disabled by default)" + ) + output_args.add_argument( + "--save_txt", action="store_true", + help="flag to enable txt outlines for ImageJ (disabled by default)") + output_args.add_argument( + "--save_mpl", action="store_true", + help="save a figure of image/mask/flows using matplotlib (disabled by default). " + "This is slow, especially with large images.") + + # training settings + training_args = parser.add_argument_group("Training Arguments") + training_args.add_argument("--train", action="store_true", + help="train network using images in dir") + training_args.add_argument("--train_size", action="store_true", + help="train size network at end of training") + training_args.add_argument("--test_dir", default=[], type=str, + help="folder containing test data (optional)") + training_args.add_argument( + "--file_list", default=[], type=str, help= + "path to list of files for training and testing and probabilities for each image (optional)" + ) + training_args.add_argument( + "--mask_filter", default="_masks", type=str, help= + "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s" + ) + training_args.add_argument( + "--diam_mean", default=30., type=float, help= + "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0" + ) + training_args.add_argument("--learning_rate", default=0.2, type=float, + help="learning rate. Default: %(default)s") + training_args.add_argument("--weight_decay", default=0.00001, type=float, + help="weight decay. Default: %(default)s") + training_args.add_argument("--n_epochs", default=500, type=int, + help="number of epochs. Default: %(default)s") + training_args.add_argument("--batch_size", default=8, type=int, + help="batch size. Default: %(default)s") + training_args.add_argument( + "--nimg_per_epoch", default=None, type=int, + help="number of train images per epoch. Default is to use all train images.") + training_args.add_argument( + "--nimg_test_per_epoch", default=None, type=int, + help="number of test images per epoch. Default is to use all test images.") + training_args.add_argument( + "--min_train_masks", default=5, type=int, help= + "minimum number of masks a training image must have to be used. Default: %(default)s" + ) + training_args.add_argument("--SGD", default=1, type=int, help="use SGD") + training_args.add_argument( + "--save_every", default=100, type=int, + help="number of epochs to skip between saves. Default: %(default)s") + training_args.add_argument( + "--model_name_out", default=None, type=str, + help="Name of model to save as, defaults to name describing model architecture. " + "Model is saved in the folder specified by --dir in models subfolder.") + + return parser diff --git a/cellpose/contrib/distributed_segmentation.py b/cellpose/contrib/distributed_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..42a21e3ad662622d2ee1e64e0a7a628a7011cc13 --- /dev/null +++ b/cellpose/contrib/distributed_segmentation.py @@ -0,0 +1,924 @@ +# stdlib imports +import os, getpass, datetime, pathlib, tempfile, functools, glob + +# non-stdlib core dependencies +import numpy as np +import scipy +import cellpose.io +import cellpose.models +import tifffile +import imagecodecs + +# distributed dependencies +import dask +import distributed +import dask_image.ndmeasure +import yaml +import zarr +import dask_jobqueue + + + + +######################## File format functions ################################ +def numpy_array_to_zarr(write_path, array, chunks): + """ + Store an in memory numpy array to disk as a chunked Zarr array + + Parameters + ---------- + write_path : string + Filepath where Zarr array will be created + + array : numpy.ndarray + The already loaded in-memory numpy array to store as zarr + + chunks : tuple, must be array.ndim length + How the array will be chunked in the Zarr array + + Returns + ------- + zarr.core.Array + A read+write reference to the zarr array on disk + """ + + zarr_array = zarr.open( + write_path, + mode='w', + shape=array.shape, + chunks=chunks, + dtype=array.dtype, + ) + zarr_array[...] = array + return zarr_array + + +def wrap_folder_of_tiffs( + filename_pattern, + block_index_pattern=r'_(Z)(\d+)(Y)(\d+)(X)(\d+)', +): + """ + Wrap a folder of tiff files with a zarr array without duplicating data. + Tiff files must all contain images with the same shape and data type. + Tiff file names must contain a pattern indicating where individual files + lie in the block grid. + + Distributed computing requires parallel access to small regions of your + image from different processes. This is best accomplished with chunked + file formats like Zarr and N5. This function can accommodate a folder of + tiff files, but it is not equivalent to reformating your data as Zarr or N5. + If your individual tiff files/tiles are huge, distributed performance will + be poor or not work at all. + + It does not make sense to use this function if you have only one tiff file. + That tiff file will become the only chunk in the zarr array, which means all + workers will have to load the entire image to fetch their crop of data anyway. + If you have a single tiff image, you should just reformat it with the + numpy_array_to_zarr function. Single tiff files too large to fit into system + memory are not be supported. + + Parameters + ---------- + filename_pattern : string + A glob pattern that will match all needed tif files + + block_index_pattern : regular expression string (default: r'_(Z)(\d+)(Y)(\d+)(X)(\d+)') + A regular expression pattern that indicates how to parse tiff filenames + to determine where each tiff file lies in the overall block grid + The default pattern assumes filenames like the following: + {any_prefix}_Z000Y000X000{any_suffix} + {any_prefix}_Z000Y000X001{any_suffix} + ... and so on + + Returns + ------- + zarr.core.Array + """ + + # define function to read individual files + def imread(fname): + with open(fname, 'rb') as fh: + return imagecodecs.tiff_decode(fh.read(), index=None) + + # create zarr store, open it as zarr array and return + store = tifffile.imread( + filename_pattern, + aszarr=True, + imread=imread, + pattern=block_index_pattern, + axestiled={x:x for x in range(3)}, + ) + return zarr.open(store=store) + + + + +######################## Cluster related functions ############################ + +#----------------------- config stuff ----------------------------------------# +DEFAULT_CONFIG_FILENAME = 'distributed_cellpose_dask_config.yaml' + +def _config_path(config_name): + """Add config directory path to config filename""" + return str(pathlib.Path.home()) + '/.config/dask/' + config_name + + +def _modify_dask_config( + config, + config_name=DEFAULT_CONFIG_FILENAME, +): + """ + Modifies dask config dictionary, but also dumps modified + config to disk as a yaml file in ~/.config/dask/. This + ensures that workers inherit config options. + """ + dask.config.set(config) + with open(_config_path(config_name), 'w') as f: + yaml.dump(dask.config.config, f, default_flow_style=False) + + +def _remove_config_file( + config_name=DEFAULT_CONFIG_FILENAME, +): + """Removes a config file from disk""" + config_path = _config_path(config_name) + if os.path.exists(config_path): os.remove(config_path) + + +#----------------------- clusters --------------------------------------------# +class myLocalCluster(distributed.LocalCluster): + """ + This is a thin wrapper extending dask.distributed.LocalCluster to set + configs before the cluster or workers are initialized. + + For a list of full arguments (how to specify your worker resources) see: + https://distributed.dask.org/en/latest/api.html#distributed.LocalCluster + You need to know how many cpu cores and how much RAM your machine has. + + Most users will only need to specify: + n_workers + ncpus (number of physical cpu cores per worker) + memory_limit (which is the limit per worker, should be a string like '16GB') + threads_per_worker (for most workflows this should be 1) + + You can also modify any dask configuration option through the + config argument. + + If your workstation has a GPU, one of the workers will have exclusive + access to it by default. That worker will be much faster than the others. + You may want to consider creating only one worker (which will have access + to the GPU) and letting that worker process all blocks serially. + """ + + def __init__( + self, + ncpus, + config={}, + config_name=DEFAULT_CONFIG_FILENAME, + persist_config=False, + **kwargs, + ): + # config + self.config_name = config_name + self.persist_config = persist_config + scratch_dir = f"{os.getcwd()}/" + scratch_dir += f".{getpass.getuser()}_distributed_cellpose/" + config_defaults = {'temporary-directory':scratch_dir} + config = {**config_defaults, **config} + _modify_dask_config(config, config_name) + + # construct + if "host" not in kwargs: kwargs["host"] = "" + super().__init__(**kwargs) + self.client = distributed.Client(self) + + # set environment variables for workers (threading) + environment_vars = { + 'MKL_NUM_THREADS':str(2*ncpus), + 'NUM_MKL_THREADS':str(2*ncpus), + 'OPENBLAS_NUM_THREADS':str(2*ncpus), + 'OPENMP_NUM_THREADS':str(2*ncpus), + 'OMP_NUM_THREADS':str(2*ncpus), + } + def set_environment_vars(): + for k, v in environment_vars.items(): + os.environ[k] = v + self.client.run(set_environment_vars) + + print("Cluster dashboard link: ", self.dashboard_link) + + def __enter__(self): return self + def __exit__(self, exc_type, exc_value, traceback): + if not self.persist_config: + _remove_config_file(self.config_name) + self.client.close() + super().__exit__(exc_type, exc_value, traceback) + + +class janeliaLSFCluster(dask_jobqueue.LSFCluster): + """ + This is a thin wrapper extending dask_jobqueue.LSFCluster, + which in turn extends dask.distributed.SpecCluster. This wrapper + sets configs before the cluster or workers are initialized. This is + an adaptive cluster and will scale the number of workers, between user + specified limits, based on the number of pending tasks. This wrapper + also enforces conventions specific to the Janelia LSF cluster. + + For a full list of arguments see + https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.LSFCluster.html + + Most users will only need to specify: + ncpus (the number of cpu cores per worker) + min_workers + max_workers + """ + + def __init__( + self, + ncpus, + min_workers, + max_workers, + config={}, + config_name=DEFAULT_CONFIG_FILENAME, + persist_config=False, + **kwargs + ): + + # store all args in case needed later + self.locals_store = {**locals()} + + # config + self.config_name = config_name + self.persist_config = persist_config + scratch_dir = f"/scratch/{getpass.getuser()}/" + config_defaults = { + 'temporary-directory':scratch_dir, + 'distributed.comm.timeouts.connect':'180s', + 'distributed.comm.timeouts.tcp':'360s', + } + config = {**config_defaults, **config} + _modify_dask_config(config, config_name) + + # threading is best in low level libraries + job_script_prologue = [ + f"export MKL_NUM_THREADS={2*ncpus}", + f"export NUM_MKL_THREADS={2*ncpus}", + f"export OPENBLAS_NUM_THREADS={2*ncpus}", + f"export OPENMP_NUM_THREADS={2*ncpus}", + f"export OMP_NUM_THREADS={2*ncpus}", + ] + + # set scratch and log directories + if "local_directory" not in kwargs: + kwargs["local_directory"] = scratch_dir + if "log_directory" not in kwargs: + log_dir = f"{os.getcwd()}/dask_worker_logs_{os.getpid()}/" + pathlib.Path(log_dir).mkdir(parents=False, exist_ok=True) + kwargs["log_directory"] = log_dir + + # graceful exit for lsf jobs (adds -d flag) + class quietLSFJob(dask_jobqueue.lsf.LSFJob): + cancel_command = "bkill -d" + + # construct + super().__init__( + ncpus=ncpus, + processes=1, + cores=1, + memory=str(15*ncpus)+'GB', + mem=int(15e9*ncpus), + job_script_prologue=job_script_prologue, + job_cls=quietLSFJob, + **kwargs, + ) + self.client = distributed.Client(self) + print("Cluster dashboard link: ", self.dashboard_link) + + # set adaptive cluster bounds + self.adapt_cluster(min_workers, max_workers) + + + def __enter__(self): return self + def __exit__(self, exc_type, exc_value, traceback): + if not self.persist_config: + _remove_config_file(self.config_name) + self.client.close() + super().__exit__(exc_type, exc_value, traceback) + + + def adapt_cluster(self, min_workers, max_workers): + _ = self.adapt( + minimum_jobs=min_workers, + maximum_jobs=max_workers, + interval='10s', + wait_count=6, + ) + + + def change_worker_attributes( + self, + min_workers, + max_workers, + **kwargs, + ): + """WARNING: this function is dangerous if you don't know what + you're doing. Don't call this unless you know exactly what + this does.""" + self.scale(0) + for k, v in kwargs.items(): + self.new_spec['options'][k] = v + self.adapt_cluster(min_workers, max_workers) + + +#----------------------- decorator -------------------------------------------# +def cluster(func): + """ + This decorator ensures a function will run inside a cluster + as a context manager. The decorated function, "func", must + accept "cluster" and "cluster_kwargs" as parameters. If + "cluster" is not None then the user has provided an existing + cluster and we just run func. If "cluster" is None then + "cluster_kwargs" are used to construct a new cluster, and + the function is run inside that cluster context. + """ + @functools.wraps(func) + def create_or_pass_cluster(*args, **kwargs): + # TODO: this only checks if args are explicitly present in function call + # it does not check if they are set correctly in any way + assert 'cluster' in kwargs or 'cluster_kwargs' in kwargs, \ + "Either cluster or cluster_kwargs must be defined" + if not 'cluster' in kwargs: + cluster_constructor = myLocalCluster + F = lambda x: x in kwargs['cluster_kwargs'] + if F('ncpus') and F('min_workers') and F('max_workers'): + cluster_constructor = janeliaLSFCluster + with cluster_constructor(**kwargs['cluster_kwargs']) as cluster: + kwargs['cluster'] = cluster + return func(*args, **kwargs) + return func(*args, **kwargs) + return create_or_pass_cluster + + + + +######################## the function to run on each block #################### + +#----------------------- The main function -----------------------------------# +def process_block( + block_index, + crop, + input_zarr, + model_kwargs, + eval_kwargs, + blocksize, + overlap, + output_zarr, + preprocessing_steps=[], + worker_logs_directory=None, + test_mode=False, +): + """ + Preprocess and segment one block, of many, with eventual merger + of all blocks in mind. The block is processed as follows: + + (1) Read block from disk, preprocess, and segment. + (2) Remove overlaps. + (3) Get bounding boxes for every segment. + (4) Remap segment IDs to globally unique values. + (5) Write segments to disk. + (6) Get segmented block faces. + + A user may want to test this function on one block before running + the distributed function. When test_mode=True, steps (5) and (6) + are omitted and replaced with: + + (5) return remapped segments as a numpy array, boxes, and box_ids + + Parameters + ---------- + block_index : tuple + The (i, j, k, ...) index of the block in the overall block grid + + crop : tuple of slice objects + The bounding box of the data to read from the input_zarr array + + input_zarr : zarr.core.Array + The image data we want to segment + + preprocessing_steps : list of tuples (default: the empty list) + Optionally apply an arbitrary pipeline of preprocessing steps + to the image block before running cellpose. + + Must be in the following format: + [(f, {'arg1':val1, ...}), ...] + That is, each tuple must contain only two elements, a function + and a dictionary. The function must have the following signature: + def F(image, ..., crop=None) + That is, the first argument must be a numpy array, which will later + be populated by the image data. The function must also take a keyword + argument called crop, even if it is not used in the function itself. + All other arguments to the function are passed using the dictionary. + Here is an example: + + def F(image, sigma, crop=None): + return gaussian_filter(image, sigma) + def G(image, radius, crop=None): + return median_filter(image, radius) + preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})] + + model_kwargs : dict + Arguments passed to cellpose.models.Cellpose + This is how you select and parameterize a model. + + eval_kwargs : dict + Arguments passed to the eval function of the Cellpose model + This is how you parameterize model evaluation. + + blocksize : iterable (list, tuple, np.ndarray) + The number of voxels (the shape) of blocks without overlaps + + overlap : int + The number of voxels added to the blocksize to provide context + at the edges + + output_zarr : zarr.core.Array + A location where segments can be stored temporarily before + merger is complete + + worker_logs_directory : string (default: None) + A directory path where log files for each worker can be created + The directory must exist + + test_mode : bool (default: False) + The primary use case of this function is to be called by + distributed_eval (defined later in this same module). However + you may want to call this function manually to test what + happens to an individual block; this is a good idea before + ramping up to process big data and also useful for debugging. + + When test_mode is False (default) this function stores + the segments and returns objects needed for merging between + blocks. + + When test_mode is True this function does not store the + segments, and instead returns them to the caller as a numpy + array. The boxes and box IDs are also returned. When test_mode + is True, you can supply dummy values for many of the inputs, + such as: + + block_index = (0, 0, 0) + output_zarr=None + + Returns + ------- + If test_mode == False (the default), three things are returned: + faces : a list of numpy arrays - the faces of the block segments + boxes : a list of crops (tuples of slices), bounding boxes of segments + box_ids : 1D numpy array, parallel to boxes, the segment IDs of the + boxes + + If test_mode == True, three things are returned: + segments : np.ndarray containing the segments with globally unique IDs + boxes : a list of crops (tuples of slices), bounding boxes of segments + box_ids : 1D numpy array, parallel to boxes, the segment IDs of the + boxes + """ + print('RUNNING BLOCK: ', block_index, '\tREGION: ', crop, flush=True) + segmentation = read_preprocess_and_segment( + input_zarr, crop, preprocessing_steps, model_kwargs, eval_kwargs, + worker_logs_directory, + ) + segmentation, crop = remove_overlaps( + segmentation, crop, overlap, blocksize, + ) + boxes = bounding_boxes_in_global_coordinates(segmentation, crop) + nblocks = get_nblocks(input_zarr.shape, blocksize) + segmentation, remap = global_segment_ids(segmentation, block_index, nblocks) + if remap[0] == 0: remap = remap[1:] + + if test_mode: return segmentation, boxes, remap + output_zarr[tuple(crop)] = segmentation + faces = block_faces(segmentation) + return faces, boxes, remap + + +#----------------------- component functions ---------------------------------# +def read_preprocess_and_segment( + input_zarr, + crop, + preprocessing_steps, + model_kwargs, + eval_kwargs, + worker_logs_directory, +): + """Read block from zarr array, run all preprocessing steps, run cellpose""" + image = input_zarr[crop] + for pp_step in preprocessing_steps: + pp_step[1]['crop'] = crop + image = pp_step[0](image, **pp_step[1]) + log_file=None + if worker_logs_directory is not None: + log_file = f'dask_worker_{distributed.get_worker().name}.log' + log_file = pathlib.Path(worker_logs_directory).joinpath(log_file) + cellpose.io.logger_setup(stdout_file_replacement=log_file) + model = cellpose.models.CellposeModel(**model_kwargs) + return model.eval(image, **eval_kwargs)[0].astype(np.uint32) + + +def remove_overlaps(array, crop, overlap, blocksize): + """overlaps only there to provide context for boundary voxels + and can be removed after segmentation is complete + reslice array to remove the overlaps""" + crop_trimmed = list(crop) + for axis in range(array.ndim): + if crop[axis].start != 0: + slc = [slice(None),]*array.ndim + slc[axis] = slice(overlap, None) + array = array[tuple(slc)] + a, b = crop[axis].start, crop[axis].stop + crop_trimmed[axis] = slice(a + overlap, b) + if array.shape[axis] > blocksize[axis]: + slc = [slice(None),]*array.ndim + slc[axis] = slice(None, blocksize[axis]) + array = array[tuple(slc)] + a = crop_trimmed[axis].start + crop_trimmed[axis] = slice(a, a + blocksize[axis]) + return array, crop_trimmed + + +def bounding_boxes_in_global_coordinates(segmentation, crop): + """bounding boxes (tuples of slices) are super useful later + best to compute them now while things are distributed""" + boxes = scipy.ndimage.find_objects(segmentation) + boxes = [b for b in boxes if b is not None] + translate = lambda a, b: slice(a.start+b.start, a.start+b.stop) + for iii, box in enumerate(boxes): + boxes[iii] = tuple(translate(a, b) for a, b in zip(crop, box)) + return boxes + + +def get_nblocks(shape, blocksize): + """Given a shape and blocksize determine the number of blocks per axis""" + return np.ceil(np.array(shape) / blocksize).astype(int) + + +def global_segment_ids(segmentation, block_index, nblocks): + """pack the block index into the segment IDs so they are + globally unique. Everything gets remapped to [1..N] later. + A uint32 is split into 5 digits on left and 5 digits on right. + This creates limits: 42950 maximum number of blocks and + 99999 maximum number of segments per block""" + unique, unique_inverse = np.unique(segmentation, return_inverse=True) + p = str(np.ravel_multi_index(block_index, nblocks)) + remap = [np.uint32(p+str(x).zfill(5)) for x in unique] + if unique[0] == 0: remap[0] = np.uint32(0) # 0 should just always be 0 + segmentation = np.array(remap)[unique_inverse.reshape(segmentation.shape)] + return segmentation, remap + + +def block_faces(segmentation): + """slice faces along every axis""" + faces = [] + for iii in range(segmentation.ndim): + a = [slice(None),] * segmentation.ndim + a[iii] = slice(0, 1) + faces.append(segmentation[tuple(a)]) + a = [slice(None),] * segmentation.ndim + a[iii] = slice(-1, None) + faces.append(segmentation[tuple(a)]) + return faces + + + + +######################## Distributed Cellpose ################################# + +#----------------------- The main function -----------------------------------# +@cluster +def distributed_eval( + input_zarr, + blocksize, + write_path, + mask=None, + preprocessing_steps=[], + model_kwargs={}, + eval_kwargs={}, + cluster=None, + cluster_kwargs={}, + temporary_directory=None, +): + """ + Evaluate a cellpose model on overlapping blocks of a big image. + Distributed over workstation or cluster resources with Dask. + Optionally run preprocessing steps on the blocks before running cellpose. + Optionally use a mask to ignore background regions in image. + Either cluster or cluster_kwargs parameter must be set to a + non-default value; please read these parameter descriptions below. + If using cluster_kwargs, the workstation and Janelia LSF cluster cases + are distinguished by the arguments present in the dictionary. + + PC/Mac/Linux workstations and the Janelia LSF cluster are supported; + running on a different institute cluster will require implementing your + own dask cluster class. Look at the JaneliaLSFCluster class in this + module as an example, also look at the dask_jobqueue library. A PR with + a solid start is the right way to get help running this on your own + institute cluster. + + If running on a workstation, please read the docstring for the + LocalCluster class defined in this module. That will tell you what to + put in the cluster_kwargs dictionary. If using the Janelia cluster, + please read the docstring for the JaneliaLSFCluster class. + + Parameters + ---------- + input_zarr : zarr.core.Array + A zarr.core.Array instance containing the image data you want to + segment. + + blocksize : iterable + The size of blocks in voxels. E.g. [128, 256, 256] + + write_path : string + The location of a zarr file on disk where you'd like to write your results + + mask : numpy.ndarray (default: None) + A foreground mask for the image data; may be at a different resolution + (e.g. lower) than the image data. If given, only blocks that contain + foreground will be processed. This can save considerable time and + expense. It is assumed that the domain of the input_zarr image data + and the mask is the same in physical units, but they may be on + different sampling/voxel grids. + + preprocessing_steps : list of tuples (default: the empty list) + Optionally apply an arbitrary pipeline of preprocessing steps + to the image blocks before running cellpose. + + Must be in the following format: + [(f, {'arg1':val1, ...}), ...] + That is, each tuple must contain only two elements, a function + and a dictionary. The function must have the following signature: + def F(image, ..., crop=None) + That is, the first argument must be a numpy array, which will later + be populated by the image data. The function must also take a keyword + argument called crop, even if it is not used in the function itself. + All other arguments to the function are passed using the dictionary. + Here is an example: + + def F(image, sigma, crop=None): + return gaussian_filter(image, sigma) + def G(image, radius, crop=None): + return median_filter(image, radius) + preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})] + + model_kwargs : dict (default: {}) + Arguments passed to cellpose.models.Cellpose + + eval_kwargs : dict (default: {}) + Arguments passed to cellpose.models.Cellpose.eval + + cluster : A dask cluster object (default: None) + Only set if you have constructed your own static cluster. The default + behavior is to construct a dask cluster for the duration of this function, + then close it when the function is finished. + + cluster_kwargs : dict (default: {}) + Arguments used to parameterize your cluster. + If you are running locally, see the docstring for the myLocalCluster + class in this module. If you are running on the Janelia LSF cluster, see + the docstring for the janeliaLSFCluster class in this module. If you are + running on a different institute cluster, you may need to implement + a dask cluster object that conforms to the requirements of your cluster. + + temporary_directory : string (default: None) + Temporary files are created during segmentation. The temporary files + will be in their own folder within the temporary_directory. The default + is the current directory. Temporary files are removed if the function + completes successfully. + + Returns + ------- + Two values are returned: + (1) A reference to the zarr array on disk containing the stitched cellpose + segments for your entire image + (2) Bounding boxes for every segment. This is a list of tuples of slices: + [(slice(z1, z2), slice(y1, y2), slice(x1, x2)), ...] + The list is sorted according to segment ID. That is the smallest segment + ID is the first tuple in the list, the largest segment ID is the last + tuple in the list. + """ + + timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + worker_logs_dirname = f'dask_worker_logs_{timestamp}' + worker_logs_dir = pathlib.Path().absolute().joinpath(worker_logs_dirname) + worker_logs_dir.mkdir() + + if 'diameter' not in eval_kwargs.keys(): + eval_kwargs['diameter'] = 30 + overlap = eval_kwargs['diameter'] * 2 + block_indices, block_crops = get_block_crops( + input_zarr.shape, blocksize, overlap, mask, + ) + + # I hate indenting all that code just for the tempdir + # but context manager is the only way to really guarantee that + # the tempdir gets cleaned up even after unhandled exceptions + with tempfile.TemporaryDirectory( + prefix='.', suffix='_distributed_cellpose_tempdir', + dir=temporary_directory or os.getcwd(), + ) as temporary_directory: + + temp_zarr_path = temporary_directory + '/segmentation_unstitched.zarr' + temp_zarr = zarr.open( + temp_zarr_path, 'w', + shape=input_zarr.shape, + chunks=blocksize, + dtype=np.uint32, + ) + + futures = cluster.client.map( + process_block, + block_indices, + block_crops, + input_zarr=input_zarr, + preprocessing_steps=preprocessing_steps, + model_kwargs=model_kwargs, + eval_kwargs=eval_kwargs, + blocksize=blocksize, + overlap=overlap, + output_zarr=temp_zarr, + worker_logs_directory=str(worker_logs_dir), + ) + results = cluster.client.gather(futures) + if isinstance(cluster, dask_jobqueue.core.JobQueueCluster): + cluster.scale(0) + + faces, boxes_, box_ids_ = list(zip(*results)) + boxes = [box for sublist in boxes_ for box in sublist] + box_ids = np.concatenate(box_ids_).astype(int) # unsure how but without cast these are float64 + new_labeling = determine_merge_relabeling(block_indices, faces, box_ids) + debug_unique = np.unique(new_labeling) + new_labeling_path = temporary_directory + '/new_labeling.npy' + np.save(new_labeling_path, new_labeling) + + # stitching step is cheap, we should release gpus and use small workers + if isinstance(cluster, dask_jobqueue.core.JobQueueCluster): + cluster.change_worker_attributes( + min_workers=cluster.locals_store['min_workers'], + max_workers=cluster.locals_store['max_workers'], + ncpus=1, + memory="15GB", + mem=int(15e9), + queue=None, + job_extra_directives=[], + ) + + segmentation_da = dask.array.from_zarr(temp_zarr) + relabeled = dask.array.map_blocks( + lambda block: np.load(new_labeling_path)[block], + segmentation_da, + dtype=np.uint32, + chunks=segmentation_da.chunks, + ) + dask.array.to_zarr(relabeled, write_path, overwrite=True) + merged_boxes = merge_all_boxes(boxes, new_labeling[box_ids]) + return zarr.open(write_path, mode='r'), merged_boxes + + +#----------------------- component functions ---------------------------------# +def get_block_crops(shape, blocksize, overlap, mask): + """Given a voxel grid shape, blocksize, and overlap size, construct + tuples of slices for every block; optionally only include blocks + that contain foreground in the mask. Returns parallel lists, + the block indices and the slice tuples.""" + blocksize = np.array(blocksize) + if mask is not None: + ratio = np.array(mask.shape) / shape + mask_blocksize = np.round(ratio * blocksize).astype(int) + + indices, crops = [], [] + nblocks = get_nblocks(shape, blocksize) + for index in np.ndindex(*nblocks): + start = blocksize * index - overlap + stop = start + blocksize + 2 * overlap + start = np.maximum(0, start) + stop = np.minimum(shape, stop) + crop = tuple(slice(x, y) for x, y in zip(start, stop)) + + foreground = True + if mask is not None: + start = mask_blocksize * index + stop = start + mask_blocksize + stop = np.minimum(mask.shape, stop) + mask_crop = tuple(slice(x, y) for x, y in zip(start, stop)) + if not np.any(mask[mask_crop]): foreground = False + if foreground: + indices.append(index) + crops.append(crop) + return indices, crops + + +def determine_merge_relabeling(block_indices, faces, used_labels): + """Determine boundary segment mergers, remap all label IDs to merge + and put all label IDs in range [1..N] for N global segments found""" + faces = adjacent_faces(block_indices, faces) + # FIX float parameters + # print("Used labels:", used_labels, "Type:", type(used_labels)) + used_labels = used_labels.astype(int) + # print("Used labels:", used_labels, "Type:", type(used_labels)) + label_range = int(np.max(used_labels)) + + label_groups = block_face_adjacency_graph(faces, label_range) + new_labeling = scipy.sparse.csgraph.connected_components( + label_groups, directed=False)[1] + # XXX: new_labeling is returned as int32. Loses half range. Potentially a problem. + unused_labels = np.ones(label_range + 1, dtype=bool) + unused_labels[used_labels] = 0 + new_labeling[unused_labels] = 0 + unique, unique_inverse = np.unique(new_labeling, return_inverse=True) + new_labeling = np.arange(len(unique), dtype=np.uint32)[unique_inverse] + return new_labeling + + +def adjacent_faces(block_indices, faces): + """Find faces which touch and pair them together in new data structure""" + face_pairs = [] + faces_index_lookup = {a:b for a, b in zip(block_indices, faces)} + for block_index in block_indices: + for ax in range(len(block_index)): + neighbor_index = np.array(block_index) + neighbor_index[ax] += 1 + neighbor_index = tuple(neighbor_index) + try: + a = faces_index_lookup[block_index][2*ax + 1] + b = faces_index_lookup[neighbor_index][2*ax] + face_pairs.append( np.concatenate((a, b), axis=ax) ) + except KeyError: + continue + return face_pairs + + +def block_face_adjacency_graph(faces, nlabels): + """Shrink labels in face plane, then find which labels touch across the + face boundary""" + # FIX float parameters + # print("Initial nlabels:", nlabels, "Type:", type(nlabels)) + nlabels = int(nlabels) + # print("Final nlabels:", nlabels, "Type:", type(nlabels)) + + all_mappings = [] + structure = scipy.ndimage.generate_binary_structure(3, 1) + for face in faces: + sl0 = tuple(slice(0, 1) if d==2 else slice(None) for d in face.shape) + sl1 = tuple(slice(1, 2) if d==2 else slice(None) for d in face.shape) + a = shrink_labels(face[sl0], 1.0) + b = shrink_labels(face[sl1], 1.0) + face = np.concatenate((a, b), axis=np.argmin(a.shape)) + mapped = dask_image.ndmeasure._utils._label._across_block_label_grouping(face, structure) + all_mappings.append(mapped) + i, j = np.concatenate(all_mappings, axis=1) + v = np.ones_like(i) + return scipy.sparse.coo_matrix((v, (i, j)), shape=(nlabels+1, nlabels+1)).tocsr() + + +def shrink_labels(plane, threshold): + """Shrink labels in plane by some distance from their boundary""" + gradmag = np.linalg.norm(np.gradient(plane.squeeze()), axis=0) + shrunk_labels = np.copy(plane.squeeze()) + shrunk_labels[gradmag > 0] = 0 + distances = scipy.ndimage.distance_transform_edt(shrunk_labels) + shrunk_labels[distances <= threshold] = 0 + return shrunk_labels.reshape(plane.shape) + + +def merge_all_boxes(boxes, box_ids): + """Merge all boxes that map to the same box_ids""" + merged_boxes = [] + boxes_array = np.array(boxes, dtype=object) + # FIX float parameters + # print("Box IDs:", box_ids, "Type:", type(box_ids)) + box_ids = box_ids.astype(int) + # print("Box IDs:", box_ids, "Type:", type(box_ids)) + + for iii in np.unique(box_ids): + merge_indices = np.argwhere(box_ids == iii).squeeze() + if merge_indices.shape: + merged_box = merge_boxes(boxes_array[merge_indices]) + else: + merged_box = boxes_array[merge_indices] + merged_boxes.append(merged_box) + return merged_boxes + + +def merge_boxes(boxes): + """Take union of two or more parallelpipeds""" + box_union = boxes[0] + for iii in range(1, len(boxes)): + local_union = [] + for s1, s2 in zip(box_union, boxes[iii]): + start = min(s1.start, s2.start) + stop = max(s1.stop, s2.stop) + local_union.append(slice(start, stop)) + box_union = tuple(local_union) + return box_union + + diff --git a/cellpose/core.py b/cellpose/core.py new file mode 100644 index 0000000000000000000000000000000000000000..d4297ee727588699b5075aee6631ce07bc942f5f --- /dev/null +++ b/cellpose/core.py @@ -0,0 +1,331 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess +import logging +import numpy as np +from tqdm import trange, tqdm +from urllib.parse import urlparse +import tempfile +import cv2 +from scipy.stats import mode +import fastremap +from . import transforms, dynamics, utils, plot, metrics, resnet_torch + +import torch +from torch import nn +from torch.utils import mkldnn as mkldnn_utils + +TORCH_ENABLED = True + +core_logger = logging.getLogger(__name__) +tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO) + + +def use_gpu(gpu_number=0, use_torch=True): + """ + Check if GPU is available for use. + + Args: + gpu_number (int): The index of the GPU to be used. Default is 0. + use_torch (bool): Whether to use PyTorch for GPU check. Default is True. + + Returns: + bool: True if GPU is available, False otherwise. + + Raises: + ValueError: If use_torch is False, as cellpose only runs with PyTorch now. + """ + if use_torch: + return _use_gpu_torch(gpu_number) + else: + raise ValueError("cellpose only runs with PyTorch now") + + +def _use_gpu_torch(gpu_number=0): + """ + Checks if CUDA or MPS is available and working with PyTorch. + + Args: + gpu_number (int): The GPU device number to use (default is 0). + + Returns: + bool: True if CUDA or MPS is available and working, False otherwise. + """ + try: + device = torch.device("cuda:" + str(gpu_number)) + _ = torch.zeros((1,1)).to(device) + core_logger.info("** TORCH CUDA version installed and working. **") + return True + except: + pass + try: + device = torch.device('mps:' + str(gpu_number)) + _ = torch.zeros((1,1)).to(device) + core_logger.info('** TORCH MPS version installed and working. **') + return True + except: + core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.') + return False + + +def assign_device(use_torch=True, gpu=False, device=0): + """ + Assigns the device (CPU or GPU or mps) to be used for computation. + + Args: + use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True. + gpu (bool, optional): Whether to use GPU for computation. Defaults to False. + device (int or str, optional): The device index or name to be used. Defaults to 0. + + Returns: + torch.device, bool (True if GPU is used, False otherwise) + """ + + if isinstance(device, str): + if device != "mps" or not(gpu and torch.backends.mps.is_available()): + device = int(device) + if gpu and use_gpu(use_torch=True): + try: + if torch.cuda.is_available(): + device = torch.device(f'cuda:{device}') + core_logger.info(">>>> using GPU (CUDA)") + gpu = True + cpu = False + except: + gpu = False + cpu = True + try: + if torch.backends.mps.is_available(): + device = torch.device('mps') + core_logger.info(">>>> using GPU (MPS)") + gpu = True + cpu = False + except: + gpu = False + cpu = True + else: + device = torch.device('cpu') + core_logger.info('>>>> using CPU') + gpu = False + cpu = True + + if cpu: + device = torch.device("cpu") + core_logger.info(">>>> using CPU") + gpu = False + return device, gpu + + +def check_mkl(use_torch=True): + """ + Checks if MKL-DNN is enabled and working. + + Args: + use_torch (bool, optional): Whether to use torch. Defaults to True. + + Returns: + bool: True if MKL-DNN is enabled, False otherwise. + """ + mkl_enabled = torch.backends.mkldnn.is_available() + if mkl_enabled: + mkl_enabled = True + else: + core_logger.info( + "WARNING: MKL version on torch not working/installed - CPU version will be slightly slower." + ) + core_logger.info( + "see https://pytorch.org/docs/stable/backends.html?highlight=mkl") + return mkl_enabled + + +def _to_device(x, device): + """ + Converts the input tensor or numpy array to the specified device. + + Args: + x (torch.Tensor or numpy.ndarray): The input tensor or numpy array. + device (torch.device): The target device. + + Returns: + torch.Tensor: The converted tensor on the specified device. + """ + if not isinstance(x, torch.Tensor): + X = torch.from_numpy(x).to(device, dtype=torch.float32) + return X + else: + return x + + +def _from_device(X): + """ + Converts a PyTorch tensor from the device to a NumPy array on the CPU. + + Args: + X (torch.Tensor): The input PyTorch tensor. + + Returns: + numpy.ndarray: The converted NumPy array. + """ + x = X.detach().cpu().numpy() + return x + + +def _forward(net, x): + """Converts images to torch tensors, runs the network model, and returns numpy arrays. + + Args: + net (torch.nn.Module): The network model. + x (numpy.ndarray): The input images. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features. + """ + X = _to_device(x, net.device) + net.eval() + if net.mkldnn: + net = mkldnn_utils.to_mkldnn(net) + with torch.no_grad(): + y, style = net(X)[:2] + del X + y = _from_device(y) + style = _from_device(style) + return y, style + + +def run_net(net, imgi, batch_size=8, augment=False, tile_overlap=0.1, bsize=224, + rsz=None): + """ + Run network on stack of images. + + (faster if augment is False) + + Args: + net (class): cellpose network (model.net) + imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan]. + batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. + rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. + augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. + tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1. + bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3]. + y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability. + style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. + """ + # run network + nout = net.nout + Lz, Ly0, Lx0, nchan = imgi.shape + if rsz is not None: + if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray): + rsz = [rsz, rsz] + Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1]) + else: + Lyr, Lxr = Ly0, Lx0 + ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr) + pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]]) + Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 + if augment: + ny = max(2, int(np.ceil(2. * Ly / bsize))) + nx = max(2, int(np.ceil(2. * Lx / bsize))) + ly, lx = bsize, bsize + else: + ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) + nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) + ly, lx = min(bsize, Ly), min(bsize, Lx) + yf = np.zeros((Lz, nout, Ly, Lx), "float32") + styles = np.zeros((Lz, 256), "float32") + + # run multiple slices at the same time + ntiles = ny * nx + nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch + niter = int(np.ceil(Lz / nimgs)) + ziterator = (trange(niter, file=tqdm_out, mininterval=30) + if niter > 10 or Lz > 1 else range(niter)) + for k in ziterator: + inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs)) + IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") + for i, b in enumerate(inds): + # pad image for net so Ly and Lx are divisible by 4 + imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy() + imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") + IMG, ysub, xsub, Ly, Lx = transforms.make_tiles( + imgb, bsize=bsize, augment=augment, + tile_overlap=tile_overlap) + IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG, + (ny * nx, nchan, ly, lx)) + + ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32") + stylea = np.zeros((IMGa.shape[0], 256), "float32") + for j in range(0, IMGa.shape[0], batch_size): + bslc = slice(j, min(j + batch_size, IMGa.shape[0])) + ya[bslc], stylea[bslc] = _forward(net, IMGa[bslc]) + for i, b in enumerate(inds): + y = ya[i * ntiles : (i + 1) * ntiles] + if augment: + y = np.reshape(y, (ny, nx, 3, ly, lx)) + y = transforms.unaugment_tiles(y) + y = np.reshape(y, (-1, 3, ly, lx)) + yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx) + yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]] + stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0) + stylei /= (stylei**2).sum()**0.5 + styles[b] = stylei + # slices from padding + yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2] + yf = yf.transpose(0,2,3,1) + return yf, np.array(styles) + + +def run_3D(net, imgs, batch_size=8, augment=False, + tile_overlap=0.1, bsize=224, net_ortho=None, + progress=None): + """ + Run network on image z-stack. + + (faster if augment is False) + + Args: + imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan]. + batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. + rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. + anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. + augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. + tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1. + bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224. + net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None. + progress (QProgressBar, optional): pyqt progress bar. Defaults to None. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3]. + y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability. + style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. + """ + sstr = ["YX", "ZY", "ZX"] + pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)] + ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)] + cp = [(1, 2), (0, 2), (0, 1)] + cpy = [(0, 1), (0, 1), (0, 1)] + shape = imgs.shape[:-1] + #cellprob = np.zeros(shape, "float32") + yf = np.zeros((*shape, 4), "float32") + for p in range(3): + xsl = imgs.transpose(pm[p]) + # per image + core_logger.info("running %s: %d planes of size (%d, %d)" % + (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]])) + y, style = run_net(net if p==0 or net_ortho is None else net_ortho, + xsl, batch_size=batch_size, augment=augment, + bsize=bsize, tile_overlap=tile_overlap, + rsz=None) + yf[..., -1] += y[..., -1].transpose(ipm[p]) + for j in range(2): + yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p]) + y = None; del y + + if progress is not None: + progress.setValue(25 + 15 * p) + + return yf, style diff --git a/cellpose/denoise.py b/cellpose/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3ef5d54bbca9be70a9b416151df17d407a0f5b --- /dev/null +++ b/cellpose/denoise.py @@ -0,0 +1,1484 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os, time, datetime +import numpy as np +from scipy.stats import mode +import cv2 +import torch +from torch import nn +from torch.nn.functional import conv2d, interpolate +from tqdm import trange +from pathlib import Path + +import logging + +denoise_logger = logging.getLogger(__name__) + +from cellpose import transforms, resnet_torch, utils, io +from cellpose.core import run_net +from cellpose.resnet_torch import CPnet +from cellpose.models import CellposeModel, model_path, normalize_default, assign_device, check_mkl + +MODEL_NAMES = [] +for ctype in ["cyto3", "cyto2", "nuclei"]: + for ntype in ["denoise", "deblur", "upsample", "oneclick"]: + MODEL_NAMES.append(f"{ntype}_{ctype}") + if ctype != "cyto3": + for ltype in ["per", "seg", "rec"]: + MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}") + if ctype != "cyto3": + MODEL_NAMES.append(f"aniso_{ctype}") + +criterion = nn.MSELoss(reduction="mean") +criterion2 = nn.BCEWithLogitsLoss(reduction="mean") + + +def deterministic(seed=0): + """ set random seeds to create test data """ + import random + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) # Numpy module. + random.seed(seed) # Python random module. + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def loss_fn_rec(lbl, y): + """ loss function between true labels lbl and prediction y """ + loss = 80. * criterion(y, lbl) + return loss + + +def loss_fn_seg(lbl, y): + """ loss function between true labels lbl and prediction y """ + veci = 5. * lbl[:, 1:] + lbl = (lbl[:, 0] > .5).float() + loss = criterion(y[:, :2], veci) + loss /= 2. + loss2 = criterion2(y[:, 2], lbl) + loss = loss + loss2 + return loss + + +def get_sigma(Tdown): + """ Calculates the correlation matrices across channels for the perceptual loss. + + Args: + Tdown (list): List of tensors output by each downsampling block of network. + + Returns: + list: List of correlations for each input tensor. + """ + Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown] + Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm] + Sigma = [ + torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1]) + for x in Tnorm + ] + return Sigma + + +def imstats(X, net1): + """ + Calculates the image correlation matrices for the perceptual loss. + + Args: + X (torch.Tensor): Input image tensor. + net1: Cellpose net. + + Returns: + list: A list of tensors of correlation matrices. + """ + _, _, Tdown = net1(X) + Sigma = get_sigma(Tdown) + Sigma = [x.detach() for x in Sigma] + return Sigma + + +def loss_fn_per(img, net1, yl): + """ + Calculates the perceptual loss function for image restoration. + + Args: + img (torch.Tensor): Input image tensor (noisy/blurry/downsampled). + net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net). + yl (torch.Tensor): Clean image tensor. + + Returns: + torch.Tensor: Mean perceptual loss. + """ + Sigma = imstats(img, net1) + sd = [x.std((1, 2)) + 1e-6 for x in Sigma] + Sigma_test = get_sigma(yl) + losses = torch.zeros(len(Sigma[0]), device=img.device) + for k in range(len(Sigma)): + losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2) + return losses.mean() + + +def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]): + """ + Calculates the test loss for image restoration tasks. + + Args: + net0 (torch.nn.Module): The image restoration network. + X (torch.Tensor): The input image tensor. + net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None. + img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None. + lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None. + lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.]. + + Returns: + tuple: A tuple containing the total loss and the perceptual loss. + """ + net0.eval() + if net1 is not None: + net1.eval() + loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device) + + with torch.no_grad(): + img_dn = net0(X)[0] + if lam[2] > 0.: + loss += lam[2] * loss_fn_rec(img, img_dn) + if lam[1] > 0. or lam[0] > 0.: + y, _, ydown = net1(img_dn) + if lam[1] > 0.: + loss += lam[1] * loss_fn_seg(lbl, y) + if lam[0] > 0.: + loss_per = loss_fn_per(img, net1, ydown) + loss += lam[0] * loss_per + return loss, loss_per + + +def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]): + """ + Calculates the train loss for image restoration tasks. + + Args: + net0 (torch.nn.Module): The image restoration network. + X (torch.Tensor): The input image tensor. + net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None. + img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None. + lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None. + lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.]. + + Returns: + tuple: A tuple containing the total loss and the perceptual loss. + """ + net0.train() + if net1 is not None: + net1.eval() + loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device) + + img_dn = net0(X)[0] + if lam[2] > 0.: + loss += lam[2] * loss_fn_rec(img, img_dn) + if lam[1] > 0. or lam[0] > 0.: + y, _, ydown = net1(img_dn) + if lam[1] > 0.: + loss += lam[1] * loss_fn_seg(lbl, y) + if lam[0] > 0.: + loss_per = loss_fn_per(img, net1, ydown) + loss += lam[0] * loss_per + return loss, loss_per + + +def img_norm(imgi): + """ + Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles. + + Args: + imgi (torch.Tensor): Input image tensor. + + Returns: + torch.Tensor: Normalized image tensor. + """ + shape = imgi.shape + imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1) + perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1, + keepdim=True) + for k in range(imgi.shape[1]): + hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3 + imgi[hask, k] -= perc[0, hask, k] + imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k]) + imgi = imgi.reshape(shape) + return imgi + + +def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7, + ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None, + ds=None, uniform_blur=False, partial_blur=False): + """Adds noise to the input image. + + Args: + lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx). + alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4. + beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7. + poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7. + blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7. + gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0. + downsample (float, optional): The probability of downsampling the image. Defaults to 0.7. + ds_max (int, optional): The maximum downsampling factor. Defaults to 7. + diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None. + pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None. + iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True. + sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None. + sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None. + ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None. + + Returns: + torch.Tensor: The noisy image tensor of the same shape as the input image. + """ + device = lbl.device + imgi = torch.zeros_like(lbl) + Ly, Lx = lbl.shape[-2:] + + diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device) + #ds0 = 1 if ds is None else ds.item() + ds = ds * torch.ones( + (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds + + # downsample + ii = [] + idownsample = np.random.rand(len(lbl)) < downsample + if (ds is None and idownsample.sum() > 0.) or not iso: + ds = torch.ones(len(lbl), dtype=torch.long, device=device) + ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), + device=device) + ii = torch.nonzero(ds > 1).flatten() + elif ds is not None and (ds > 1).sum(): + ii = torch.nonzero(ds > 1).flatten() + + # add gaussian blur + iblur = torch.rand(len(lbl), device=device) < blur + iblur[ii] = True + if iblur.sum() > 0: + if sigma0 is None: + if uniform_blur and iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = ds[ii].float() / 2. / gblur + sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) + sigma1 = sigma0.clone() + elif not iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = (ds[ii].float()) / gblur + xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35 + xr[ii] = torch.clip(xr[ii], 0.05, 1.5) + sigma0 = diams[iblur] / 30. * gblur * xr[iblur] + sigma1 = sigma0.clone() / 10. + else: + xrand = np.random.exponential(1, size=iblur.sum()) + xrand = np.clip(xrand * 0.5, 0.1, 1.0) + xrand *= gblur + sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( + device) + sigma1 = sigma0.clone() + else: + sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) + sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device) + + # create gaussian filter + xr = max(8, sigma0.max().long() * 2) + gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 / + (2 * sigma0.unsqueeze(-1)**2)) + gfilt0 /= gfilt0.sum(axis=-1, keepdims=True) + gfilt1 = torch.zeros_like(gfilt0) + gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0] + gfilt1[sigma1 != sigma0] = torch.exp( + -torch.arange(-xr + 1, xr, device=device)**2 / + (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2)) + gfilt1[sigma1 == 0] = 0. + gfilt1[sigma1 == 0, xr] = 1. + gfilt1 /= gfilt1.sum(axis=-1, keepdims=True) + gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1) + gfilt /= gfilt.sum(axis=(1, 2), keepdims=True) + + lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), + padding=gfilt.shape[-1] // 2, + groups=gfilt.shape[0]).transpose(1, 0) + if partial_blur: + #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100) + imgi[iblur] = lbl[iblur].clone() + Lxc = int(Lx * 0.85) + ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32), + torch.arange(0, Lxc, dtype=torch.float32), + indexing="ij") + mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2)) + mask -= mask.min() + mask /= mask.max() + lbl_blur_crop = lbl_blur[:, :, :, :Lxc] + imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask + + (1-mask) * imgi[iblur, :, :, :Lxc]) + else: + imgi[iblur] = lbl_blur + + imgi[~iblur] = lbl[~iblur] + + # apply downsample + for k in ii: + i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]] + imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear") + + # add poisson noise + ipoisson = np.random.rand(len(lbl)) < poisson + if ipoisson.sum() > 0: + if pscale is None: + pscale = torch.zeros(len(lbl)) + m = torch.distributions.gamma.Gamma(alpha, beta) + pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.) + #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5) + pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device) + else: + pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device) + imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson]) + imgi[~ipoisson] = imgi[~ipoisson] + + # renormalize + imgi = img_norm(imgi) + + return imgi + + +def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, + downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, + ds_max=7, uniform_blur=False, iso=True, rotate=True, + device=torch.device("cuda"), xy=(224, 224), + nchan_noise=1, keep_raw=True): + """ + Applies random rotation, resizing, and noise to the input data. + + Args: + data (numpy.ndarray): The input data. + labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None. + diams (float, optional): The diameter of the objects. Defaults to None. + poisson (float, optional): The Poisson noise probability. Defaults to 0.7. + blur (float, optional): The blur probability. Defaults to 0.7. + downsample (float, optional): The downsample probability. Defaults to 0.0. + beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7. + gblur (float, optional): The Gaussian blur level. Defaults to 1.0. + diam_mean (float, optional): The mean diameter. Defaults to 30. + ds_max (int, optional): The maximum downsample value. Defaults to 7. + iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True. + rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True. + device (torch.device, optional): The device to use. Defaults to torch.device("cuda"). + xy (tuple, optional): The size of the output image. Defaults to (224, 224). + nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1. + keep_raw (bool, optional): Whether to keep the raw image. Defaults to True. + + Returns: + torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image. + torch.Tensor: The augmented labels. + float: The scale factor applied to the image. + """ + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + diams = 30 if diams is None else diams + random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1)) + random_rsc = diams / random_diam #/ random_diam + #rsc /= random_scale + xy0 = (340, 340) + nchan = data[0].shape[0] + data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32") + labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32") + for i in range( + len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)): + sc = random_rsc[i] + img = data[i] + lbl = labels[i] if labels is not None else None + # create affine transform to resize + Ly, Lx = img.shape[-2:] + dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]])) + dxy = (np.random.rand(2,) - .5) * dxy + cc = np.array([Lx / 2, Ly / 2]) + cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy + pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])]) + pts2 = np.float32( + [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc]) + M = cv2.getAffineTransform(pts1, pts2) + + # apply to image + for c in range(nchan): + img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR) + #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0) + data_new[i, c] = img_rsz + if keep_raw: + data_new[i, c + nchan] = img_rsz + + if lbl is not None: + # apply to labels + labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST) + labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR) + labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR) + + rsc = random_diam / diam_mean + + # add noise before augmentations + img = torch.from_numpy(data_new).to(device) + img = torch.clamp(img, 0.) + # just add noise to cyto if nchan_noise=1 + img[:, :nchan_noise] = add_noise( + img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso, + downsample=downsample, beta=beta, gblur=gblur, + diams=torch.from_numpy(random_diam).to(device).float()) + # img -= img.mean(dim=(-2,-1), keepdim=True) + # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3 + img = img.cpu().numpy() + + # augmentations + img, lbl, scale = transforms.random_rotate_and_resize( + img, + Y=labels_new, + xy=xy, + rotate=False if not iso else rotate, + #(iso and downsample==0), + rescale=rsc, + scale_range=0.5) + img = torch.from_numpy(img).to(device) + lbl = torch.from_numpy(lbl).to(device) + + return img, lbl, scale + + +def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None): + """ + Creates a Cellpose network with a single input channel. + + Args: + device (str): The device to run the network on. + model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2". + pretrained_model (str, optional): The path to a pretrained model file. Defaults to None. + + Returns: + torch.nn.Module: The Cellpose network with a single input channel. + """ + if pretrained_model is not None and not os.path.exists(pretrained_model): + model_type = pretrained_model + pretrained_model = None + nbase = [32, 64, 128, 256] + nchan = 1 + net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device) + filename = model_path(model_type, + 0) if pretrained_model is None else pretrained_model + weights = torch.load(filename, weights_only=True) + zp = 0 + print(filename) + for name in net1.state_dict(): + if ("res_down_0.conv.conv_0" not in name and + #"output" not in name and + "res_down_0.proj" not in name and name != "diam_mean" and + name != "diam_labels"): + net1.state_dict()[name].copy_(weights[name]) + elif "res_down_0" in name: + if len(weights[name].shape) > 0: + new_weight = torch.zeros_like(net1.state_dict()[name]) + if weights[name].shape[0] == 2: + new_weight[:] = weights[name][0] + elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2: + new_weight[:, zp] = weights[name][:, 0] + else: + new_weight = weights[name] + else: + new_weight = weights[name] + net1.state_dict()[name].copy_(new_weight) + return net1 + + +class CellposeDenoiseModel(): + """ model to run Cellpose and Image restoration """ + + def __init__(self, gpu=False, pretrained_model=False, model_type=None, + restore_type="denoise_cyto3", nchan=2, + chan2_restore=False, device=None): + + self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore, + device=device) + self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan, + pretrained_model=pretrained_model, device=device) + + def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, + normalize=True, rescale=None, diameter=None, tile_overlap=0.1, + augment=False, resample=True, invert=False, flow_threshold=0.4, + cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, + min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0): + """ + Restore array or list of images using the image restoration model, and then segment. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + channels (list, optional): list of channels, either of length 2 or of length number of images by 2. + First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). + Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to segment grayscale images, input [0,0]. To segment images with cells + in green and nuclei in blue, input [2,3]. To segment one grayscale image and one + image with cells in green and nuclei in blue, input [[0,0], [2,3]]. + Defaults to None. + channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. + if None, channels dimension is attempted to be automatically determined. Defaults to None. + z_axis (int, optional): z axis in element of list x, or of np.ndarray x. + if None, z dimension is attempted to be automatically determined. Defaults to None. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + diameter (float, optional): diameter for each image, + if diameter is None, set to diam_mean or diam_train if available. Defaults to None. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False. + resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True. + invert (bool, optional): invert image pixel intensity before running network. Defaults to False. + flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4. + cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0. + do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False. + anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. + stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0. + min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15. + flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0. + niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None. + interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True. + + Returns: + A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels; + flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration; + styles: style vector summarizing each image of size 256; + imgs: Restored images. + """ + + if isinstance(normalize, dict): + normalize_params = {**normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = normalize_default + normalize_params["normalize"] = normalize + normalize_params["invert"] = invert + + img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels, + channel_axis=channel_axis, z_axis=z_axis, + do_3D=do_3D, + normalize=normalize_params, rescale=rescale, + diameter=diameter, + tile_overlap=tile_overlap, bsize=bsize) + + # turn off special normalization for segmentation + normalize_params = normalize_default + + # change channels for segmentation + if channels is not None: + channels_new = [0, 0] if channels[0] == 0 else [1, 2] + else: + channels_new = None + # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean) + diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter + masks, flows, styles = self.cp.eval( + img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1, + z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None, + normalize=normalize_params, rescale=rescale, diameter=diameter, + tile_overlap=tile_overlap, augment=augment, resample=resample, + invert=invert, flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy, + stitch_threshold=stitch_threshold, min_size=min_size, niter=niter, + interp=interp, bsize=bsize) + + return masks, flows, styles, img_restore + + +class DenoiseModel(): + """ + DenoiseModel class for denoising images using Cellpose denoising model. + + Args: + gpu (bool, optional): Whether to use GPU for computation. Defaults to False. + pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising. + Can be a string or path. Defaults to False. + nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1. + model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None. + chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False. + diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0. + device (torch.device, optional): Device to use for computation. Defaults to None. + + Attributes: + nchan (int): Number of channels in the input images. + diam_mean (float): Mean diameter of the objects in the images. + net (CPnet): Cellpose network for denoising. + pretrained_model (bool or str or Path): Pretrained model path to use for denoising. + net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable. + net_type (str): Type of the denoising network. + + Methods: + eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None, + normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1) + Denoise array or list of images using the denoising model. + + _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True, + tile_overlap=0.1) + Run denoising model on a single channel. + """ + + def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, + chan2=False, diam_mean=30., device=None): + self.nchan = nchan + if pretrained_model and (not isinstance(pretrained_model, str) and + not isinstance(pretrained_model, Path)): + raise ValueError("pretrained_model must be a string or path") + + self.diam_mean = diam_mean + builtin = True + if model_type is not None or (pretrained_model and + not os.path.exists(pretrained_model)): + pretrained_model_string = model_type if model_type is not None else "denoise_cyto3" + if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): + pretrained_model_string = "denoise_cyto3" + pretrained_model = model_path(pretrained_model_string) + if (pretrained_model and not os.path.exists(pretrained_model)): + denoise_logger.warning("pretrained model has incorrect path") + denoise_logger.info(f">> {pretrained_model_string} << model set to be used") + self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30. + else: + if pretrained_model: + builtin = False + pretrained_model_string = pretrained_model + denoise_logger.info(f">>>> loading model {pretrained_model_string}") + + # assign network device + self.mkldnn = None + if device is None: + sdevice, gpu = assign_device(use_torch=True, gpu=gpu) + self.device = device if device is not None else sdevice + if device is not None: + device_gpu = self.device.type == "cuda" + self.gpu = gpu if device is None else device_gpu + if not self.gpu: + self.mkldnn = check_mkl(True) + + # create network + self.nchan = nchan + self.nclasses = 1 + nbase = [32, 64, 128, 256] + self.nchan = nchan + self.nbase = [nchan, *nbase] + + self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, + max_pool=True, diam_mean=diam_mean).to(self.device) + + self.pretrained_model = pretrained_model + self.net_chan2 = None + if self.pretrained_model: + self.net.load_model(self.pretrained_model, device=self.device) + denoise_logger.info( + f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)" + ) + if chan2 and builtin: + chan2_path = model_path( + os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei") + print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}") + self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3, + mkldnn=self.mkldnn, max_pool=True, + diam_mean=17.).to(self.device) + self.net_chan2.load_model(chan2_path, device=self.device) + self.net_type = "cellpose_denoise" + + def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, + normalize=True, rescale=None, diameter=None, tile=True, do_3D=False, + tile_overlap=0.1, bsize=224): + """ + Restore array or list of images using the image restoration model. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + channels (list, optional): list of channels, either of length 2 or of length number of images by 2. + First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). + Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to segment grayscale images, input [0,0]. To segment images with cells + in green and nuclei in blue, input [2,3]. To segment one grayscale image and one + image with cells in green and nuclei in blue, input [[0,0], [2,3]]. + Defaults to None. + channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. + if None, channels dimension is attempted to be automatically determined. Defaults to None. + z_axis (int, optional): z axis in element of list x, or of np.ndarray x. + if None, z dimension is attempted to be automatically determined. Defaults to None. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + diameter (float, optional): diameter for each image, + if diameter is None, set to diam_mean or diam_train if available. Defaults to None. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + + Returns: + list: A list of 2D/3D arrays of restored images + + """ + if isinstance(x, list) or x.squeeze().ndim == 5: + tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO) + nimg = len(x) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + imgs = [] + for i in iterator: + imgi = self.eval( + x[i], batch_size=batch_size, + channels=channels[i] if channels is not None and + ((len(channels) == len(x) and + (isinstance(channels[i], list) or + isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2)) + else channels, channel_axis=channel_axis, z_axis=z_axis, + normalize=normalize, + do_3D=do_3D, + rescale=rescale[i] if isinstance(rescale, list) or + isinstance(rescale, np.ndarray) else rescale, + diameter=diameter[i] if isinstance(diameter, list) or + isinstance(diameter, np.ndarray) else diameter, + tile_overlap=tile_overlap, bsize=bsize) + imgs.append(imgi) + if isinstance(x, np.ndarray): + imgs = np.array(imgs) + return imgs + + else: + # reshape image + x = transforms.convert_image(x, channels, channel_axis=channel_axis, + z_axis=z_axis, do_3D=do_3D, nchan=None) + if x.ndim < 4: + squeeze = True + x = x[np.newaxis, ...] + else: + squeeze = False + + # may need to interpolate image before running upsampling + self.ratio = 1. + if "upsample" in self.pretrained_model: + Ly, Lx = x.shape[-3:-1] + if diameter is not None and 3 <= diameter < self.diam_mean: + self.ratio = self.diam_mean / diameter + denoise_logger.info( + f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)" + ) + Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio) + x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr) + else: + denoise_logger.warning( + f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}" + ) + #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}") + + self.batch_size = batch_size + + if diameter is not None and diameter > 0: + rescale = self.diam_mean / diameter + elif rescale is None: + rescale = 1.0 + + if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0): + x = x[..., :1] + + for c in range(x.shape[-1]): + rescale0 = rescale * 30. / 17. if c == 1 else rescale + if c == 0 or self.net_chan2 is None: + x[..., + c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size, + normalize=normalize, rescale=rescale0, + tile_overlap=tile_overlap, bsize=bsize)[...,0] + else: + x[..., + c] = self._eval(self.net_chan2, x[..., + c:c + 1], batch_size=batch_size, + normalize=normalize, rescale=rescale0, + tile_overlap=tile_overlap, bsize=bsize)[...,0] + x = x[0] if squeeze else x + return x + + def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, + tile_overlap=0.1, bsize=224): + """ + Run image restoration model on a single channel. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + + Returns: + list: A list of 2D/3D arrays of restored images + + """ + if isinstance(normalize, dict): + normalize_params = {**normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = normalize_default + normalize_params["normalize"] = normalize + + tic = time.time() + shape = x.shape + nimg = shape[0] + + do_normalization = True if normalize_params["normalize"] else False + + img = np.asarray(x) + if do_normalization: + img = transforms.normalize_img(img, **normalize_params) + if rescale != 1.0: + img = transforms.resize_image(img, rsz=rescale) + yf, style = run_net(self.net, img, bsize=bsize, + tile_overlap=tile_overlap) + yf = transforms.resize_image(yf, shape[1], shape[2]) + imgs = yf + del yf, style + + # imgs = np.zeros((*x.shape[:-1], 1), np.float32) + # for i in iterator: + # img = np.asarray(x[i]) + # if do_normalization: + # img = transforms.normalize_img(img, **normalize_params) + # if rescale != 1.0: + # img = transforms.resize_image(img, rsz=[rescale, rescale]) + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # yf, style = run_net(net, img, batch_size=batch_size, augment=False, + # tile=tile, tile_overlap=tile_overlap, bsize=bsize) + # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) + + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # imgs[i] = img + # del yf, style + net_time = time.time() - tic + if nimg > 1: + denoise_logger.info("imgs denoised in %2.2fs" % (net_time)) + + return imgs + + +def train(net, train_data=None, train_labels=None, train_files=None, test_data=None, + test_labels=None, test_files=None, train_probs=None, test_probs=None, + lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None, + save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, + iso=True, uniform_blur=False, downsample=0., ds_max=7, + learning_rate=0.005, n_epochs=500, + weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, + nimg_test_per_epoch=None, model_name=None): + + # net properties + device = net.device + nchan = net.nchan + diam_mean = net.diam_mean.item() + + args = np.array([poisson, beta, blur, gblur, downsample]) + if args.ndim == 1: + args = args[:, np.newaxis] + poisson, beta, blur, gblur, downsample = args + nnoise = len(poisson) + + d = datetime.datetime.now() + if save_path is not None: + if model_name is None: + filename = "" + lstrs = ["per", "seg", "rec"] + for k, (l, s) in enumerate(zip(lam, lstrs)): + filename += f"{s}_{l:.2f}_" + if not iso: + filename += "aniso_" + if poisson.sum() > 0: + filename += "poisson_" + if blur.sum() > 0: + filename += "blur_" + if downsample.sum() > 0: + filename += "downsample_" + filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") + filename = os.path.join(save_path, filename) + else: + filename = os.path.join(save_path, model_name) + print(filename) + for i in range(len(poisson)): + denoise_logger.info( + f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}" + ) + net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type) + + learning_rate_const = learning_rate + LR = np.linspace(0, learning_rate_const, 10) + LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100)) + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(10)) + learning_rate = LR + + batch_size = 8 + optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0], + weight_decay=weight_decay) + if train_data is not None: + nimg = len(train_data) + diam_train = np.array( + [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))]) + diam_train[diam_train < 5] = 5. + if test_data is not None: + diam_test = np.array( + [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))]) + diam_test[diam_test < 5] = 5. + nimg_test = len(test_data) + else: + nimg = len(train_files) + denoise_logger.info(">>> using files instead of loading dataset") + train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files] + denoise_logger.info(">>> computing diameters") + diam_train = np.array([ + utils.diameters(io.imread(train_labels_files[k])[0])[0] + for k in trange(len(train_labels_files)) + ]) + diam_train[diam_train < 5] = 5. + if test_files is not None: + nimg_test = len(test_files) + test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files] + diam_test = np.array([ + utils.diameters(io.imread(test_labels_files[k])[0])[0] + for k in trange(len(test_labels_files)) + ]) + diam_test[diam_test < 5] = 5. + train_probs = 1. / nimg * np.ones(nimg, + "float64") if train_probs is None else train_probs + if test_files is not None or test_data is not None: + test_probs = 1. / nimg_test * np.ones( + nimg_test, "float64") if test_probs is None else test_probs + + tic = time.time() + + nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch + if test_files is not None or test_data is not None: + nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch + + nbatch = 0 + train_losses, test_losses = [], [] + for iepoch in range(n_epochs): + np.random.seed(iepoch) + rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), + p=train_probs) + torch.manual_seed(iepoch) + np.random.seed(iepoch) + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate[iepoch] + lavg, lavg_per, nsum = 0, 0, 0 + for ibatch in range(0, nimg_per_epoch, batch_size * nnoise): + inds = rperm[ibatch : ibatch + batch_size * nnoise] + if train_data is None: + imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds] + lbls = [io.imread(train_labels_files[i])[1:] for i in inds] + else: + imgs = [train_data[i][:nchan] for i in inds] + lbls = [train_labels[i][1:] for i in inds] + #inoise = nbatch % nnoise + rnoise = np.random.permutation(nnoise) + for i, inoise in enumerate(rnoise): + if i * batch_size < len(imgs): + imgi, lbli, scale = random_rotate_and_resize_noise( + imgs[i * batch_size : (i + 1) * batch_size], + lbls[i * batch_size : (i + 1) * batch_size], + diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(), + poisson=poisson[inoise], + beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, + downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, + device=device) + if i == 0: + img = imgi + lbl = lbli + else: + img = torch.cat((img, imgi), axis=0) + lbl = torch.cat((lbl, lbli), axis=0) + + if nnoise > 0: + iperm = np.random.permutation(img.shape[0]) + img, lbl = img[iperm], lbl[iperm] + + for i in range(nnoise): + optimizer.zero_grad() + imgi = img[i * batch_size: (i + 1) * batch_size] + lbli = lbl[i * batch_size: (i + 1) * batch_size] + if imgi.shape[0] > 0: + loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, + img=imgi[:, nchan:], lbl=lbli, lam=lam) + loss.backward() + optimizer.step() + lavg += loss.item() * imgi.shape[0] + lavg_per += loss_per.item() * imgi.shape[0] + + nsum += len(img) + nbatch += 1 + + if iepoch % 5 == 0 or iepoch < 10: + lavg = lavg / nsum + lavg_per = lavg_per / nsum + if test_data is not None or test_files is not None: + lavgt, nsum = 0., 0 + np.random.seed(42) + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + inoise = iepoch % nnoise + torch.manual_seed(inoise) + for ibatch in range(0, nimg_test_per_epoch, batch_size): + inds = rperm[ibatch:ibatch + batch_size] + if test_data is None: + imgs = [ + np.maximum(0, + io.imread(test_files[i])[:nchan]) for i in inds + ] + lbls = [io.imread(test_labels_files[i])[1:] for i in inds] + else: + imgs = [test_data[i][:nchan] for i in inds] + lbls = [test_labels[i][1:] for i in inds] + img, lbl, scale = random_rotate_and_resize_noise( + imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise], + beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise], + iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, device=device) + loss, loss_per = test_loss(net, img[:, :nchan], net1=net1, + img=img[:, nchan:], lbl=lbl, lam=lam) + + lavgt += loss.item() * img.shape[0] + nsum += len(img) + lavgt = lavgt / nsum + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f" + % (iepoch, time.time() - tic, lavg, lavg_per, lavgt, + learning_rate[iepoch])) + test_losses.append(lavgt) + else: + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % + (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) + train_losses.append(lavg) + + if save_path is not None: + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): + if save_each: #separate files as model progresses + filename0 = str(filename) + f"_epoch_{iepoch:%04d}" + else: + filename0 = filename + denoise_logger.info(f"saving network parameters to {filename0}") + net.save_model(filename0) + else: + filename = save_path + + return filename, train_losses, test_losses + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="cellpose parameters") + + input_img_args = parser.add_argument_group("input image arguments") + input_img_args.add_argument("--dir", default=[], type=str, + help="folder containing data to run or train on.") + input_img_args.add_argument("--img_filter", default=[], type=str, + help="end string for images to run on") + + model_args = parser.add_argument_group("model arguments") + model_args.add_argument("--pretrained_model", default=[], type=str, + help="pretrained denoising model") + + training_args = parser.add_argument_group("training arguments") + training_args.add_argument("--test_dir", default=[], type=str, + help="folder containing test data (optional)") + training_args.add_argument("--file_list", default=[], type=str, + help="npy file containing list of train and test files") + training_args.add_argument("--seg_model_type", default="cyto2", type=str, + help="model to use for seg training loss") + training_args.add_argument( + "--noise_type", default=[], type=str, + help="noise type to use (if input, then other noise params are ignored)") + training_args.add_argument("--poisson", default=0.8, type=float, + help="fraction of images to add poisson noise to") + training_args.add_argument("--beta", default=0.7, type=float, + help="scale of poisson noise") + training_args.add_argument("--blur", default=0., type=float, + help="fraction of images to blur") + training_args.add_argument("--gblur", default=1.0, type=float, + help="scale of gaussian blurring stddev") + training_args.add_argument("--downsample", default=0., type=float, + help="fraction of images to downsample") + training_args.add_argument("--ds_max", default=7, type=int, + help="max downsampling factor") + training_args.add_argument("--lam_per", default=1.0, type=float, + help="weighting of perceptual loss") + training_args.add_argument("--lam_seg", default=1.5, type=float, + help="weighting of segmentation loss") + training_args.add_argument("--lam_rec", default=0., type=float, + help="weighting of reconstruction loss") + training_args.add_argument( + "--diam_mean", default=30., type=float, help= + "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0" + ) + training_args.add_argument("--learning_rate", default=0.001, type=float, + help="learning rate. Default: %(default)s") + training_args.add_argument("--n_epochs", default=2000, type=int, + help="number of epochs. Default: %(default)s") + training_args.add_argument( + "--save_each", default=False, action="store_true", + help="save each epoch as separate model") + training_args.add_argument( + "--nimg_per_epoch", default=0, type=int, + help="number of images per epoch. Default is length of training images") + training_args.add_argument( + "--nimg_test_per_epoch", default=0, type=int, + help="number of test images per epoch. Default is length of testing images") + + io.logger_setup() + + args = parser.parse_args() + lams = [args.lam_per, args.lam_seg, args.lam_rec] + print("lam", lams) + + if len(args.noise_type) > 0: + noise_type = args.noise_type + uniform_blur = False + iso = True + if noise_type == "poisson": + poisson = 0.8 + blur = 0. + downsample = 0. + beta = 0.7 + gblur = 1.0 + elif noise_type == "blur_expr": + poisson = 0.8 + blur = 0.8 + downsample = 0. + beta = 0.1 + gblur = 0.5 + elif noise_type == "blur": + poisson = 0.8 + blur = 0.8 + downsample = 0. + beta = 0.1 + gblur = 10.0 + uniform_blur = True + elif noise_type == "downsample_expr": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.03 + gblur = 1.0 + elif noise_type == "downsample": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.03 + gblur = 5.0 + uniform_blur = True + elif noise_type == "all": + poisson = [0.8, 0.8, 0.8] + blur = [0., 0.8, 0.8] + downsample = [0., 0., 0.8] + beta = [0.7, 0.1, 0.03] + gblur = [0., 10.0, 5.0] + uniform_blur = True + elif noise_type == "aniso": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.1 + gblur = args.ds_max * 1.5 + iso = False + else: + raise ValueError(f"{noise_type} noise_type is not supported") + else: + poisson, beta = args.poisson, args.beta + blur, gblur = args.blur, args.gblur + downsample = args.downsample + + pretrained_model = None if len( + args.pretrained_model) == 0 else args.pretrained_model + model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean, + pretrained_model=pretrained_model) + + train_data, labels, train_files, train_probs = None, None, None, None + test_data, test_labels, test_files, test_probs = None, None, None, None + if len(args.file_list) == 0: + output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0) + images, labels, image_names, test_images, test_labels, image_names_test = output + train_data = [] + for i in range(len(images)): + img = images[i].astype("float32") + if img.ndim > 2: + img = img[0] + train_data.append( + np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :]) + if len(args.test_dir) > 0: + test_data = [] + for i in range(len(test_images)): + img = test_images[i].astype("float32") + if img.ndim > 2: + img = img[0] + test_data.append( + np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :]) + save_path = os.path.join(args.dir, "../models/") + else: + root = args.dir + denoise_logger.info( + ">>> using file_list (assumes images are normalized and have flows!)") + dat = np.load(args.file_list, allow_pickle=True).item() + train_files = dat["train_files"] + test_files = dat["test_files"] + train_probs = dat["train_probs"] if "train_probs" in dat else None + test_probs = dat["test_probs"] if "test_probs" in dat else None + if str(train_files[0])[:len(str(root))] != str(root): + for i in range(len(train_files)): + new_path = root / Path(*train_files[i].parts[-3:]) + if i == 0: + print(f"changing path from {train_files[i]} to {new_path}") + train_files[i] = new_path + + for i in range(len(test_files)): + new_path = root / Path(*test_files[i].parts[-3:]) + test_files[i] = new_path + save_path = os.path.join(args.dir, "models/") + + os.makedirs(save_path, exist_ok=True) + + nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch + nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch + + model_path = train( + model.net, train_data=train_data, train_labels=labels, train_files=train_files, + test_data=test_data, test_labels=test_labels, test_files=test_files, + train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta, + blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, + iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs, + learning_rate=args.learning_rate, + lam=lams, + seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, + nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path) + + +def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None, + poisson=0.8, blur=0.0, downsample=0.0, save_path=None, + save_every=100, save_each=False, learning_rate=0.2, n_epochs=500, + momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8, + nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False, + model_name=None): + """ train function uses loss function model.loss_fn in models.py + + (data should already be normalized) + + """ + + d = datetime.datetime.now() + + model.n_epochs = n_epochs + if isinstance(learning_rate, (list, np.ndarray)): + if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1: + raise ValueError("learning_rate.ndim must equal 1") + elif len(learning_rate) != n_epochs: + raise ValueError( + "if learning_rate given as list or np.ndarray it must have length n_epochs" + ) + model.learning_rate = learning_rate + model.learning_rate_const = mode(learning_rate)[0][0] + else: + model.learning_rate_const = learning_rate + # set learning rate schedule + if SGD: + LR = np.linspace(0, model.learning_rate_const, 10) + if model.n_epochs > 250: + LR = np.append( + LR, model.learning_rate_const * np.ones(model.n_epochs - 100)) + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(10)) + else: + LR = np.append( + LR, + model.learning_rate_const * np.ones(max(0, model.n_epochs - 10))) + else: + LR = model.learning_rate_const * np.ones(model.n_epochs) + model.learning_rate = LR + + model.batch_size = batch_size + model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD) + model._set_criterion() + + nimg = len(train_data) + + # compute average cell diameter + if diameter is None: + diam_train = np.array( + [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))]) + diam_train_mean = diam_train[diam_train > 0].mean() + model.diam_labels = diam_train_mean + if rescale: + diam_train[diam_train < 5] = 5. + if test_data is not None: + diam_test = np.array([ + utils.diameters(test_labels[k][0])[0] + for k in range(len(test_labels)) + ]) + diam_test[diam_test < 5] = 5. + denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean) + elif rescale: + diam_train_mean = diameter + model.diam_labels = diameter + denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean) + diam_train = diameter * np.ones(len(train_labels), "float32") + if test_data is not None: + diam_test = diameter * np.ones(len(test_labels), "float32") + + denoise_logger.info( + f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}" + ) + model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean + + nchan = train_data[0].shape[0] + denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan) + denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" % + (model.learning_rate_const, model.batch_size, weight_decay)) + + if test_data is not None: + denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}") + else: + denoise_logger.info(f">>>> ntrain = {nimg}") + + tic = time.time() + + lavg, nsum = 0, 0 + + if save_path is not None: + _, file_label = os.path.split(save_path) + file_path = os.path.join(save_path, "models/") + + if not os.path.exists(file_path): + os.makedirs(file_path) + else: + denoise_logger.warning("WARNING: no save_path given, model not saving") + + ksave = 0 + + # cannot train with mkldnn + model.net.mkldnn = False + + # get indices for each epoch for training + np.random.seed(0) + inds_all = np.zeros((0,), "int32") + if nimg_per_epoch is None or nimg > nimg_per_epoch: + nimg_per_epoch = nimg + denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}") + while len(inds_all) < n_epochs * nimg_per_epoch: + rperm = np.random.permutation(nimg) + inds_all = np.hstack((inds_all, rperm)) + + for iepoch in range(model.n_epochs): + if SGD: + model._set_learning_rate(model.learning_rate[iepoch]) + np.random.seed(iepoch) + rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch] + for ibatch in range(0, nimg_per_epoch, batch_size): + inds = rperm[ibatch:ibatch + batch_size] + imgi, lbl, scale = random_rotate_and_resize_noise( + [train_data[i] for i in inds], [train_labels[i][1:] for i in inds], + poisson=poisson, blur=blur, downsample=downsample, + diams=diam_train[inds], diam_mean=model.diam_mean) + imgi = imgi[:, :1] # keep noisy only + if z_masking: + nc = imgi.shape[1] + nb = imgi.shape[0] + ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint( + nc // 2 - 1, size=nb)) + ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint( + nc // 2 - 1, size=nb)) + for b in range(nb): + imgi[b, :ncmin[b]] = 0 + imgi[b, ncmax[b]:] = 0 + + train_loss = model._train_step(imgi, lbl) + lavg += train_loss + nsum += len(imgi) + + if iepoch % 10 == 0 or iepoch == 5: + lavg = lavg / nsum + if test_data is not None: + lavgt, nsum = 0., 0 + np.random.seed(42) + rperm = np.arange(0, len(test_data), 1, int) + for ibatch in range(0, len(test_data), batch_size): + inds = rperm[ibatch:ibatch + batch_size] + imgi, lbl, scale = random_rotate_and_resize_noise( + [test_data[i] for i in inds], + [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur, + downsample=downsample, diams=diam_test[inds], + diam_mean=model.diam_mean) + imgi = imgi[:, :1] # keep noisy only + test_loss = model._test_eval(imgi, lbl) + lavgt += test_loss + nsum += len(imgi) + + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" % + (iepoch, time.time() - tic, lavg, lavgt / nsum, + model.learning_rate[iepoch])) + else: + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" % + (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch])) + + lavg, nsum = 0, 0 + + if save_path is not None: + if iepoch == model.n_epochs - 1 or iepoch % save_every == 1: + # save model at the end + if save_each: #separate files as model progresses + if model_name is None: + filename = "{}_{}_{}_{}".format( + model.net_type, file_label, + d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch)) + else: + filename = "{}_{}".format(model_name, "epoch_" + str(iepoch)) + else: + if model_name is None: + filename = "{}_{}_{}".format(model.net_type, file_label, + d.strftime("%Y_%m_%d_%H_%M_%S.%f")) + else: + filename = model_name + filename = os.path.join(file_path, filename) + ksave += 1 + denoise_logger.info(f"saving network parameters to {filename}") + model.net.save_model(filename) + else: + filename = save_path + + # reset to mkldnn if available + model.net.mkldnn = model.mkldnn + return filename diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py new file mode 100644 index 0000000000000000000000000000000000000000..db163ded6e76eac60ad8a2a563d7505083899f2b --- /dev/null +++ b/cellpose/dynamics.py @@ -0,0 +1,1041 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import time, os +from scipy.ndimage import maximum_filter1d, find_objects, center_of_mass +import torch +import numpy as np +import tifffile +from tqdm import trange +from numba import njit, prange, float32, int32, vectorize +import cv2 +import fastremap + +import logging + +dynamics_logger = logging.getLogger(__name__) + +from . import utils, metrics, transforms + +import torch +from torch import optim, nn +import torch.nn.functional as F +from . import resnet_torch + +@njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True) +def _extend_centers(T, y, x, ymed, xmed, Lx, niter): + """Run diffusion from the center of the mask on the mask pixels. + + Args: + T (numpy.ndarray): Array of shape (Ly * Lx) where diffusion is run. + y (numpy.ndarray): Array of y-coordinates of pixels inside the mask. + x (numpy.ndarray): Array of x-coordinates of pixels inside the mask. + ymed (int): Center of the mask in the y-coordinate. + xmed (int): Center of the mask in the x-coordinate. + Lx (int): Size of the x-dimension of the masks. + niter (int): Number of iterations to run diffusion. + + Returns: + numpy.ndarray: Array of shape (Ly * Lx) representing the amount of diffused particles at each pixel. + """ + for t in range(niter): + T[ymed * Lx + xmed] += 1 + T[y * Lx + + x] = 1 / 9. * (T[y * Lx + x] + T[(y - 1) * Lx + x] + T[(y + 1) * Lx + x] + + T[y * Lx + x - 1] + T[y * Lx + x + 1] + + T[(y - 1) * Lx + x - 1] + T[(y - 1) * Lx + x + 1] + + T[(y + 1) * Lx + x - 1] + T[(y + 1) * Lx + x + 1]) + return T + + +def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, + device=torch.device("cpu")): + """Runs diffusion on GPU to generate flows for training images or quality control. + + Args: + neighbors (torch.Tensor): 9 x pixels in masks. + meds (torch.Tensor): Mask centers. + isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels. + shape (tuple): Shape of the tensor. + n_iter (int, optional): Number of iterations. Defaults to 200. + device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu"). + + Returns: + torch.Tensor: Generated flows. + + """ + if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps": + T = torch.zeros(shape, dtype=torch.float, device=device) + else: + T = torch.zeros(shape, dtype=torch.double, device=device) + + for i in range(n_iter): + T[tuple(meds.T)] += 1 + Tneigh = T[tuple(neighbors)] + Tneigh *= isneighbor + T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0) + del meds, isneighbor, Tneigh + + if T.ndim == 2: + grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]] + del neighbors + dy = grads[0] - grads[1] + dx = grads[2] - grads[3] + del grads + mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2) + else: + grads = T[tuple(neighbors[:, 1:])] + del neighbors + dz = grads[0] - grads[1] + dy = grads[2] - grads[3] + dx = grads[4] - grads[5] + del grads + mu_torch = np.stack( + (dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2) + return mu_torch + +@njit(nogil=True) +def get_centers(masks, slices): + """ + Get the centers of the masks and their extents. + + Args: + masks (ndarray): The labeled masks. + slices (ndarray): The slices of the masks. + + Returns: + A tuple containing the centers of the masks and the extents of the masks. + """ + centers = np.zeros((len(slices), 2), "int32") + ext = np.zeros((len(slices),), "int32") + for p in prange(len(slices)): + si = slices[p] + i = si[0] + sr, sc = si[1:3], si[3:5] + # find center in slice around mask + yi, xi = np.nonzero(masks[sr[0]:sr[-1], sc[0]:sc[-1]] == (i + 1)) + ymed = yi.mean() + xmed = xi.mean() + # center is closest point to (ymed, xmed) within mask + imin = ((xi - xmed)**2 + (yi - ymed)**2).argmin() + ymed = yi[imin] + sr[0] + xmed = xi[imin] + sc[0] + centers[p] = np.array([ymed, xmed]) + ext[p] = (sr[-1] - sr[0]) + (sc[-1] - sc[0]) + 2 + return centers, ext + + +def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None): + """Convert masks to flows using diffusion from center pixel. + + Center of masks where diffusion starts is defined by pixel closest to median within the mask. + + Args: + masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels. + device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu"). + niter (int, optional): Number of iterations for the diffusion process. Defaults to None. + + Returns: + np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y. + + + Returns: + A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY. + meds_p are cell centers. + """ + if device is None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + Ly0, Lx0 = masks.shape + Ly, Lx = Ly0 + 2, Lx0 + 2 + + masks_padded = torch.from_numpy(masks.astype("int64")).to(device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) + shape = masks_padded.shape + + ### get mask pixel neighbors + y, x = torch.nonzero(masks_padded, as_tuple=True) + y = y.int() + x = x.int() + neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.long, device=device) + yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]] + for i in range(9): + neighbors[0, i] = y + yxi[0][i] + neighbors[1, i] = x + yxi[1][i] + isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device) + m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]] + for i in range(1, 9): + isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0 + del m0, masks_padded + + ### get center-of-mass within cell + slices = find_objects(masks) + # turn slices into array + slices = np.array([ + np.array([i, si[0].start, si[0].stop, si[1].start, si[1].stop]) + for i, si in enumerate(slices) + if si is not None + ]) + centers, ext = get_centers(masks, slices) + meds_p = torch.from_numpy(centers).to(device).long() + meds_p += 1 # for padding + + ### run diffusion + n_iter = 2 * ext.max() if niter is None else niter + mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter, + device=device) + mu = mu.astype("float64") + + # new normalization + mu /= (1e-60 + (mu**2).sum(axis=0)**0.5) + + # put into original image + mu0 = np.zeros((2, Ly0, Lx0)) + mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu + + return mu0, meds_p.cpu().numpy() - 1 + + +def masks_to_flows_gpu_3d(masks, device=None, niter=None): + """Convert masks to flows using diffusion from center pixel. + + Args: + masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels. + device (torch.device, optional): The device to run the computation on. Defaults to None. + niter (int, optional): Number of iterations for the diffusion process. Defaults to None. + + Returns: + np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y. + + """ + if device is None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + Lz0, Ly0, Lx0 = masks.shape + Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2 + + masks_padded = torch.from_numpy(masks.astype("int64")).to(device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1)) + + # get mask pixel neighbors + z, y, x = torch.nonzero(masks_padded).T + neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z)) + neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0) + neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0) + + neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0) + + # get mask centers + slices = find_objects(masks) + + centers = np.zeros((masks.max(), 3), "int") + for i, si in enumerate(slices): + if si is not None: + sz, sy, sx = si + #lz, ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1 + zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1)) + zi = zi.astype(np.int32) + 1 # add padding + yi = yi.astype(np.int32) + 1 # add padding + xi = xi.astype(np.int32) + 1 # add padding + zmed = np.mean(zi) + ymed = np.mean(yi) + xmed = np.mean(xi) + imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2) + zmed = zi[imin] + ymed = yi[imin] + xmed = xi[imin] + centers[i, 0] = zmed + sz.start + centers[i, 1] = ymed + sy.start + centers[i, 2] = xmed + sx.start + + # get neighbor validator (not all neighbors are in same mask) + neighbor_masks = masks_padded[tuple(neighbors)] + isneighbor = neighbor_masks == neighbor_masks[0] + ext = np.array( + [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1] + for sz, sy, sx in slices]) + n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter + + # run diffusion + shape = masks_padded.shape + mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter, + device=device) + # normalize + mu /= (1e-60 + (mu**2).sum(axis=0)**0.5) + + # put into original image + mu0 = np.zeros((3, Lz0, Ly0, Lx0)) + mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu + return mu0 + + +def masks_to_flows_cpu(masks, niter=None, device=None): + """Convert masks to flows using diffusion from center pixel. + + Center of masks where diffusion starts is defined to be the closest pixel to the mean of all pixels that is inside the mask. + Result of diffusion is converted into flows by computing the gradients of the diffusion density map. + + Args: + masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels + niter (int, optional): Number of iterations for computing flows. Defaults to None. + + Returns: + A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY. + meds_p are cell centers. + """ + Ly, Lx = masks.shape + mu = np.zeros((2, Ly, Lx), np.float64) + + slices = find_objects(masks) + meds = [] + for i in prange(len(slices)): + si = slices[i] + if si is not None: + sr, sc = si + ly, lx = sr.stop - sr.start + 2, sc.stop - sc.start + 2 + ### get center-of-mass within cell + y, x = np.nonzero(masks[sr, sc] == (i + 1)) + y = y.astype(np.int32) + 1 + x = x.astype(np.int32) + 1 + ymed = y.mean() + xmed = x.mean() + imin = ((x - xmed)**2 + (y - ymed)**2).argmin() + xmed = x[imin] + ymed = y[imin] + + n_iter = 2 * np.int32(ly + lx) if niter is None else niter + T = np.zeros((ly) * (lx), np.float64) + T = _extend_centers(T, y, x, ymed, xmed, np.int32(lx), np.int32(n_iter)) + dy = T[(y + 1) * lx + x] - T[(y - 1) * lx + x] + dx = T[y * lx + x + 1] - T[y * lx + x - 1] + mu[:, sr.start + y - 1, sc.start + x - 1] = np.stack((dy, dx)) + meds.append([ymed - 1, xmed - 1]) + + # new normalization + mu /= (1e-60 + (mu**2).sum(axis=0)**0.5) + + return mu, meds + + +def masks_to_flows(masks, device=torch.device("cpu"), niter=None): + """Convert masks to flows using diffusion from center pixel. + + Center of masks where diffusion starts is defined to be the closest pixel to the mean of all pixels that is inside the mask. + Result of diffusion is converted into flows by computing the gradients of the diffusion density map. + + Args: + masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels + + Returns: + np.ndarray: mu is float 3D or 4D array of flows in (Z)XY. + """ + if masks.max() == 0: + dynamics_logger.warning("empty masks!") + return np.zeros((2, *masks.shape), "float32") + + if device.type == "cuda" or device.type == "mps": + masks_to_flows_device = masks_to_flows_gpu + else: + masks_to_flows_device = masks_to_flows_cpu + + if masks.ndim == 3: + Lz, Ly, Lx = masks.shape + mu = np.zeros((3, Lz, Ly, Lx), np.float32) + for z in range(Lz): + mu0 = masks_to_flows_device(masks[z], device=device, niter=niter)[0] + mu[[1, 2], z] += mu0 + for y in range(Ly): + mu0 = masks_to_flows_device(masks[:, y], device=device, niter=niter)[0] + mu[[0, 2], :, y] += mu0 + for x in range(Lx): + mu0 = masks_to_flows_device(masks[:, :, x], device=device, niter=niter)[0] + mu[[0, 1], :, :, x] += mu0 + return mu + elif masks.ndim == 2: + mu, mu_c = masks_to_flows_device(masks, device=device, niter=niter) + return mu + + else: + raise ValueError("masks_to_flows only takes 2D or 3D arrays") + + +def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None, + return_flows=True): + """Converts labels (list of masks or flows) to flows for training model. + + Args: + labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx], + it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D) + is used to create flows and cell probabilities. + files (list of str, optional): The files to save the flows to. If provided, flows are saved to + files to be reused. Defaults to None. + device (str, optional): The device to use for computation. Defaults to None. + redo_flows (bool, optional): Whether to recompute the flows. Defaults to False. + niter (int, optional): The number of iterations for computing flows. Defaults to None. + + Returns: + list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k], + flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow, + and flows[k][4] is heat distribution. + """ + nimg = len(labels) + if labels[0].ndim < 3: + labels = [labels[n][np.newaxis, :, :] for n in range(nimg)] + + flows = [] + # flows need to be recomputed + if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: + dynamics_logger.info("computing flows for labels") + + # compute flows; labels are fixed here to be unique, so they need to be passed back + # make sure labels are unique! + labels = [fastremap.renumber(label, in_place=True)[0] for label in labels] + iterator = trange if nimg > 1 else range + for n in iterator(nimg): + labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0] + vecn = masks_to_flows(labels[n][0].astype(int), device=device, niter=niter) + + # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations) + flow = np.concatenate((labels[n], labels[n] > 0.5, vecn), + axis=0).astype(np.float32) + if files is not None: + file_name = os.path.splitext(files[n])[0] + tifffile.imwrite((file_name + "_flows.tif").replace("/y/", "/flows/"), flow) + if return_flows: + flows.append(flow) + else: + dynamics_logger.info("flows precomputed") + if return_flows: + flows = [labels[n].astype(np.float32) for n in range(nimg)] + return flows + + +@njit([ + "(int16[:,:,:], float32[:], float32[:], float32[:,:])", + "(float32[:,:,:], float32[:], float32[:], float32[:,:])" +], cache=True) +def map_coordinates(I, yc, xc, Y): + """ + Bilinear interpolation of image "I" in-place with y-coordinates yc and x-coordinates xc to Y. + + Args: + I (numpy.ndarray): Input image of shape (C, Ly, Lx). + yc (numpy.ndarray): New y-coordinates. + xc (numpy.ndarray): New x-coordinates. + Y (numpy.ndarray): Output array of shape (C, ni). + + Returns: + None + """ + C, Ly, Lx = I.shape + yc_floor = yc.astype(np.int32) + xc_floor = xc.astype(np.int32) + yc = yc - yc_floor + xc = xc - xc_floor + for i in range(yc_floor.shape[0]): + yf = min(Ly - 1, max(0, yc_floor[i])) + xf = min(Lx - 1, max(0, xc_floor[i])) + yf1 = min(Ly - 1, yf + 1) + xf1 = min(Lx - 1, xf + 1) + y = yc[i] + x = xc[i] + for c in range(C): + Y[c, i] = (np.float32(I[c, yf, xf]) * (1 - y) * (1 - x) + + np.float32(I[c, yf, xf1]) * (1 - y) * x + + np.float32(I[c, yf1, xf]) * y * (1 - x) + + np.float32(I[c, yf1, xf1]) * y * x) + + +def steps_interp(dP, inds, niter, device=torch.device("cpu")): + """ Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values. + + Euler integration of dynamics dP for niter steps. + + Args: + p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations. + dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field. + niter (int): Number of iterations to perform. + device (torch.device, optional): Device to use for computation. Defaults to None. + + Returns: + numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations. + + Raises: + None + + """ + + shape = dP.shape[1:] + ndim = len(shape) + if (device.type == "cuda" or device.type == "mps") or ndim==3: + pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device) + im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device) + # Y and X dimensions, flipped X-1, Y-1 + # pt is [1 1 1 3 n_points] + for n in range(ndim): + if ndim==3: + pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32) + else: + pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32) + im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32) + shape = np.array(shape)[::-1].astype("float") - 1 + + # normalize pt between 0 and 1, normalize the flow + for k in range(ndim): + im[:, k] *= 2. / shape[k] + pt[..., k] /= shape[k] + + # normalize to between -1 and 1 + pt *= 2 + pt -= 1 + + # dynamics + for t in range(niter): + dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False) + for k in range(ndim): #clamp the final pixel locations + pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.) + + #undo the normalization from before, reverse order of operations + pt += 1 + pt *= 0.5 + for k in range(ndim): + pt[..., k] *= shape[k] + + if ndim==3: + pt = pt[..., [2, 1, 0]].squeeze() + pt = pt.unsqueeze(0) if pt.ndim==1 else pt + return pt.T + else: + pt = pt[..., [1, 0]].squeeze() + pt = pt.unsqueeze(0) if pt.ndim==1 else pt + return pt.T + + else: + p = np.zeros((ndim, len(inds[0])), "float32") + for n in range(ndim): + p[n] = inds[n] + dPt = np.zeros(p.shape, "float32") + for t in range(niter): + map_coordinates(dP, p[0], p[1], dPt) + for k in range(len(p)): + p[k] = np.minimum(shape[k] - 1, np.maximum(0, p[k] + dPt[k])) + return p + +@njit("(float32[:,:],float32[:,:,:,:], int32)", nogil=True) +def steps3D(p, dP, niter): + """ Run dynamics of pixels to recover masks in 3D. + + Euler integration of dynamics dP for niter steps. + + Args: + p (np.ndarray): Pixels with cellprob > cellprob_threshold [3 x npts]. + dP (np.ndarray): Flows [3 x Lz x Ly x Lx]. + niter (int): Number of iterations of dynamics to run. + + Returns: + np.ndarray: Final locations of each pixel after dynamics. + """ + shape = dP.shape[1:] + for t in range(niter): + for j in range(p.shape[1]): + p0, p1, p2 = int(p[0, j]), int(p[1, j]), int(p[2, j]) + step = dP[:, p0, p1, p2] + for k in range(3): + p[k, j] = min(shape[k] - 1, max(0, p[k, j] + step[k])) + return p + +@njit("(float32[:,:], float32[:,:,:], int32)", nogil=True) +def steps2D(p, dP, niter): + """Run dynamics of pixels to recover masks in 2D. + + Euler integration of dynamics dP for niter steps. + + Args: + p (np.ndarray): Pixels with cellprob > cellprob_threshold [2 x npts]. + dP (np.ndarray): Flows [2 x Ly x Lx]. + niter (int): Number of iterations of dynamics to run. + + Returns: + np.ndarray: Final locations of each pixel after dynamics. + """ + shape = dP.shape[1:] + for t in range(niter): + for j in range(p.shape[1]): + # starting coordinates + p0, p1 = int(p[0, j]), int(p[1, j]) + step = dP[:, p0, p1] + for k in range(p.shape[0]): + p[k, j] = min(shape[k] - 1, max(0, p[k, j] + step[k])) + return p + +def follow_flows(dP, inds, niter=200, interp=True, device=torch.device("cpu")): + """ Run dynamics to recover masks in 2D or 3D. + + Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability + are used (as defined by inds). + + Args: + dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes. + niter (int, optional): Number of iterations of dynamics to run. Default is 200. + interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True. + device (torch.device, optional): Device to use for computation. Default is None. + + Returns: + A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]; + inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + """ + shape = np.array(dP.shape[1:]).astype(np.int32) + ndim = len(inds) + niter = np.uint32(niter) + + if interp: + p = steps_interp(dP, inds, niter, device=device) + else: + p = np.zeros((ndim, len(inds[0])), "float32") + for n in range(ndim): + p[n] = inds[n] + steps_fcn = steps2D if ndim == 2 else steps3D + p = steps_fcn(p, dP, niter) + + return p + + +def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")): + """Remove masks which have inconsistent flows. + + Uses metrics.flow_error to compute flows from predicted masks + and compare flows to predicted flows from the network. Discards + masks with flow errors greater than the threshold. + + Args: + masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels, + size [Ly x Lx] or [Lz x Ly x Lx]. + flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + threshold (float, optional): Masks with flow error greater than threshold are discarded. + Default is 0.4. + + Returns: + masks (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. + """ + device0 = device + if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"): + + major_version, minor_version = torch.__version__.split(".")[:2] + torch.cuda.empty_cache() + if major_version == "1" and int(minor_version) < 10: + # for PyTorch version lower than 1.10 + def mem_info(): + total_mem = torch.cuda.get_device_properties(device0.index).total_memory + used_mem = torch.cuda.memory_allocated(device0.index) + free_mem = total_mem - used_mem + return total_mem, free_mem + else: + # for PyTorch version 1.10 and above + def mem_info(): + free_mem, total_mem = torch.cuda.mem_get_info(device0.index) + return total_mem, free_mem + total_mem, free_mem = mem_info() + if masks.size * 32 > free_mem: + dynamics_logger.warning( + "WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold" + ) + dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow") + device0 = torch.device("cpu") + + merrors, _ = metrics.flow_error(masks, flows, device0) + badi = 1 + (merrors > threshold).nonzero()[0] + masks[np.isin(masks, badi)] = 0 + return masks + + +def max_pool3d(h, kernel_size=5): + """ memory efficient max_pool thanks to Mark Kittisopikul + + for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3 + + """ + _, nd, ny, nx = h.shape + m = h.clone().detach() + kruns, k0 = kernel_size // 2, 1 + for k in range(kruns): + for d in range(-k0, k0+1): + for y in range(-k0, k0+1): + for x in range(-k0, k0+1): + mv = m[:, max(-d,0):min(nd-d,nd), max(-y,0):min(ny-y,ny), max(-x,0):min(nx-x,nx)] + hv = h[:, max(d,0):min(nd+d,nd), max(y,0):min(ny+y,ny), max(x,0):min(nx+x,nx)] + torch.maximum(mv, hv, out=mv) + return m + +def max_pool2d(h, kernel_size=5): + """ memory efficient max_pool thanks to Mark Kittisopikul """ + _, ny, nx = h.shape + m = h.clone().detach() + k0 = kernel_size // 2 + for y in range(-k0, k0+1): + for x in range(-k0, k0+1): + mv = m[:, max(-y,0):min(ny-y,ny), max(-x,0):min(nx-x,nx)] + hv = h[:, max(y,0):min(ny+y,ny), max(x,0):min(nx+x,nx)] + torch.maximum(mv, hv, out=mv) + return m + +def max_pool1d(h, kernel_size=5, axis=1, out=None): + """ memory efficient max_pool thanks to Mark Kittisopikul + + for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3 + + """ + if out is None: + out = h.clone() + else: + out.copy_(h) + + nd = h.shape[axis] + k0 = kernel_size // 2 + for d in range(-k0, k0+1): + if axis==1: + mv = out[:, max(-d,0):min(nd-d,nd)] + hv = h[:, max(d,0):min(nd+d,nd)] + elif axis==2: + mv = out[:, :, max(-d,0):min(nd-d,nd)] + hv = h[:, :, max(d,0):min(nd+d,nd)] + elif axis==3: + mv = out[:, :, :, max(-d,0):min(nd-d,nd)] + hv = h[:, :, :, max(d,0):min(nd+d,nd)] + torch.maximum(mv, hv, out=mv) + return out + +def max_pool_nd(h, kernel_size=5): + """ memory efficient max_pool in 2d or 3d """ + ndim = h.ndim - 1 + hmax = max_pool1d(h, kernel_size=kernel_size, axis=1) + hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2) + if ndim==2: + del hmax + return hmax2 + else: + hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax) + del hmax2 + return hmax + +# from torch.nn.functional import max_pool2d +def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4): + """Create masks using pixel convergence after running dynamics. + + Makes a histogram of final pixel locations p, initializes masks + at peaks of histogram and extends the masks from the peaks so that + they include all pixels with more than 2 final pixels p. Discards + masks with flow errors greater than the threshold. + + Parameters: + p (float32, 3D or 4D array): Final locations of each pixel after dynamics, + size [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are + iscell False to stay in their original location. + rpad (int, optional): Histogram edge padding. Default is 20. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + + Returns: + M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. + """ + + ndim = len(shape0) + device = pt.device + + rpad = 20 + pt += rpad + pt = torch.clamp(pt, min=0) + for i in range(len(pt)): + pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1) + + # # add extra padding to make divisible by 5 + # shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int)) + shape = tuple(np.array(shape0) + 2*rpad) + + # sparse coo torch + coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int), + shape) + h1 = coo.to_dense() + del coo + + hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5) + hmax1 = hmax1.squeeze() + seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10)) + del hmax1 + if len(seeds1) == 0: + dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.") + return np.zeros(shape0, dtype="uint16") + + npts = h1[tuple(seeds1.T)] + isort1 = npts.argsort() + seeds1 = seeds1[isort1] + + n_seeds = len(seeds1) + h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device) + for k in range(n_seeds): + slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)]) + h_slc[k] = h1[slc] + del h1 + seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device) + if ndim==2: + seed_masks[:,5,5] = 1 + else: + seed_masks[:,5,5,5] = 1 + + for iter in range(5): + # extend + seed_masks = max_pool_nd(seed_masks, kernel_size=3) + seed_masks *= h_slc > 2 + del h_slc + seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T) + for k in range(n_seeds)] + del seed_masks + + dtype = torch.int32 if n_seeds < 2**16 else torch.int64 + M1 = torch.zeros(shape, dtype=dtype, device=device) + for k in range(n_seeds): + M1[seeds_new[k]] = 1 + k + + M1 = M1[tuple(pt.long())] + M1 = M1.cpu().numpy() + + dtype = "uint16" if n_seeds < 2**16 else "uint32" + M0 = np.zeros(shape0, dtype=dtype) + M0[inds] = M1 + + # remove big masks + uniq, counts = fastremap.unique(M0, return_counts=True) + big = np.prod(shape0) * max_size_fraction + bigc = uniq[counts > big] + if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0): + M0 = fastremap.mask(M0, bigc) + fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels + M0 = M0.reshape(tuple(shape0)) + + #print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb") + return M0 + + +def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0, + flow_threshold=0.4, interp=True, do_3D=False, min_size=15, + max_size_fraction=0.4, resize=None, device=torch.device("cpu")): + """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None. + + Args: + dP (numpy.ndarray): The dynamics flow field array. + cellprob (numpy.ndarray): The cell probability array. + p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None + niter (int, optional): The number of iterations for mask computation. Defaults to 200. + cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0. + flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4. + interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. + do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. + min_size (int, optional): The minimum size of the masks. Defaults to 15. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + resize (tuple, optional): The desired size for resizing the masks. Defaults to None. + device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu"). + + Returns: + tuple: A tuple containing the computed masks and the final pixel locations. + """ + mask = compute_masks(dP, cellprob, niter=niter, + cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, + max_size_fraction=max_size_fraction, + device=device) + + if resize is not None: + if len(resize) == 2: + mask = transforms.resize_image(mask, resize[0], resize[1], no_channels=True, + interpolation=cv2.INTER_NEAREST) + else: + Lz, Ly, Lx = resize + if mask.shape[0] != Lz or mask.shape[1] != Ly: + dynamics_logger.info("resizing 3D masks to original image size") + if mask.shape[1] != Ly: + mask = transforms.resize_image(mask, Ly=Ly, Lx=Lx, + no_channels=True, + interpolation=cv2.INTER_NEAREST) + if mask.shape[0] != Lz: + mask = transforms.resize_image(mask.transpose(1,0,2), + Ly=Lz, Lx=Lx, + no_channels=True, + interpolation=cv2.INTER_NEAREST).transpose(1,0,2) + + mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size) + + return mask + +def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, + flow_threshold=0.4, interp=True, do_3D=False, min_size=-1, + max_size_fraction=0.4, device=torch.device("cpu")): + """Compute masks using dynamics from dP and cellprob. + + Args: + dP (numpy.ndarray): The dynamics flow field array. + cellprob (numpy.ndarray): The cell probability array. + p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None + niter (int, optional): The number of iterations for mask computation. Defaults to 200. + cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0. + flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4. + interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. + do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. + min_size (int, optional): The minimum size of the masks. Defaults to 15. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu"). + + Returns: + tuple: A tuple containing the computed masks and the final pixel locations. + """ + + if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels + inds = np.nonzero(cellprob > cellprob_threshold) + if len(inds[0]) == 0: + dynamics_logger.info("No cell pixels found.") + shape = cellprob.shape + mask = np.zeros(shape, "uint16") + return mask + + p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5., + inds=inds, niter=niter, interp=interp, + device=device) + if not torch.is_tensor(p_final): + p_final = torch.from_numpy(p_final).to(device, dtype=torch.int) + else: + p_final = p_final.int() + # calculate masks + if device.type == "mps": + p_final = p_final.to(torch.device("cpu")) + mask = get_masks_torch(p_final, inds, dP.shape[1:], + max_size_fraction=max_size_fraction) + del p_final + # flow thresholding factored out of get_masks + if not do_3D: + if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0: + # make sure labels are unique at output of get_masks + mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold, + device=device) + + if mask.max() < 2**16 and mask.dtype != "uint16": + mask = mask.astype("uint16") + + else: # nothing to compute, just make it compatible + dynamics_logger.info("No cell pixels found.") + shape = cellprob.shape + mask = np.zeros(cellprob.shape, "uint16") + return mask + + if min_size > 0: + mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size) + + if mask.dtype == np.uint32: + dynamics_logger.warning( + "more than 65535 masks in image, masks returned as np.uint32") + + return mask + +def get_masks_orig(p, iscell=None, rpad=20, max_size_fraction=0.4): + """Create masks using pixel convergence after running dynamics. + + Original implementation on CPU with histogramdd + (histogramdd uses excessive memory with large images) + + Makes a histogram of final pixel locations p, initializes masks + at peaks of histogram and extends the masks from the peaks so that + they include all pixels with more than 2 final pixels p. Discards + masks with flow errors greater than the threshold. + + Parameters: + p (float32, 3D or 4D array): Final locations of each pixel after dynamics, + size [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are + iscell False to stay in their original location. + rpad (int, optional): Histogram edge padding. Default is 20. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + + Returns: + M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. + """ + pflows = [] + edges = [] + shape0 = p.shape[1:] + dims = len(p) + if iscell is not None: + if dims == 3: + inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]), + np.arange(shape0[2]), indexing="ij") + elif dims == 2: + inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]), + indexing="ij") + for i in range(dims): + p[i, ~iscell] = inds[i][~iscell] + + for i in range(dims): + pflows.append(p[i].flatten().astype("int32")) + edges.append(np.arange(-.5 - rpad, shape0[i] + .5 + rpad, 1)) + + h, _ = np.histogramdd(tuple(pflows), bins=edges) + hmax = h.copy() + for i in range(dims): + hmax = maximum_filter1d(hmax, 5, axis=i) + + seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10)) + Nmax = h[seeds] + isort = np.argsort(Nmax)[::-1] + for s in seeds: + s[:] = s[isort] + + pix = list(np.array(seeds).T) + + shape = h.shape + if dims == 3: + expand = np.nonzero(np.ones((3, 3, 3))) + else: + expand = np.nonzero(np.ones((3, 3))) + + for iter in range(5): + for k in range(len(pix)): + if iter == 0: + pix[k] = list(pix[k]) + newpix = [] + iin = [] + for i, e in enumerate(expand): + epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1 + epix = epix.flatten() + iin.append(np.logical_and(epix >= 0, epix < shape[i])) + newpix.append(epix) + iin = np.all(tuple(iin), axis=0) + for p in newpix: + p = p[iin] + newpix = tuple(newpix) + igood = h[newpix] > 2 + for i in range(dims): + pix[k][i] = newpix[i][igood] + if iter == 4: + pix[k] = tuple(pix[k]) + + M = np.zeros(h.shape, np.uint32) + for k in range(len(pix)): + M[pix[k]] = 1 + k + + for i in range(dims): + pflows[i] = pflows[i] + rpad + M0 = M[tuple(pflows)] + + # remove big masks + uniq, counts = fastremap.unique(M0, return_counts=True) + big = np.prod(shape0) * max_size_fraction + bigc = uniq[counts > big] + if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0): + M0 = fastremap.mask(M0, bigc) + fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels + M0 = np.reshape(M0, shape0) + return M0 diff --git a/cellpose/export.py b/cellpose/export.py new file mode 100644 index 0000000000000000000000000000000000000000..6b240620206054310577a3ffa4a631070aa1fe56 --- /dev/null +++ b/cellpose/export.py @@ -0,0 +1,410 @@ +"""Auxiliary module for bioimageio format export + +Example usage: + +```bash +#!/bin/bash + +# Define default paths and parameters +DEFAULT_CHANNELS="1 0" +DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995" +DEFAULT_PATH_README="/home/qinyu/models/cp/README.md" +DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg" +DEFAULT_MODEL_ID="philosophical-panda" +DEFAULT_MODEL_ICON="🐼" +DEFAULT_MODEL_VERSION="0.1.0" +DEFAULT_MODEL_NAME="My Cool Cellpose" +DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset." +DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]' +DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]' +DEFAULT_MODEL_TAGS="cellpose 3d 2d" +DEFAULT_MODEL_LICENSE="MIT" +DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear" + +# Run the Python script with default parameters +python export.py \ + --channels $DEFAULT_CHANNELS \ + --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \ + --path_readme "$DEFAULT_PATH_README" \ + --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \ + --model_version "$DEFAULT_MODEL_VERSION" \ + --model_name "$DEFAULT_MODEL_NAME" \ + --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \ + --model_authors "$DEFAULT_MODEL_AUTHORS" \ + --model_cite "$DEFAULT_MODEL_CITE" \ + --model_tags $DEFAULT_MODEL_TAGS \ + --model_license "$DEFAULT_MODEL_LICENSE" \ + --model_repo "$DEFAULT_MODEL_REPO" +``` +""" + +import os +import sys +import json +import argparse +from pathlib import Path +from urllib.parse import urlparse + +import torch +import numpy as np + +from cellpose.io import imread +from cellpose.utils import download_url_to_file +from cellpose.transforms import pad_image_ND, normalize_img, convert_image +from cellpose.resnet_torch import CPnetBioImageIO + +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromFileDescr, + Author, + AxisId, + ChannelAxis, + CiteEntry, + Doi, + FileDescr, + Identifier, + InputTensorDescr, + IntervalOrRatioDataDescr, + LicenseId, + ModelDescr, + ModelId, + OrcidId, + OutputTensorDescr, + ParameterizedSize, + PytorchStateDictWeightsDescr, + SizeReference, + SpaceInputAxis, + SpaceOutputAxis, + TensorId, + TorchscriptWeightsDescr, + Version, + WeightsDescr, +) +# Define ARBITRARY_SIZE if it is not available in the module +try: + from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE +except ImportError: + ARBITRARY_SIZE = ParameterizedSize(min=1, step=1) + +from bioimageio.spec.common import HttpUrl +from bioimageio.spec import save_bioimageio_package +from bioimageio.core import test_model + +DEFAULT_CHANNELS = [2, 1] +DEFAULT_NORMALIZE_PARAMS = { + "axis": -1, + "lowhigh": None, + "percentile": None, + "normalize": True, + "norm3D": False, + "sharpen_radius": 0, + "smooth_radius": 0, + "tile_norm_blocksize": 0, + "tile_norm_smooth3D": 1, + "invert": False, +} +IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif" + + +def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS): + """ + Download and normalize image. + """ + filename = os.path.basename(urlparse(IMAGE_URL).path) + path_image = path_dir_temp / filename + if not path_image.exists(): + sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n') + download_url_to_file(IMAGE_URL, path_image) + img = imread(path_image).astype(np.float32) + img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2) + img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS) + img = np.transpose(img, (0, 3, 1, 2)) + img, _, _ = pad_image_ND(img) + return img + + +def load_bioimageio_cpnet_model(path_model_weight, nchan=2): + cpnet_kwargs = { + "nbase": [nchan, 32, 64, 128, 256], + "nout": 3, + "sz": 3, + "mkldnn": False, + "conv_3D": False, + "max_pool": True, + } + cpnet_biio = CPnetBioImageIO(**cpnet_kwargs) + state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True) + cpnet_biio.load_state_dict(state_dict_cuda) + cpnet_biio.eval() # crucial for the prediction results + return cpnet_biio, cpnet_kwargs + + +def descr_gen_input(path_test_input, nchan=2): + input_axes = [ + SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE), + ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]), + SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)), + SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)), + ] + data_descr = IntervalOrRatioDataDescr(type="float32") + path_test_input = Path(path_test_input) + descr_input = InputTensorDescr( + id=TensorId("raw"), + axes=input_axes, + test_tensor=FileDescr(source=path_test_input), + data=data_descr, + ) + return descr_input + + +def descr_gen_output_flow(path_test_output): + output_axes_output_tensor = [ + SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), + ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]), + SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))), + SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))), + ] + path_test_output = Path(path_test_output) + descr_output = OutputTensorDescr( + id=TensorId("flow"), + axes=output_axes_output_tensor, + test_tensor=FileDescr(source=path_test_output), + ) + return descr_output + + +def descr_gen_output_downsampled(path_dir_temp, nbase=None): + if nbase is None: + nbase = [32, 64, 128, 256] + + output_axes_downsampled_tensors = [ + [ + SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), + ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]), + SpaceOutputAxis( + id=AxisId("y"), + size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")), + scale=2**offset, + ), + SpaceOutputAxis( + id=AxisId("x"), + size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")), + scale=2**offset, + ), + ] + for offset, base in enumerate(nbase) + ] + path_downsampled_tensors = [ + Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors)) + ] + descr_output_downsampled_tensors = [ + OutputTensorDescr( + id=TensorId(f"downsampled_{i}"), + axes=axes, + test_tensor=FileDescr(source=path), + ) + for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors)) + ] + return descr_output_downsampled_tensors + + +def descr_gen_output_style(path_test_style, nchannel=256): + output_axes_style_tensor = [ + SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), + ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]), + ] + path_style_tensor = Path(path_test_style) + descr_output_style_tensor = OutputTensorDescr( + id=TensorId("style"), + axes=output_axes_style_tensor, + test_tensor=FileDescr(source=path_style_tensor), + ) + return descr_output_style_tensor + + +def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None): + if path_cpnet_wrapper is None: + path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py" + pytorch_architecture = ArchitectureFromFileDescr( + callable=Identifier("CPnetBioImageIO"), + source=Path(path_cpnet_wrapper), + kwargs=cpnet_kwargs, + ) + return pytorch_architecture + + +def descr_gen_documentation(path_doc, markdown_text): + with open(path_doc, "w") as f: + f.write(markdown_text) + + +def package_to_bioimageio( + path_pretrained_model, + path_save_trace, + path_readme, + list_path_cover_images, + descr_input, + descr_output, + descr_output_downsampled_tensors, + descr_output_style_tensor, + pytorch_version, + pytorch_architecture, + model_id, + model_icon, + model_version, + model_name, + model_documentation, + model_authors, + model_cite, + model_tags, + model_license, + model_repo, +): + """Package model description to BioImage.IO format.""" + my_model_descr = ModelDescr( + id=ModelId(model_id) if model_id is not None else None, + id_emoji=model_icon, + version=Version(model_version), + name=model_name, + description=model_documentation, + authors=[ + Author( + name=author["name"], + affiliation=author["affiliation"], + github_user=author["github_user"], + orcid=OrcidId(author["orcid"]), + ) + for author in model_authors + ], + cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite], + covers=[Path(img) for img in list_path_cover_images], + license=LicenseId(model_license), + tags=model_tags, + documentation=Path(path_readme), + git_repo=HttpUrl(model_repo), + inputs=[descr_input], + outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors, + weights=WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=Path(path_pretrained_model), + architecture=pytorch_architecture, + pytorch_version=pytorch_version, + ), + torchscript=TorchscriptWeightsDescr( + source=Path(path_save_trace), + pytorch_version=pytorch_version, + parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights. + ), + ), + ) + + return my_model_descr + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose") + parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]") + parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995") + parser.add_argument("--path_readme", required=True, type=str, help="Path to README file") + parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images") + parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None) + parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None) + parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0") + parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose") + parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.") + parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'") + parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'") + parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d") + parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT") + parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL") + return parser.parse_args() + # fmt: on + + +def main(): + args = parse_args() + + # Parse user-provided paths and arguments + channels = args.channels + model_cite = json.loads(args.model_cite) + model_authors = json.loads(args.model_authors) + + path_readme = Path(args.path_readme) + path_pretrained_model = Path(args.path_pretrained_model) + list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images] + + # Auto-generated paths + path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py" + path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem + path_dir_temp.mkdir(parents=True, exist_ok=True) + + path_save_trace = path_dir_temp / "cp_traced.pt" + path_test_input = path_dir_temp / "test_input.npy" + path_test_output = path_dir_temp / "test_output.npy" + path_test_style = path_dir_temp / "test_style.npy" + path_bioimageio_package = path_dir_temp / "cellpose_model.zip" + + # Download test input image + img_np = download_and_normalize_image(path_dir_temp, channels=channels) + np.save(path_test_input, img_np) + img = torch.tensor(img_np).float() + + # Load model + cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model) + + # Test model and save output + tuple_output_tensor = cpnet_biio(img) + np.save(path_test_output, tuple_output_tensor[0].detach().numpy()) + np.save(path_test_style, tuple_output_tensor[1].detach().numpy()) + for i, t in enumerate(tuple_output_tensor[2:]): + np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy()) + + # Save traced model + model_traced = torch.jit.trace(cpnet_biio, img) + model_traced.save(path_save_trace) + + # Generate model description + descr_input = descr_gen_input(path_test_input) + descr_output = descr_gen_output_flow(path_test_output) + descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:]) + descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1]) + pytorch_version = Version(torch.__version__) + pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper) + + # Package model + my_model_descr = package_to_bioimageio( + path_pretrained_model, + path_save_trace, + path_readme, + list_path_cover_images, + descr_input, + descr_output, + descr_output_downsampled_tensors, + descr_output_style_tensor, + pytorch_version, + pytorch_architecture, + args.model_id, + args.model_icon, + args.model_version, + args.model_name, + args.model_documentation, + model_authors, + model_cite, + args.model_tags, + args.model_license, + args.model_repo, + ) + + # Test model + summary = test_model(my_model_descr, weight_format="pytorch_state_dict") + summary.display() + summary = test_model(my_model_descr, weight_format="torchscript") + summary.display() + + # Save BioImage.IO package + package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package)) + print("package path:", package_path) + + +if __name__ == "__main__": + main() diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..614b877af9402e63f045a8ee0f32a8007f0f7787 --- /dev/null +++ b/cellpose/gui/gui.py @@ -0,0 +1,2526 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import sys, os, pathlib, warnings, datetime, time, copy + +from qtpy import QtGui, QtCore +from superqt import QRangeSlider, QCollapsible +from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox +import pyqtgraph as pg + +import numpy as np +from scipy.stats import mode +import cv2 + +from . import guiparts, menus, io +from .. import models, core, dynamics, version, denoise, train +from ..utils import download_url_to_file, masks_to_outlines, diameters +from ..io import get_image_files, imsave, imread +from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img +from ..models import normalize_default +from ..plot import disk + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + +try: + from google.cloud import storage + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "key/cellpose-data-writer.json") + SERVER_UPLOAD = True +except: + SERVER_UPLOAD = False + +Horizontal = QtCore.Qt.Orientation.Horizontal + + +class Slider(QRangeSlider): + + def __init__(self, parent, name, color): + super().__init__(Horizontal) + self.setEnabled(False) + self.valueChanged.connect(lambda: self.levelChanged(parent)) + self.name = name + + self.setStyleSheet(""" QSlider{ + background-color: transparent; + } + """) + self.show() + + def levelChanged(self, parent): + parent.level_change(self.name) + + +class QHLine(QFrame): + + def __init__(self): + super(QHLine, self).__init__() + self.setFrameShape(QFrame.HLine) + #self.setFrameShadow(QFrame.Sunken) + self.setLineWidth(8) + + +def make_bwr(): + # make a bwr colormap + b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis] + r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis] + g = np.append(np.linspace(0, 255, 128), + np.linspace(0, 255, 128)[::-1])[:, np.newaxis] + color = np.concatenate((r, g, b), axis=-1).astype(np.uint8) + bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) + return bwr + + +def make_spectral(): + # make spectral colormap + r = np.array([ + 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, + 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23, + 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103, + 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167, + 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231, + 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255 + ]) + g = np.array([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3, + 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, + 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, + 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143, + 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150, + 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175, + 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201, + 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226, + 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251, + 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, + 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, + 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63, + 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41, + 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, + 189, 197, 205, 213, 222, 230, 238, 246, 254 + ]) + b = np.array([ + 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143, + 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247, + 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183, + 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124, + 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90, + 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, + 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, + 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74, + 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205, + 213, 222, 230, 238, 246, 254 + ]) + color = (np.vstack((r, g, b)).T).astype(np.uint8) + spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) + return spectral + + +def make_cmap(cm=0): + # make a single channel colormap + r = np.arange(0, 256) + color = np.zeros((256, 3)) + color[:, cm] = r + color = color.astype(np.uint8) + cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) + return cmap + + +def run(image=None): + from ..io import logger_setup + logger, log_file = logger_setup() + # Always start by initializing Qt (only once per application) + warnings.filterwarnings("ignore") + app = QApplication(sys.argv) + icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") + guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png") + if not icon_path.is_file(): + cp_dir = pathlib.Path.home().joinpath(".cellpose") + cp_dir.mkdir(exist_ok=True) + print("downloading logo") + download_url_to_file( + "https://www.cellpose.org/static/images/cellpose_transparent.png", + icon_path, progress=True) + if not guip_path.is_file(): + print("downloading help window image") + download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png", + guip_path, progress=True) + icon_path = str(icon_path.resolve()) + app_icon = QtGui.QIcon() + app_icon.addFile(icon_path, QtCore.QSize(16, 16)) + app_icon.addFile(icon_path, QtCore.QSize(24, 24)) + app_icon.addFile(icon_path, QtCore.QSize(32, 32)) + app_icon.addFile(icon_path, QtCore.QSize(48, 48)) + app_icon.addFile(icon_path, QtCore.QSize(64, 64)) + app_icon.addFile(icon_path, QtCore.QSize(256, 256)) + app.setWindowIcon(app_icon) + app.setStyle("Fusion") + app.setPalette(guiparts.DarkPalette()) + #app.setStyleSheet("QLineEdit { color: yellow }") + + # models.download_model_weights() # does not exist + MainW(image=image, logger=logger) + ret = app.exec_() + sys.exit(ret) + + +class MainW(QMainWindow): + + def __init__(self, image=None, logger=None): + super(MainW, self).__init__() + + self.logger = logger + pg.setConfigOptions(imageAxisOrder="row-major") + self.setGeometry(50, 50, 1200, 1000) + self.setWindowTitle(f"cellpose v{version}") + self.cp_path = os.path.dirname(os.path.realpath(__file__)) + app_icon = QtGui.QIcon() + icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") + icon_path = str(icon_path.resolve()) + app_icon.addFile(icon_path, QtCore.QSize(16, 16)) + app_icon.addFile(icon_path, QtCore.QSize(24, 24)) + app_icon.addFile(icon_path, QtCore.QSize(32, 32)) + app_icon.addFile(icon_path, QtCore.QSize(48, 48)) + app_icon.addFile(icon_path, QtCore.QSize(64, 64)) + app_icon.addFile(icon_path, QtCore.QSize(256, 256)) + self.setWindowIcon(app_icon) + # rgb(150,255,150) + self.setStyleSheet(guiparts.stylesheet()) + + menus.mainmenu(self) + menus.editmenu(self) + menus.modelmenu(self) + menus.helpmenu(self) + + self.stylePressed = """QPushButton {Text-align: center; + background-color: rgb(150,50,150); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + }""" + self.styleUnpressed = """QPushButton {Text-align: center; + background-color: rgb(50,50,50); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + }""" + self.loaded = False + + # ---- MAIN WIDGET LAYOUT ---- # + self.cwidget = QWidget(self) + self.lmain = QGridLayout() + self.cwidget.setLayout(self.lmain) + self.setCentralWidget(self.cwidget) + self.lmain.setVerticalSpacing(0) + self.lmain.setContentsMargins(0, 0, 0, 10) + + self.imask = 0 + self.scrollarea = QScrollArea() + self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn) + self.scrollarea.setStyleSheet("""QScrollArea { border: none }""") + self.scrollarea.setWidgetResizable(True) + self.swidget = QWidget(self) + self.scrollarea.setWidget(self.swidget) + self.l0 = QGridLayout() + self.swidget.setLayout(self.l0) + b = self.make_buttons() + self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9) + + # ---- drawing area ---- # + self.win = pg.GraphicsLayoutWidget() + + self.lmain.addWidget(self.win, 0, 9, 40, 30) + + self.win.scene().sigMouseClicked.connect(self.plot_clicked) + self.win.scene().sigMouseMoved.connect(self.mouse_moved) + self.make_viewbox() + self.lmain.setColumnStretch(10, 1) + bwrmap = make_bwr() + self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False) + self.cmap = [] + # spectral colormap + self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0, + alpha=False)) + # single channel colormaps + for i in range(3): + self.cmap.append( + make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False)) + + if MATPLOTLIB: + self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) * + 255).astype(np.uint8) + np.random.seed(42) # make colors stable + self.colormap = self.colormap[np.random.permutation(1000000)] + else: + np.random.seed(42) # make colors stable + self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype( + np.uint8) + self.NZ = 1 + self.restore = None + self.ratio = 1. + self.reset() + + # if called with image, load it + if image is not None: + self.filename = image + io._load_image(self, self.filename) + + # training settings + d = datetime.datetime.now() + self.training_params = { + "model_index": 0, + "learning_rate": 0.1, + "weight_decay": 0.0001, + "n_epochs": 100, + "SGD": True, + "model_name": "CP" + d.strftime("_%Y%m%d_%H%M%S"), + } + + self.load_3D = False + self.stitch_threshold = 0. + self.flow3D_smooth = 0. + self.anisotropy = 1. + self.min_size = 15 + self.resample = True + + self.setAcceptDrops(True) + self.win.show() + self.show() + + def help_window(self): + HW = guiparts.HelpWindow(self) + HW.show() + + def train_help_window(self): + THW = guiparts.TrainHelpWindow(self) + THW.show() + + def gui_window(self): + EG = guiparts.ExampleGUI(self) + EG.show() + + def make_buttons(self): + self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold) + self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold) + self.medfont = QtGui.QFont("Arial", 9) + self.smallfont = QtGui.QFont("Arial", 8) + + b = 0 + self.satBox = QGroupBox("Views") + self.satBox.setFont(self.boldfont) + self.satBoxG = QGridLayout() + self.satBox.setLayout(self.satBoxG) + self.l0.addWidget(self.satBox, b, 0, 1, 9) + + b0 = 0 + self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob + self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B + self.RGBDropDown = QComboBox() + self.RGBDropDown.addItems( + ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"]) + self.RGBDropDown.setFont(self.medfont) + self.RGBDropDown.currentIndexChanged.connect(self.color_choose) + self.satBoxG.addWidget(self.RGBDropDown, b0, 0, 1, 3) + + label = QLabel("

[↑ / ↓ or W/S]

") + label.setFont(self.smallfont) + self.satBoxG.addWidget(label, b0, 3, 1, 3) + label = QLabel("[R / G / B \n toggles color ]") + label.setFont(self.smallfont) + self.satBoxG.addWidget(label, b0, 6, 1, 3) + + b0 += 1 + self.ViewDropDown = QComboBox() + self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"]) + self.ViewDropDown.setFont(self.medfont) + self.ViewDropDown.model().item(3).setEnabled(False) + self.ViewDropDown.currentIndexChanged.connect(self.update_plot) + self.satBoxG.addWidget(self.ViewDropDown, b0, 0, 2, 3) + + label = QLabel("[pageup / pagedown]") + label.setFont(self.smallfont) + self.satBoxG.addWidget(label, b0, 3, 1, 5) + + b0 += 2 + label = QLabel("") + label.setToolTip( + "NOTE: manually changing the saturation bars does not affect normalization in segmentation" + ) + self.satBoxG.addWidget(label, b0, 0, 1, 5) + + self.autobtn = QCheckBox("auto-adjust saturation") + self.autobtn.setToolTip("sets scale-bars as normalized for segmentation") + self.autobtn.setFont(self.medfont) + self.autobtn.setChecked(True) + self.satBoxG.addWidget(self.autobtn, b0, 1, 1, 8) + + b0 += 1 + self.sliders = [] + colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]] + colornames = ["red", "Chartreuse", "DodgerBlue"] + names = ["red", "green", "blue"] + for r in range(3): + b0 += 1 + if r == 0: + label = QLabel('gray/
red') + else: + label = QLabel(names[r] + ":") + label.setStyleSheet(f"color: {colornames[r]}") + label.setFont(self.boldmedfont) + self.satBoxG.addWidget(label, b0, 0, 1, 2) + self.sliders.append(Slider(self, names[r], colors[r])) + self.sliders[-1].setMinimum(-.1) + self.sliders[-1].setMaximum(255.1) + self.sliders[-1].setValue([0, 255]) + self.sliders[-1].setToolTip( + "NOTE: manually changing the saturation bars does not affect normalization in segmentation" + ) + #self.sliders[-1].setTickPosition(QSlider.TicksRight) + self.satBoxG.addWidget(self.sliders[-1], b0, 2, 1, 7) + + b += 1 + self.drawBox = QGroupBox("Drawing") + self.drawBox.setFont(self.boldfont) + self.drawBoxG = QGridLayout() + self.drawBox.setLayout(self.drawBoxG) + self.l0.addWidget(self.drawBox, b, 0, 1, 9) + self.autosave = True + + b0 = 0 + self.brush_size = 3 + self.BrushChoose = QComboBox() + self.BrushChoose.addItems(["1", "3", "5", "7", "9"]) + self.BrushChoose.currentIndexChanged.connect(self.brush_choose) + self.BrushChoose.setFixedWidth(40) + self.BrushChoose.setFont(self.medfont) + self.drawBoxG.addWidget(self.BrushChoose, b0, 3, 1, 2) + label = QLabel("brush size:") + label.setFont(self.medfont) + self.drawBoxG.addWidget(label, b0, 0, 1, 3) + + b0 += 1 + # turn off masks + self.layer_off = False + self.masksOn = True + self.MCheckBox = QCheckBox("MASKS ON [X]") + self.MCheckBox.setFont(self.medfont) + self.MCheckBox.setChecked(True) + self.MCheckBox.toggled.connect(self.toggle_masks) + self.drawBoxG.addWidget(self.MCheckBox, b0, 0, 1, 5) + + b0 += 1 + # turn off outlines + self.outlinesOn = False # turn off by default + self.OCheckBox = QCheckBox("outlines on [Z]") + self.OCheckBox.setFont(self.medfont) + self.drawBoxG.addWidget(self.OCheckBox, b0, 0, 1, 5) + self.OCheckBox.setChecked(False) + self.OCheckBox.toggled.connect(self.toggle_masks) + + b0 += 1 + self.SCheckBox = QCheckBox("single stroke") + self.SCheckBox.setFont(self.medfont) + self.SCheckBox.setChecked(True) + self.SCheckBox.toggled.connect(self.autosave_on) + self.SCheckBox.setEnabled(True) + self.drawBoxG.addWidget(self.SCheckBox, b0, 0, 1, 5) + + # buttons for deleting multiple cells + self.deleteBox = QGroupBox("delete multiple ROIs") + self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)") + self.deleteBox.setFont(self.medfont) + self.deleteBoxG = QGridLayout() + self.deleteBox.setLayout(self.deleteBoxG) + self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4) + self.MakeDeletionRegionButton = QPushButton("region-select") + self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells) + self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4) + self.MakeDeletionRegionButton.setFont(self.smallfont) + self.MakeDeletionRegionButton.setFixedWidth(70) + self.DeleteMultipleROIButton = QPushButton("click-select") + self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells) + self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4) + self.DeleteMultipleROIButton.setFont(self.smallfont) + self.DeleteMultipleROIButton.setFixedWidth(70) + self.DoneDeleteMultipleROIButton = QPushButton("done") + self.DoneDeleteMultipleROIButton.clicked.connect( + self.done_remove_multiple_cells) + self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2) + self.DoneDeleteMultipleROIButton.setFont(self.smallfont) + self.DoneDeleteMultipleROIButton.setFixedWidth(35) + self.CancelDeleteMultipleROIButton = QPushButton("cancel") + self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple) + self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2) + self.CancelDeleteMultipleROIButton.setFont(self.smallfont) + self.CancelDeleteMultipleROIButton.setFixedWidth(35) + + b += 1 + b0 = 0 + self.segBox = QGroupBox("Segmentation") + self.segBoxG = QGridLayout() + self.segBox.setLayout(self.segBoxG) + self.l0.addWidget(self.segBox, b, 0, 1, 9) + self.segBox.setFont(self.boldfont) + + self.diameter = 30 + label = QLabel("diameter (pixels):") + label.setFont(self.medfont) + label.setToolTip( + 'you can manually enter the approximate diameter for your cells, \nor press β€œcalibrate” to let the model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking β€œscale disk on”)' + ) + self.segBoxG.addWidget(label, b0, 0, 1, 4) + self.Diameter = QLineEdit() + self.Diameter.setToolTip( + 'you can manually enter the approximate diameter for your cells, \nor press β€œcalibrate” to let the "cyto3" model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking β€œscale disk on”)' + ) + self.Diameter.setText(str(self.diameter)) + self.Diameter.setFont(self.medfont) + self.Diameter.returnPressed.connect(self.update_scale) + self.Diameter.setFixedWidth(50) + self.segBoxG.addWidget(self.Diameter, b0, 4, 1, 2) + + # compute diameter + self.SizeButton = QPushButton("calibrate") + self.SizeButton.setFont(self.medfont) + self.SizeButton.clicked.connect(self.calibrate_size) + self.segBoxG.addWidget(self.SizeButton, b0, 6, 1, 3) + #self.SizeButton.setFixedWidth(65) + self.SizeButton.setEnabled(False) + self.SizeButton.setToolTip( + 'you can manually enter the approximate diameter for your cells, \nor press β€œcalibrate” to let the cyto3 model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking β€œscale disk on”)' + ) + + b0 += 1 + # choose channel + self.ChannelChoose = [QComboBox(), QComboBox()] + self.ChannelChoose[0].addItems(["0: gray", "1: red", "2: green", "3: blue"]) + self.ChannelChoose[1].addItems(["0: none", "1: red", "2: green", "3: blue"]) + cstr = ["chan to segment:", "chan2 (optional): "] + for i in range(2): + self.ChannelChoose[i].setFont(self.medfont) + label = QLabel(cstr[i]) + label.setFont(self.medfont) + if i == 0: + label.setToolTip( + "this is the channel in which the cytoplasm or nuclei exist that you want to segment" + ) + self.ChannelChoose[i].setToolTip( + "this is the channel in which the cytoplasm or nuclei exist that you want to segment" + ) + else: + label.setToolTip( + "if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option" + ) + self.ChannelChoose[i].setToolTip( + "if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option" + ) + self.segBoxG.addWidget(label, b0 + i, 0, 1, 4) + self.segBoxG.addWidget(self.ChannelChoose[i], b0 + i, 4, 1, 5) + + b0 += 2 + + # use GPU + self.useGPU = QCheckBox("use GPU") + self.useGPU.setToolTip( + "if you have specially installed the cuda version of torch, then you can activate this" + ) + self.useGPU.setFont(self.medfont) + self.check_gpu() + self.segBoxG.addWidget(self.useGPU, b0, 0, 1, 3) + + # compute segmentation with general models + self.net_text = ["run cyto3"] + nett = ["cellpose super-generalist model"] + + #label = QLabel("Run:") + #label.setFont(self.boldfont) + #label.setFont(self.medfont) + #self.segBoxG.addWidget(label, b0, 0, 1, 2) + self.StyleButtons = [] + jj = 4 + for j in range(len(self.net_text)): + self.StyleButtons.append( + guiparts.ModelButton(self, self.net_text[j], self.net_text[j])) + w = 5 + self.segBoxG.addWidget(self.StyleButtons[-1], b0, jj, 1, w) + jj += w + #self.StyleButtons[-1].setFixedWidth(140) + self.StyleButtons[-1].setToolTip(nett[j]) + + b0 += 1 + self.roi_count = QLabel("0 ROIs") + self.roi_count.setFont(self.boldfont) + self.roi_count.setAlignment(QtCore.Qt.AlignLeft) + self.segBoxG.addWidget(self.roi_count, b0, 0, 1, 4) + + self.progress = QProgressBar(self) + self.segBoxG.addWidget(self.progress, b0, 4, 1, 5) + + b0 += 1 + self.segaBox = QCollapsible("additional settings") + self.segaBox.setFont(self.medfont) + self.segaBox._toggle_btn.setFont(self.medfont) + self.segaBoxG = QGridLayout() + _content = QWidget() + _content.setLayout(self.segaBoxG) + _content.setMaximumHeight(0) + _content.setMinimumHeight(0) + #_content.layout().setContentsMargins(QtCore.QMargins(0, -20, -20, -20)) + self.segaBox.setContent(_content) + self.segBoxG.addWidget(self.segaBox, b0, 0, 1, 9) + + b0 = 0 + # post-hoc paramater tuning + label = QLabel("flow\nthreshold:") + label.setToolTip( + "threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run" + ) + label.setFont(self.medfont) + self.segaBoxG.addWidget(label, b0, 0, 1, 2) + self.flow_threshold = QLineEdit() + self.flow_threshold.setText("0.4") + self.flow_threshold.returnPressed.connect(self.compute_cprob) + self.flow_threshold.setFixedWidth(40) + self.flow_threshold.setFont(self.medfont) + self.segaBoxG.addWidget(self.flow_threshold, b0, 2, 1, 2) + self.flow_threshold.setToolTip( + "threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run" + ) + + label = QLabel("cellprob\nthreshold:") + label.setToolTip( + "threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run" + ) + label.setFont(self.medfont) + self.segaBoxG.addWidget(label, b0, 4, 1, 2) + self.cellprob_threshold = QLineEdit() + self.cellprob_threshold.setText("0.0") + self.cellprob_threshold.returnPressed.connect(self.compute_cprob) + self.cellprob_threshold.setFixedWidth(40) + self.cellprob_threshold.setFont(self.medfont) + self.cellprob_threshold.setToolTip( + "threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run" + ) + self.segaBoxG.addWidget(self.cellprob_threshold, b0, 6, 1, 2) + + b0 += 1 + label = QLabel("norm percentiles:") + label.setToolTip( + "sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)" + ) + label.setFont(self.medfont) + self.segaBoxG.addWidget(label, b0, 0, 1, 8) + + b0 += 1 + self.norm_vals = [1., 99.] + self.norm_edits = [] + labels = ["lower", "upper"] + tooltips = [ + "pixels at this percentile set to 0 (default 1.0)", + "pixels at this percentile set to 1 (default 99.0)" + ] + for p in range(2): + label = QLabel(f"{labels[p]}:") + label.setToolTip(tooltips[p]) + label.setFont(self.medfont) + self.segaBoxG.addWidget(label, b0, 4 * (p % 2), 1, 2) + self.norm_edits.append(QLineEdit()) + self.norm_edits[p].setText(str(self.norm_vals[p])) + self.norm_edits[p].setFixedWidth(40) + self.norm_edits[p].setFont(self.medfont) + self.segaBoxG.addWidget(self.norm_edits[p], b0, 4 * (p % 2) + 2, 1, 2) + self.norm_edits[p].setToolTip(tooltips[p]) + + b0 += 1 + label = QLabel("niter dynamics:") + label.setFont(self.medfont) + label.setToolTip( + "number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria" + ) + self.segaBoxG.addWidget(label, b0, 0, 1, 4) + self.niter = QLineEdit() + self.niter.setText("0") + self.niter.setFixedWidth(40) + self.niter.setFont(self.medfont) + self.niter.setToolTip( + "number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria" + ) + self.segaBoxG.addWidget(self.niter, b0, 4, 1, 2) + + b += 1 + b0 = 0 + self.modelBox = QGroupBox("Other models") + self.modelBoxG = QGridLayout() + self.modelBox.setLayout(self.modelBoxG) + self.l0.addWidget(self.modelBox, b, 0, 1, 9) + self.modelBox.setFont(self.boldfont) + # choose models + self.ModelChooseC = QComboBox() + self.ModelChooseC.setFont(self.medfont) + current_index = 0 + self.ModelChooseC.addItems(["custom models"]) + if len(self.model_strings) > 0: + self.ModelChooseC.addItems(self.model_strings) + self.ModelChooseC.setFixedWidth(175) + self.ModelChooseC.setCurrentIndex(current_index) + tipstr = 'add or train your own models in the "Models" file menu and choose model here' + self.ModelChooseC.setToolTip(tipstr) + self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True)) + self.modelBoxG.addWidget(self.ModelChooseC, b0, 0, 1, 8) + + # compute segmentation w/ custom model + self.ModelButtonC = QPushButton(u"run") + self.ModelButtonC.setFont(self.medfont) + self.ModelButtonC.setFixedWidth(35) + self.ModelButtonC.clicked.connect( + lambda: self.compute_segmentation(custom=True)) + self.modelBoxG.addWidget(self.ModelButtonC, b0, 8, 1, 1) + self.ModelButtonC.setEnabled(False) + + self.net_names = [ + "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", + "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", + "cyto", "cyto2", "CPx"] + + nett = [ + "nuclei", "cellpose (cyto2_cp3)", "tissuenet_cp3", "livecell_cp3", + "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", + "deepbacs_cp3", "cyto", "cyto2", + "CPx (from Cellpose2)" + ] + b0 += 1 + self.ModelChooseB = QComboBox() + self.ModelChooseB.setFont(self.medfont) + self.ModelChooseB.addItems(["dataset-specific models"]) + self.ModelChooseB.addItems(nett) + self.ModelChooseB.setFixedWidth(175) + tipstr = "dataset-specific models" + self.ModelChooseB.setToolTip(tipstr) + self.ModelChooseB.activated.connect(lambda: self.model_choose(custom=False)) + self.modelBoxG.addWidget(self.ModelChooseB, b0, 0, 1, 8) + + # compute segmentation w/ cp model + self.ModelButtonB = QPushButton(u"run") + self.ModelButtonB.setFont(self.medfont) + self.ModelButtonB.setFixedWidth(35) + self.ModelButtonB.clicked.connect( + lambda: self.compute_segmentation(custom=False)) + self.modelBoxG.addWidget(self.ModelButtonB, b0, 8, 1, 1) + self.ModelButtonB.setEnabled(False) + + b += 1 + self.denoiseBox = QGroupBox("Image restoration") + self.denoiseBox.setFont(self.boldfont) + self.denoiseBoxG = QGridLayout() + self.denoiseBox.setLayout(self.denoiseBoxG) + self.l0.addWidget(self.denoiseBox, b, 0, 1, 9) + + b0 = 0 + + # DENOISING + self.DenoiseButtons = [] + nett = [ + "clear restore/filter", + "filter image (settings below)", + "denoise (please set cell diameter first)", + "deblur (please set cell diameter first)", + "upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)", + "one-click model trained to denoise+deblur+upsample (please set cell diameter first)" + ] + self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"] + self.restore = None + self.ratio = 1. + jj = 0 + w = 3 + for j in range(len(self.denoise_text)): + self.DenoiseButtons.append( + guiparts.DenoiseButton(self, self.denoise_text[j])) + self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w) + self.DenoiseButtons[-1].setFixedWidth(75) + self.DenoiseButtons[-1].setToolTip(nett[j]) + self.DenoiseButtons[-1].setFont(self.medfont) + b0 += 1 if j%2==1 else 0 + jj = 0 if j%2==1 else jj + w + + # b0+=1 + self.save_norm = QCheckBox("save restored/filtered image") + self.save_norm.setFont(self.medfont) + self.save_norm.setToolTip("save restored/filtered image in _seg.npy file") + self.save_norm.setChecked(True) + # self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8) + + b0 -= 3 + label = QLabel("restore-dataset:") + label.setToolTip( + "choose dataset and click [denoise], [deblur], [upsample], or [one-click]") + label.setFont(self.medfont) + self.denoiseBoxG.addWidget(label, b0, 6, 1, 3) + + b0 += 1 + self.DenoiseChoose = QComboBox() + self.DenoiseChoose.setFont(self.medfont) + self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"]) + self.DenoiseChoose.setFixedWidth(85) + tipstr = "choose model type and click [denoise], [deblur], or [upsample]" + self.DenoiseChoose.setToolTip(tipstr) + self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3) + + b0 += 2 + # FILTERING + self.filtBox = QCollapsible("custom filter settings") + self.filtBox._toggle_btn.setFont(self.medfont) + self.filtBoxG = QGridLayout() + _content = QWidget() + _content.setLayout(self.filtBoxG) + _content.setMaximumHeight(0) + _content.setMinimumHeight(0) + #_content.layout().setContentsMargins(QtCore.QMargins(0, -20, -20, -20)) + self.filtBox.setContent(_content) + self.denoiseBoxG.addWidget(self.filtBox, b0, 0, 1, 9) + + self.filt_vals = [0., 0., 0., 0.] + self.filt_edits = [] + labels = [ + "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize", + "tile_norm\nsmooth3D" + ] + tooltips = [ + "set size of surround-subtraction filter for sharpening image", + "set size of gaussian filter for smoothing image", + "set size of tiles to use to normalize image", + "set amount of smoothing of normalization values across planes" + ] + + for p in range(4): + label = QLabel(f"{labels[p]}:") + label.setToolTip(tooltips[p]) + label.setFont(self.medfont) + self.filtBoxG.addWidget(label, b0 + p // 2, 4 * (p % 2), 1, 2) + self.filt_edits.append(QLineEdit()) + self.filt_edits[p].setText(str(self.filt_vals[p])) + self.filt_edits[p].setFixedWidth(40) + self.filt_edits[p].setFont(self.medfont) + self.filtBoxG.addWidget(self.filt_edits[p], b0 + p // 2, 4 * (p % 2) + 2, 1, + 2) + self.filt_edits[p].setToolTip(tooltips[p]) + + b0 += 3 + self.norm3D_cb = QCheckBox("norm3D") + self.norm3D_cb.setFont(self.medfont) + self.norm3D_cb.setChecked(True) + self.norm3D_cb.setToolTip("run same normalization across planes") + self.filtBoxG.addWidget(self.norm3D_cb, b0, 0, 1, 3) + + self.invert_cb = QCheckBox("invert") + self.invert_cb.setFont(self.medfont) + self.invert_cb.setToolTip("invert image") + self.filtBoxG.addWidget(self.invert_cb, b0, 3, 1, 3) + + b += 1 + self.l0.addWidget(QLabel(""), b, 0, 1, 9) + self.l0.setRowStretch(b, 100) + + b += 1 + # scale toggle + self.scale_on = True + self.ScaleOn = QCheckBox("scale disk on") + self.ScaleOn.setFont(self.medfont) + self.ScaleOn.setStyleSheet("color: rgb(150,50,150);") + self.ScaleOn.setChecked(True) + self.ScaleOn.setToolTip("see current diameter as red disk at bottom") + self.ScaleOn.toggled.connect(self.toggle_scale) + self.l0.addWidget(self.ScaleOn, b, 0, 1, 5) + + return b + + def level_change(self, r): + r = ["red", "green", "blue"].index(r) + if self.loaded: + sval = self.sliders[r].value() + self.saturation[r][self.currentZ] = sval + if not self.autobtn.isChecked(): + for r in range(3): + for i in range(len(self.saturation[r])): + self.saturation[r][i] = self.saturation[r][self.currentZ] + self.update_plot() + + def keyPressEvent(self, event): + if self.loaded: + if not (event.modifiers() & + (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier | + QtCore.Qt.AltModifier) or self.in_stroke): + updated = False + if len(self.current_point_set) > 0: + if event.key() == QtCore.Qt.Key_Return: + self.add_set() + else: + nviews = self.ViewDropDown.count() - 1 + nviews += int( + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).isEnabled()) + if event.key() == QtCore.Qt.Key_X: + self.MCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Z: + self.OCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Left or event.key( + ) == QtCore.Qt.Key_A: + self.get_prev_image() + elif event.key() == QtCore.Qt.Key_Right or event.key( + ) == QtCore.Qt.Key_D: + self.get_next_image() + elif event.key() == QtCore.Qt.Key_PageDown: + self.view = (self.view + 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + elif event.key() == QtCore.Qt.Key_PageUp: + self.view = (self.view - 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + + # can change background or stroke size if cell not finished + if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W: + self.color = (self.color - 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_Down or event.key( + ) == QtCore.Qt.Key_S: + self.color = (self.color + 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_R: + if self.color != 1: + self.color = 1 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_G: + if self.color != 2: + self.color = 2 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_B: + if self.color != 3: + self.color = 3 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif (event.key() == QtCore.Qt.Key_Comma or + event.key() == QtCore.Qt.Key_Period): + count = self.BrushChoose.count() + gci = self.BrushChoose.currentIndex() + if event.key() == QtCore.Qt.Key_Comma: + gci = max(0, gci - 1) + else: + gci = min(count - 1, gci + 1) + self.BrushChoose.setCurrentIndex(gci) + self.brush_choose() + if not updated: + self.update_plot() + if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: + self.p0.keyPressEvent(event) + + def autosave_on(self): + if self.SCheckBox.isChecked(): + self.autosave = True + else: + self.autosave = False + + def check_gpu(self, torch=True): + # also decide whether or not to use torch + self.useGPU.setChecked(False) + self.useGPU.setEnabled(False) + if core.use_gpu(use_torch=True): + self.useGPU.setEnabled(True) + self.useGPU.setChecked(True) + else: + self.useGPU.setStyleSheet("color: rgb(80,80,80);") + + def get_channels(self): + channels = [ + self.ChannelChoose[0].currentIndex(), self.ChannelChoose[1].currentIndex() + ] + if hasattr(self, "current_model"): + if self.current_model == "nuclei": + channels[1] = 0 + if channels[0] == 0: + channels[1] = 0 + if self.nchan == 1: + channels = [0, 0] + elif self.nchan == 2: + if channels[0] == 3: + channels[0] = 1 if channels[1] != 1 else 2 + print( + f"GUI_WARNING: only two channels in image, cannot use blue channel, changing channels" + ) + if channels[1] == 3: + channels[1] = 1 if channels[0] != 1 else 2 + print( + f"GUI_WARNING: only two channels in image, cannot use blue channel, changing channels" + ) + self.ChannelChoose[0].setCurrentIndex(channels[0]) + self.ChannelChoose[1].setCurrentIndex(channels[1]) + return channels + + def model_choose(self, custom=False): + index = self.ModelChooseC.currentIndex( + ) if custom else self.ModelChooseB.currentIndex() + if index > 0: + if custom: + model_name = self.ModelChooseC.currentText() + else: + model_name = self.net_names[index - 1] + print(f"GUI_INFO: selected model {model_name}, loading now") + self.initialize_model(model_name=model_name, custom=custom) + self.diameter = self.model.diam_labels + self.Diameter.setText("%0.2f" % self.diameter) + print( + f"GUI_INFO: diameter set to {self.diameter: 0.2f} (but can be changed)") + + def calibrate_size(self): + self.initialize_model(model_name="cyto3") + diams, _ = self.model.sz.eval(self.stack[self.currentZ].copy(), + channels=self.get_channels(), + progress=self.progress) + diams = np.maximum(5.0, diams) + self.logger.info("estimated diameter of cells using %s model = %0.1f pixels" % + (self.current_model, diams)) + self.Diameter.setText("%0.1f" % diams) + self.diameter = diams + self.update_scale() + self.progress.setValue(100) + + def toggle_scale(self): + if self.scale_on: + self.p0.removeItem(self.scale) + self.scale_on = False + else: + self.p0.addItem(self.scale) + self.scale_on = True + + def enable_buttons(self): + if len(self.model_strings) > 0: + self.ModelButtonC.setEnabled(True) + for i in range(len(self.StyleButtons)): + self.StyleButtons[i].setEnabled(True) + for i in range(len(self.DenoiseButtons)): + self.DenoiseButtons[i].setEnabled(True) + if self.load_3D: + self.DenoiseButtons[-2].setEnabled(False) + self.ModelButtonB.setEnabled(True) + self.SizeButton.setEnabled(True) + self.newmodel.setEnabled(True) + self.loadMasks.setEnabled(True) + + for n in range(self.nchan): + self.sliders[n].setEnabled(True) + for n in range(self.nchan, 3): + self.sliders[n].setEnabled(True) + + self.toggle_mask_ops() + + self.update_plot() + self.setWindowTitle(self.filename) + + def disable_buttons_removeROIs(self): + if len(self.model_strings) > 0: + self.ModelButtonC.setEnabled(False) + for i in range(len(self.StyleButtons)): + self.StyleButtons[i].setEnabled(False) + self.ModelButtonB.setEnabled(False) + self.SizeButton.setEnabled(False) + self.newmodel.setEnabled(False) + self.loadMasks.setEnabled(False) + self.saveSet.setEnabled(False) + self.savePNG.setEnabled(False) + self.saveFlows.setEnabled(False) + self.saveOutlines.setEnabled(False) + self.saveROIs.setEnabled(False) + + self.MakeDeletionRegionButton.setEnabled(False) + self.DeleteMultipleROIButton.setEnabled(False) + self.DoneDeleteMultipleROIButton.setEnabled(True) + self.CancelDeleteMultipleROIButton.setEnabled(True) + + def toggle_mask_ops(self): + self.update_layer() + self.toggle_saving() + self.toggle_removals() + + def toggle_saving(self): + if self.ncells > 0: + self.saveSet.setEnabled(True) + self.savePNG.setEnabled(True) + self.saveFlows.setEnabled(True) + self.saveOutlines.setEnabled(True) + self.saveROIs.setEnabled(True) + else: + self.saveSet.setEnabled(False) + self.savePNG.setEnabled(False) + self.saveFlows.setEnabled(False) + self.saveOutlines.setEnabled(False) + self.saveROIs.setEnabled(False) + + def toggle_removals(self): + if self.ncells > 0: + self.ClearButton.setEnabled(True) + self.remcell.setEnabled(True) + self.undo.setEnabled(True) + self.MakeDeletionRegionButton.setEnabled(True) + self.DeleteMultipleROIButton.setEnabled(True) + self.DoneDeleteMultipleROIButton.setEnabled(False) + self.CancelDeleteMultipleROIButton.setEnabled(False) + else: + self.ClearButton.setEnabled(False) + self.remcell.setEnabled(False) + self.undo.setEnabled(False) + self.MakeDeletionRegionButton.setEnabled(False) + self.DeleteMultipleROIButton.setEnabled(False) + self.DoneDeleteMultipleROIButton.setEnabled(False) + self.CancelDeleteMultipleROIButton.setEnabled(False) + + def remove_action(self): + if self.selected > 0: + self.remove_cell(self.selected) + + def undo_action(self): + if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ): + self.remove_stroke() + else: + # remove previous cell + if self.ncells > 0: + self.remove_cell(self.ncells) + + def undo_remove_action(self): + self.undo_remove_cell() + + def get_files(self): + folder = os.path.dirname(self.filename) + mask_filter = "_masks" + images = get_image_files(folder, mask_filter) + fnames = [os.path.split(images[k])[-1] for k in range(len(images))] + f0 = os.path.split(self.filename)[-1] + idx = np.nonzero(np.array(fnames) == f0)[0][0] + return images, idx + + def get_prev_image(self): + images, idx = self.get_files() + idx = (idx - 1) % len(images) + io._load_image(self, filename=images[idx]) + + def get_next_image(self, load_seg=True): + images, idx = self.get_files() + idx = (idx + 1) % len(images) + io._load_image(self, filename=images[idx], load_seg=load_seg) + + def dragEnterEvent(self, event): + if event.mimeData().hasUrls(): + event.accept() + else: + event.ignore() + + def dropEvent(self, event): + files = [u.toLocalFile() for u in event.mimeData().urls()] + if os.path.splitext(files[0])[-1] == ".npy": + io._load_seg(self, filename=files[0], load_3D=self.load_3D) + else: + io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D) + + def toggle_masks(self): + if self.MCheckBox.isChecked(): + self.masksOn = True + else: + self.masksOn = False + if self.OCheckBox.isChecked(): + self.outlinesOn = True + else: + self.outlinesOn = False + if not self.masksOn and not self.outlinesOn: + self.p0.removeItem(self.layer) + self.layer_off = True + else: + if self.layer_off: + self.p0.addItem(self.layer) + self.draw_layer() + self.update_layer() + if self.loaded: + self.update_plot() + self.update_layer() + + def make_viewbox(self): + self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True, + name="plot1", border=[100, 100, + 100], invertY=True) + self.p0.setCursor(QtCore.Qt.CrossCursor) + self.brush_size = 3 + self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1) + self.p0.setMenuEnabled(False) + self.p0.setMouseEnabled(x=True, y=True) + self.img = pg.ImageItem(viewbox=self.p0, parent=self) + self.img.autoDownsample = False + self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self) + self.layer.setLevels([0, 255]) + self.scale = pg.ImageItem(viewbox=self.p0, parent=self) + self.scale.setLevels([0, 255]) + self.p0.scene().contextMenuItem = self.p0 + #self.p0.setMouseEnabled(x=False,y=False) + self.Ly, self.Lx = 512, 512 + self.p0.addItem(self.img) + self.p0.addItem(self.layer) + self.p0.addItem(self.scale) + + def reset(self): + # ---- start sets of points ---- # + self.selected = 0 + self.nchan = 3 + self.loaded = False + self.channel = [0, 1] + self.current_point_set = [] + self.in_stroke = False + self.strokes = [] + self.stroke_appended = True + self.resize = False + self.ncells = 0 + self.zdraw = [] + self.removed_cell = [] + self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] + + # -- zero out image stack -- # + self.opacity = 128 # how opaque masks should be + self.outcolor = [200, 200, 255, 200] + self.NZ, self.Ly, self.Lx = 1, 224, 224 + self.saturation = [] + for r in range(3): + self.saturation.append([[0, 255] for n in range(self.NZ)]) + self.sliders[r].setValue([0, 255]) + self.sliders[r].setEnabled(False) + self.sliders[r].show() + self.currentZ = 0 + self.flows = [[], [], [], [], [[]]] + # masks matrix + # image matrix with a scale disk + self.stack = np.zeros((1, self.Ly, self.Lx, 3)) + self.Lyr, self.Lxr = self.Ly, self.Lx + self.Ly0, self.Lx0 = self.Ly, self.Lx + self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) + self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) + self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16) + self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16) + if self.restore and "upsample" in self.restore: + self.cellpix_resize = self.cellpix + self.cellpix_orig = self.cellpix + self.outpix_resize = self.cellpix + self.outpix_orig = self.cellpix + self.ismanual = np.zeros(0, "bool") + + # -- set menus to default -- # + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + self.view = 0 + self.ViewDropDown.setCurrentIndex(0) + self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) + self.delete_restore() + + self.clear_all() + + #self.update_plot() + self.filename = [] + self.loaded = False + self.recompute_masks = False + + self.deleting_multiple = False + self.removing_cells_list = [] + self.removing_region = False + self.remove_roi_obj = None + + def delete_restore(self): + """ delete restored imgs but don't reset settings """ + if hasattr(self, "stack_filtered"): + del self.stack_filtered + if hasattr(self, "cellpix_orig"): + self.cellpix = self.cellpix_orig.copy() + self.outpix = self.outpix_orig.copy() + del self.outpix_orig, self.outpix_resize + del self.cellpix_orig, self.cellpix_resize + + def clear_restore(self): + """ delete restored imgs and reset settings """ + print("GUI_INFO: clearing restored image") + self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) + if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1: + self.ViewDropDown.setCurrentIndex(0) + self.delete_restore() + self.restore = None + self.ratio = 1. + self.set_normalize_params(self.get_normalize_params()) + + def brush_choose(self): + self.brush_size = self.BrushChoose.currentIndex() * 2 + 1 + if self.loaded: + self.layer.setDrawKernel(kernel_size=self.brush_size) + self.update_layer() + + def clear_all(self): + self.prev_selected = 0 + self.selected = 0 + if self.restore and "upsample" in self.restore: + self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8) + self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) + self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) + self.cellpix_resize = self.cellpix.copy() + self.outpix_resize = self.outpix.copy() + self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) + self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) + else: + self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) + self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) + self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) + + self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] + self.ncells = 0 + self.toggle_removals() + self.update_scale() + self.update_layer() + + def select_cell(self, idx): + self.prev_selected = self.selected + self.selected = idx + if self.selected > 0: + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.array( + [255, 255, 255, self.opacity]) + self.update_layer() + + def select_cell_multi(self, idx): + if idx > 0: + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.array( + [255, 255, 255, self.opacity]) + self.update_layer() + + def unselect_cell(self): + if self.selected > 0: + idx = self.selected + if idx < self.ncells + 1: + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.append( + self.cellcolors[idx], self.opacity) + if self.outlinesOn: + self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( + np.uint8) + #[0,0,0,self.opacity]) + self.update_layer() + self.selected = 0 + + def unselect_cell_multi(self, idx): + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx], + self.opacity) + if self.outlinesOn: + self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( + np.uint8) + # [0,0,0,self.opacity]) + self.update_layer() + + def remove_cell(self, idx): + if isinstance(idx, (int, np.integer)): + idx = [idx] + # because the function remove_single_cell updates the state of the cellpix and outpix arrays + # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order + # so that the indices are correct + idx.sort(reverse=True) + for i in idx: + self.remove_single_cell(i) + self.ncells -= len(idx) # _save_sets uses ncells + + if self.ncells == 0: + self.ClearButton.setEnabled(False) + if self.NZ == 1: + io._save_sets_with_check(self) + + self.update_layer() + + def remove_single_cell(self, idx): + # remove from manual array + self.selected = 0 + if self.NZ > 1: + zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0] + else: + zextent = [0] + for z in zextent: + cp = self.cellpix[z] == idx + op = self.outpix[z] == idx + # remove from self.cellpix and self.outpix + self.cellpix[z, cp] = 0 + self.outpix[z, op] = 0 + if z == self.currentZ: + # remove from mask layer + self.layerz[cp] = np.array([0, 0, 0, 0]) + + # reduce other pixels by -1 + self.cellpix[self.cellpix > idx] -= 1 + self.outpix[self.outpix > idx] -= 1 + + if self.NZ == 1: + self.removed_cell = [ + self.ismanual[idx - 1], self.cellcolors[idx], + np.nonzero(cp), + np.nonzero(op) + ] + self.redo.setEnabled(True) + ar, ac = self.removed_cell[2] + d = datetime.datetime.now() + self.track_changes.append( + [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]]) + # remove cell from lists + self.ismanual = np.delete(self.ismanual, idx - 1) + self.cellcolors = np.delete(self.cellcolors, [idx], axis=0) + del self.zdraw[idx - 1] + print("GUI_INFO: removed cell %d" % (idx - 1)) + + def remove_region_cells(self): + if self.removing_cells_list: + for idx in self.removing_cells_list: + self.unselect_cell_multi(idx) + self.removing_cells_list.clear() + self.disable_buttons_removeROIs() + self.removing_region = True + + self.clear_multi_selected_cells() + + # make roi region here in center of view, making ROI half the size of the view + roi_width = self.p0.viewRect().width() / 2 + x_loc = self.p0.viewRect().x() + (roi_width / 2) + roi_height = self.p0.viewRect().height() / 2 + y_loc = self.p0.viewRect().y() + (roi_height / 2) + + pos = [x_loc, y_loc] + roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2), + removable=True) + roi.sigRemoveRequested.connect(self.remove_roi) + roi.sigRegionChangeFinished.connect(self.roi_changed) + self.p0.addItem(roi) + self.remove_roi_obj = roi + self.roi_changed(roi) + + def delete_multiple_cells(self): + self.unselect_cell() + self.disable_buttons_removeROIs() + self.DoneDeleteMultipleROIButton.setEnabled(True) + self.MakeDeletionRegionButton.setEnabled(True) + self.CancelDeleteMultipleROIButton.setEnabled(True) + self.deleting_multiple = True + + def done_remove_multiple_cells(self): + self.deleting_multiple = False + self.removing_region = False + self.DoneDeleteMultipleROIButton.setEnabled(False) + self.MakeDeletionRegionButton.setEnabled(False) + self.CancelDeleteMultipleROIButton.setEnabled(False) + + if self.removing_cells_list: + self.removing_cells_list = list(set(self.removing_cells_list)) + display_remove_list = [i - 1 for i in self.removing_cells_list] + print(f"GUI_INFO: removing cells: {display_remove_list}") + self.remove_cell(self.removing_cells_list) + self.removing_cells_list.clear() + self.unselect_cell() + self.enable_buttons() + + if self.remove_roi_obj is not None: + self.remove_roi(self.remove_roi_obj) + + def merge_cells(self, idx): + self.prev_selected = self.selected + self.selected = idx + if self.selected != self.prev_selected: + for z in range(self.NZ): + ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected) + ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected) + touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3, + (ac0[:, np.newaxis] - ac1) < 3).sum() + ar = np.hstack((ar0, ar1)) + ac = np.hstack((ac0, ac1)) + vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected) + vr1, vc1 = np.nonzero(self.outpix[z] == self.selected) + self.outpix[z, vr0, vc0] = 0 + self.outpix[z, vr1, vc1] = 0 + if touching > 0: + mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8) + mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2 + + else: + vr = np.hstack((vr0, vr1)) + vc = np.hstack((vc0, vc1)) + color = self.cellcolors[self.prev_selected] + self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected) + self.remove_cell(self.selected) + print("GUI_INFO: merged two cells") + self.update_layer() + io._save_sets_with_check(self) + self.undo.setEnabled(False) + self.redo.setEnabled(False) + + def undo_remove_cell(self): + if len(self.removed_cell) > 0: + z = 0 + ar, ac = self.removed_cell[2] + vr, vc = self.removed_cell[3] + color = self.removed_cell[1] + self.draw_mask(z, ar, ac, vr, vc, color) + self.toggle_mask_ops() + self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0) + self.ncells += 1 + self.ismanual = np.append(self.ismanual, self.removed_cell[0]) + self.zdraw.append([]) + print(">>> added back removed cell") + self.update_layer() + io._save_sets_with_check(self) + self.removed_cell = [] + self.redo.setEnabled(False) + + def remove_stroke(self, delete_points=True, stroke_ind=-1): + stroke = np.array(self.strokes[stroke_ind]) + cZ = self.currentZ + inZ = stroke[0, 0] == cZ + if inZ: + outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0 + self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0]) + cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]] + ccol = self.cellcolors.copy() + if self.selected > 0: + ccol[self.selected] = np.array([255, 255, 255]) + col2mask = ccol[cellpix] + if self.masksOn: + col2mask = np.concatenate( + (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1) + else: + col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)), + axis=-1) + self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask + if self.outlinesOn: + self.layerz[stroke[outpix, 1], stroke[outpix, + 2]] = np.array(self.outcolor) + if delete_points: + # self.current_point_set = self.current_point_set[:-1*(stroke[:,-1]==1).sum()] + del self.current_point_set[stroke_ind] + self.update_layer() + + del self.strokes[stroke_ind] + + def plot_clicked(self, event): + if event.button()==QtCore.Qt.LeftButton \ + and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\ + and not self.removing_region: + if event.double(): + try: + self.p0.setYRange(0, self.Ly + self.pr) + except: + self.p0.setYRange(0, self.Ly) + self.p0.setXRange(0, self.Lx) + + def cancel_remove_multiple(self): + self.clear_multi_selected_cells() + self.done_remove_multiple_cells() + + def clear_multi_selected_cells(self): + # unselect all previously selected cells: + for idx in self.removing_cells_list: + self.unselect_cell_multi(idx) + self.removing_cells_list.clear() + + def add_roi(self, roi): + self.p0.addItem(roi) + self.remove_roi_obj = roi + + def remove_roi(self, roi): + self.clear_multi_selected_cells() + assert roi == self.remove_roi_obj + self.remove_roi_obj = None + self.p0.removeItem(roi) + self.removing_region = False + + def roi_changed(self, roi): + # find the overlapping cells and make them selected + pos = roi.pos() + size = roi.size() + x0 = int(pos.x()) + y0 = int(pos.y()) + x1 = int(pos.x() + size.x()) + y1 = int(pos.y() + size.y()) + if x0 < 0: + x0 = 0 + if y0 < 0: + y0 = 0 + if x1 > self.Lx: + x1 = self.Lx + if y1 > self.Ly: + y1 = self.Ly + + # find cells in that region + cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1]) + cell_idxs = np.trim_zeros(cell_idxs) + # deselect cells not in region by deselecting all and then selecting the ones in the region + self.clear_multi_selected_cells() + + for idx in cell_idxs: + self.select_cell_multi(idx) + self.removing_cells_list.append(idx) + + self.update_layer() + + def mouse_moved(self, pos): + items = self.win.scene().items(pos) + + def color_choose(self): + self.color = self.RGBDropDown.currentIndex() + self.view = 0 + self.ViewDropDown.setCurrentIndex(self.view) + self.update_plot() + + def update_plot(self): + self.view = self.ViewDropDown.currentIndex() + self.Ly, self.Lx, _ = self.stack[self.currentZ].shape + + if self.restore and "upsample" in self.restore: + if self.view != 0: + if self.view == 3: + self.resize = True + elif len(self.flows[0]) > 0 and self.flows[0].shape[1] == self.Lyr: + self.resize = True + else: + self.resize = False + else: + self.resize = False + self.draw_layer() + self.update_scale() + self.update_layer() + + if self.view == 0 or self.view == self.ViewDropDown.count() - 1: + image = self.stack[ + self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ] + if self.nchan == 1: + # show single channel + image = image[..., 0] + if self.color == 0: + self.img.setImage(image, autoLevels=False, lut=None) + if self.nchan > 1: + levels = np.array([ + self.saturation[0][self.currentZ], + self.saturation[1][self.currentZ], + self.saturation[2][self.currentZ] + ]) + self.img.setLevels(levels) + else: + self.img.setLevels(self.saturation[0][self.currentZ]) + elif self.color > 0 and self.color < 4: + if self.nchan > 1: + image = image[:, :, self.color - 1] + self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color]) + if self.nchan > 1: + self.img.setLevels(self.saturation[self.color - 1][self.currentZ]) + else: + self.img.setLevels(self.saturation[0][self.currentZ]) + elif self.color == 4: + if self.nchan > 1: + image = image.mean(axis=-1) + self.img.setImage(image, autoLevels=False, lut=None) + self.img.setLevels(self.saturation[0][self.currentZ]) + elif self.color == 5: + if self.nchan > 1: + image = image.mean(axis=-1) + self.img.setImage(image, autoLevels=False, lut=self.cmap[0]) + self.img.setLevels(self.saturation[0][self.currentZ]) + else: + image = np.zeros((self.Ly, self.Lx), np.uint8) + if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0: + image = self.flows[self.view - 1][self.currentZ] + if self.view > 1: + self.img.setImage(image, autoLevels=False, lut=self.bwr) + else: + self.img.setImage(image, autoLevels=False, lut=None) + self.img.setLevels([0.0, 255.0]) + + for r in range(3): + self.sliders[r].setValue([ + self.saturation[r][self.currentZ][0], + self.saturation[r][self.currentZ][1] + ]) + self.win.show() + self.show() + + def update_layer(self): + if self.masksOn or self.outlinesOn: + #self.draw_layer() + self.layer.setImage(self.layerz, autoLevels=False) + self.update_roi_count() + self.win.show() + self.show() + + def update_roi_count(self): + self.roi_count.setText(f"{self.ncells} ROIs") + + def add_set(self): + if len(self.current_point_set) > 0: + while len(self.strokes) > 0: + self.remove_stroke(delete_points=False) + if len(self.current_point_set[0]) > 8: + color = self.colormap[self.ncells, :3] + median = self.add_mask(points=self.current_point_set, color=color) + if median is not None: + self.removed_cell = [] + self.toggle_mask_ops() + self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], + axis=0) + self.ncells += 1 + self.ismanual = np.append(self.ismanual, True) + if self.NZ == 1: + # only save after each cell if single image + io._save_sets_with_check(self) + else: + print("GUI_ERROR: cell too small, not drawn") + self.current_stroke = [] + self.strokes = [] + self.current_point_set = [] + self.update_layer() + + def add_mask(self, points=None, color=(100, 200, 50), dense=True): + # points is list of strokes + points_all = np.concatenate(points, axis=0) + + # loop over z values + median = [] + zdraw = np.unique(points_all[:, 0]) + z = 0 + ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros( + 0, "int"), np.zeros(0, "int") + for stroke in points: + stroke = np.concatenate(stroke, axis=0).reshape(-1, 4) + vr = stroke[:, 1] + vc = stroke[:, 2] + # get points inside drawn points + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) + pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), + axis=-1)[:, np.newaxis, :] + mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) + ar, ac = np.nonzero(mask) + ar, ac = ar + vr.min() - 2, ac + vc.min() - 2 + # get dense outline + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0][:,0].T + vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 + # concatenate all points + ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac)))) + # if these pixels are overlapping with another cell, reassign them + ioverlap = self.cellpix[z][ar, ac] > 0 + if (~ioverlap).sum() < 10: + print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn") + return None + elif ioverlap.sum() > 0: + ar, ac = ar[~ioverlap], ac[~ioverlap] + # compute outline of new mask + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) + mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1 + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0][:,0].T + vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 + ars = np.concatenate((ars, ar), axis=0) + acs = np.concatenate((acs, ac), axis=0) + vrs = np.concatenate((vrs, vr), axis=0) + vcs = np.concatenate((vcs, vc), axis=0) + + self.draw_mask(z, ars, acs, vrs, vcs, color) + median.append(np.array([np.median(ars), np.median(acs)])) + + self.zdraw.append(zdraw) + d = datetime.datetime.now() + self.track_changes.append( + [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]]) + return median + + def draw_mask(self, z, ar, ac, vr, vc, color, idx=None): + """ draw single mask using outlines and area """ + if idx is None: + idx = self.ncells + 1 + self.cellpix[z, vr, vc] = idx + self.cellpix[z, ar, ac] = idx + self.outpix[z, vr, vc] = idx + if self.restore and "upsample" in self.restore: + if self.resize: + self.cellpix_resize[z, vr, vc] = idx + self.cellpix_resize[z, ar, ac] = idx + self.outpix_resize[z, vr, vc] = idx + self.cellpix_orig[z, (vr / self.ratio).astype(int), + (vc / self.ratio).astype(int)] = idx + self.cellpix_orig[z, (ar / self.ratio).astype(int), + (ac / self.ratio).astype(int)] = idx + self.outpix_orig[z, (vr / self.ratio).astype(int), + (vc / self.ratio).astype(int)] = idx + else: + self.cellpix_orig[z, vr, vc] = idx + self.cellpix_orig[z, ar, ac] = idx + self.outpix_orig[z, vr, vc] = idx + + # get upsampled mask + vrr = (vr.copy() * self.ratio).astype(int) + vcr = (vc.copy() * self.ratio).astype(int) + mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8) + pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2), + axis=-1)[:, np.newaxis, :] + mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) + arr, acr = np.nonzero(mask) + arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2 + # get dense outline + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2 + # concatenate all points + arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr)))) + self.cellpix_resize[z, vrr, vcr] = idx + self.cellpix_resize[z, arr, acr] = idx + self.outpix_resize[z, vrr, vcr] = idx + + if z == self.currentZ: + self.layerz[ar, ac, :3] = color + if self.masksOn: + self.layerz[ar, ac, -1] = self.opacity + if self.outlinesOn: + self.layerz[vr, vc] = np.array(self.outcolor) + + def compute_scale(self): + self.diameter = float(self.Diameter.text()) + self.pr = int(float(self.Diameter.text())) + self.radii_padding = int(self.pr * 1.25) + self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8) + yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1], + self.pr / 2, self.Ly + self.radii_padding, self.Lx) + # rgb(150,50,150) + self.radii[yy, xx, 0] = 150 + self.radii[yy, xx, 1] = 50 + self.radii[yy, xx, 2] = 150 + self.radii[yy, xx, 3] = 255 + self.p0.setYRange(0, self.Ly + self.radii_padding) + self.p0.setXRange(0, self.Lx) + + def update_scale(self): + self.compute_scale() + self.scale.setImage(self.radii, autoLevels=False) + self.scale.setLevels([0.0, 255.0]) + self.win.show() + self.show() + + def redraw_masks(self, masks=True, outlines=True, draw=True): + self.draw_layer() + + def draw_masks(self): + self.draw_layer() + + def draw_layer(self): + if self.resize: + self.Ly, self.Lx = self.Lyr, self.Lxr + else: + self.Ly, self.Lx = self.Ly0, self.Lx0 + + if self.masksOn or self.outlinesOn: + if self.restore and "upsample" in self.restore: + if self.resize: + self.cellpix = self.cellpix_resize.copy() + self.outpix = self.outpix_resize.copy() + else: + self.cellpix = self.cellpix_orig.copy() + self.outpix = self.outpix_orig.copy() + + #print(self.cellpix.shape, self.outpix.shape, self.cellpix.max(), self.outpix.max()) + self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8) + if self.masksOn: + self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :] + self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ] + > 0).astype(np.uint8) + if self.selected > 0: + self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array( + [255, 255, 255, self.opacity]) + cZ = self.currentZ + stroke_z = np.array([s[0][0] for s in self.strokes]) + inZ = np.nonzero(stroke_z == cZ)[0] + if len(inZ) > 0: + for i in inZ: + stroke = np.array(self.strokes[i]) + self.layerz[stroke[:, 1], stroke[:, + 2]] = np.array([255, 0, 255, 100]) + else: + self.layerz[..., 3] = 0 + + if self.outlinesOn: + self.layerz[self.outpix[self.currentZ] > 0] = np.array( + self.outcolor).astype(np.uint8) + + def set_restore_button(self): + keys = self.denoise_text + for i, key in enumerate(keys): + if key != "none" and (self.restore and key in self.restore): + self.DenoiseButtons[i].setStyleSheet(self.stylePressed) + elif key == "none" and self.restore is None: + self.DenoiseButtons[i].setStyleSheet(self.stylePressed) + else: + if self.DenoiseButtons[i].isEnabled(): + self.DenoiseButtons[i].setStyleSheet(self.styleUnpressed) + + def set_normalize_params(self, normalize_params): + from cellpose.models import normalize_default + if self.restore != "filter": + keys = list(normalize_params.keys()).copy() + for key in keys: + if key != "percentile": + normalize_params[key] = normalize_default[key] + normalize_params = {**normalize_default, **normalize_params} + percentile = self.check_percentile_params(normalize_params["percentile"]) + out = self.check_filter_params(normalize_params["sharpen_radius"], + normalize_params["smooth_radius"], + normalize_params["tile_norm_blocksize"], + normalize_params["tile_norm_smooth3D"], + normalize_params["norm3D"], + normalize_params["invert"]) + + def check_percentile_params(self, percentile): + # check normalization params + if percentile is not None and not (percentile[0] >= 0 and percentile[1] > 0 and + percentile[0] < 100 and percentile[1] <= 100 + and percentile[1] > percentile[0]): + print( + "GUI_ERROR: percentiles need be between 0 and 100, and upper > lower, using defaults" + ) + self.norm_edits[0].setText("1.") + self.norm_edits[1].setText("99.") + percentile = [1., 99.] + elif percentile is None: + percentile = [1., 99.] + self.norm_edits[0].setText(str(percentile[0])) + self.norm_edits[1].setText(str(percentile[1])) + return percentile + + def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert): + tile_norm = 0 if tile_norm < 0 else tile_norm + sharpen = 0 if sharpen < 0 else sharpen + smooth = 0 if smooth < 0 else smooth + smooth3D = 0 if smooth3D < 0 else smooth3D + norm3D = bool(norm3D) + invert = bool(invert) + if tile_norm > self.Ly and tile_norm > self.Lx: + print( + "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling" + ) + tile_norm = 0 + self.filt_edits[0].setText(str(sharpen)) + self.filt_edits[1].setText(str(smooth)) + self.filt_edits[2].setText(str(tile_norm)) + self.filt_edits[3].setText(str(smooth3D)) + self.norm3D_cb.setChecked(norm3D) + self.invert_cb.setChecked(invert) + return sharpen, smooth, tile_norm, smooth3D, norm3D, invert + + def get_normalize_params(self): + percentile = [ + float(self.norm_edits[0].text()), + float(self.norm_edits[1].text()) + ] + self.check_percentile_params(percentile) + normalize_params = {"percentile": percentile} + norm3D = self.norm3D_cb.isChecked() + normalize_params["norm3D"] = norm3D + if self.restore == "filter": + sharpen = float(self.filt_edits[0].text()) + smooth = float(self.filt_edits[1].text()) + tile_norm = float(self.filt_edits[2].text()) + smooth3D = float(self.filt_edits[3].text()) + invert = self.invert_cb.isChecked() + out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D, + invert) + sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out + normalize_params["sharpen_radius"] = sharpen + normalize_params["smooth_radius"] = smooth + normalize_params["tile_norm_blocksize"] = tile_norm + normalize_params["tile_norm_smooth3D"] = smooth3D + normalize_params["invert"] = invert + + from cellpose.models import normalize_default + normalize_params = {**normalize_default, **normalize_params} + + return normalize_params + + def compute_saturation(self, return_img=False): + norm = self.get_normalize_params() + print(norm) + sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"] + percentile = norm["percentile"] + tile_norm = norm["tile_norm_blocksize"] + invert = norm["invert"] + norm3D = norm["norm3D"] + smooth3D = norm["tile_norm_smooth3D"] + tile_norm = norm["tile_norm_blocksize"] + + # if grayscale, use gray img + channels = self.get_channels() + if channels[0] == 0: + img_norm = self.stack.mean(axis=-1, keepdims=True) + elif sharpen > 0 or smooth > 0 or tile_norm > 0: + img_norm = self.stack.copy() + else: + img_norm = self.stack + + if sharpen > 0 or smooth > 0 or tile_norm > 0: + self.clear_restore() + self.restore = "filter" + print( + "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0" + ) + print( + "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this" + ) + img_norm = self.stack.copy() + if sharpen > 0 or smooth > 0: + img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen, + smooth_radius=smooth) + + if tile_norm > 0: + img_norm = normalize99_tile(img_norm, blocksize=tile_norm, + lower=percentile[0], upper=percentile[1], + smooth3D=smooth3D, norm3D=norm3D) + # convert to 0->255 + img_norm_min = img_norm.min() + img_norm_max = img_norm.max() + for c in range(img_norm.shape[-1]): + if np.ptp(img_norm[..., c]) > 1e-3: + img_norm[..., c] -= img_norm_min + img_norm[..., c] /= (img_norm_max - img_norm_min) + img_norm *= 255 + self.stack_filtered = img_norm + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).setEnabled(True) + self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1) + elif invert: + img_norm = self.stack.copy() + else: + img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered + + self.saturation = [] + for c in range(img_norm.shape[-1]): + self.saturation.append([]) + if np.ptp(img_norm[..., c]) > 1e-3: + if norm3D: + x01 = np.percentile(img_norm[..., c], percentile[0]) + x99 = np.percentile(img_norm[..., c], percentile[1]) + if invert: + x01i = 255. - x99 + x99i = 255. - x01 + x01, x99 = x01i, x99i + for n in range(self.NZ): + self.saturation[-1].append([x01, x99]) + else: + for z in range(self.NZ): + if self.NZ > 1: + x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) + x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) + else: + x01 = np.percentile(img_norm[..., c], percentile[0]) + x99 = np.percentile(img_norm[..., c], percentile[1]) + if invert: + x01i = 255. - x99 + x99i = 255. - x01 + x01, x99 = x01i, x99i + self.saturation[-1].append([x01, x99]) + else: + for n in range(self.NZ): + self.saturation[-1].append([0, 255.]) + # if only 2 restore channels, add blue + if len(self.saturation) < 3: + for i in range(3 - len(self.saturation)): + self.saturation.append([]) + for n in range(self.NZ): + self.saturation[-1].append([0, 255.]) + print(self.saturation[2][self.currentZ]) + + if invert: + img_norm = 255. - img_norm + self.stack_filtered = img_norm + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).setEnabled(True) + self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1) + + if img_norm.shape[-1] == 1: + self.saturation.append(self.saturation[0]) + self.saturation.append(self.saturation[0]) + + self.autobtn.setChecked(True) + self.update_plot() + + def chanchoose(self, image): + if image.ndim > 2 and self.nchan > 1: + if self.ChannelChoose[0].currentIndex() == 0: + return image.mean(axis=-1, keepdims=True) + else: + chanid = [self.ChannelChoose[0].currentIndex() - 1] + if self.ChannelChoose[1].currentIndex() > 0: + chanid.append(self.ChannelChoose[1].currentIndex() - 1) + return image[:, :, chanid] + else: + return image + + def get_model_path(self, custom=False): + if custom: + self.current_model = self.ModelChooseC.currentText() + self.current_model_path = os.fspath( + models.MODEL_DIR.joinpath(self.current_model)) + else: + self.current_model = self.net_names[max( + 0, + self.ModelChooseB.currentIndex() - 1)] + self.current_model_path = models.model_path(self.current_model) + + def initialize_model(self, model_name=None, custom=False): + if model_name == "dataset-specific models": + raise ValueError("need to specify model (use dropdown)") + elif model_name is None or custom: + self.get_model_path(custom=custom) + if not os.path.exists(self.current_model_path): + raise ValueError("need to specify model (use dropdown)") + + if model_name is None or not isinstance(model_name, str): + self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), + pretrained_model=self.current_model_path) + else: + self.current_model = model_name + if self.current_model == "cyto" or self.current_model == "nuclei": + self.current_model_path = models.model_path(self.current_model, 0) + else: + self.current_model_path = os.fspath( + models.MODEL_DIR.joinpath(self.current_model)) + + if self.current_model != "cyto3": + diam_mean = 17. if self.current_model == "nuclei" else 30. + self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), + diam_mean=diam_mean, + model_type=self.current_model) + else: + self.model = models.Cellpose(gpu=self.useGPU.isChecked(), + model_type=self.current_model) + + def add_model(self): + io._add_model(self) + return + + def remove_model(self): + io._remove_model(self) + return + + def new_model(self): + if self.NZ != 1: + print("ERROR: cannot train model on 3D data") + return + + # train model + image_names = self.get_files()[0] + self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set( + image_names) + TW = guiparts.TrainWindow(self, models.MODEL_NAMES) + train = TW.exec_() + if train: + self.logger.info( + f"training with {[os.path.split(f)[1] for f in self.train_files]}") + self.train_model(restore=restore, normalize_params=normalize_params) + else: + print("GUI_INFO: training cancelled") + + def train_model(self, restore=None, normalize_params=None): + from cellpose.models import normalize_default + if normalize_params is None: + normalize_params = copy.deepcopy(normalize_default) + if self.training_params["model_index"] < len(models.MODEL_NAMES): + model_type = models.MODEL_NAMES[self.training_params["model_index"]] + self.logger.info(f"training new model starting at model {model_type}") + else: + model_type = None + self.logger.info(f"training new model starting from scratch") + self.current_model = model_type + self.channels = self.training_params["channels"] + + self.logger.info( + f"training with chan = {self.ChannelChoose[0].currentText()}, chan2 = {self.ChannelChoose[1].currentText()}" + ) + + self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), + model_type=model_type) + self.SizeButton.setEnabled(False) + save_path = os.path.dirname(self.filename) + + print("GUI_INFO: name of new model: " + self.training_params["model_name"]) + print(f"GUI_INFO: SGD activated: {self.training_params['SGD']}") + self.new_model_path, train_losses = train.train_seg( + self.model.net, train_data=self.train_data, train_labels=self.train_labels, + channels=self.channels, normalize=normalize_params, min_train_masks=0, + save_path=save_path, nimg_per_epoch=max(8, len(self.train_data)), + learning_rate=self.training_params["learning_rate"], + weight_decay=self.training_params["weight_decay"], + n_epochs=self.training_params["n_epochs"], + SGD=self.training_params["SGD"], + model_name=self.training_params["model_name"])[:2] + # save train losses + np.save(str(self.new_model_path) + "_train_losses.npy", train_losses) + # run model on next image + io._add_model(self, self.new_model_path) + diam_labels = self.model.net.diam_labels.item() #.copy() + self.new_model_ind = len(self.model_strings) + self.autorun = True + channels = self.channels.copy() + self.clear_all() + # keep same channels + self.ChannelChoose[0].setCurrentIndex(channels[0]) + self.ChannelChoose[1].setCurrentIndex(channels[1]) + self.diameter = diam_labels + self.Diameter.setText("%0.2f" % self.diameter) + self.logger.info(f">>>> diameter set to diam_labels ( = {diam_labels: 0.3f} )") + self.restore = restore + self.set_normalize_params(normalize_params) + self.get_next_image(load_seg=False) + + self.compute_segmentation(custom=True) + self.logger.info( + f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!" + ) + + def compute_restore(self): + if self.restore: + self.logger.info(f"running image restoration {self.restore}") + if self.restore != "filter": + rstr = self.restore.split("_") + model_type = rstr[0] + if len(rstr) > 1: + dset = rstr[1] + if dset == "cyto3": + self.DenoiseChoose.setCurrentIndex(0) + else: + self.DenoiseChoose.setCurrentIndex(1) + if "upsample" in self.restore: + i = self.DenoiseChoose.currentIndex() + diam_up = 30. if i==0 or i==1 else 17. + print(diam_up, self.ratio) + self.Diameter.setText(str(diam_up / self.ratio)) + self.compute_denoise_model(model_type=model_type) + else: + self.compute_saturation() + + def get_thresholds(self): + try: + flow_threshold = float(self.flow_threshold.text()) + cellprob_threshold = float(self.cellprob_threshold.text()) + if flow_threshold == 0.0 or self.NZ > 1: + flow_threshold = None + return flow_threshold, cellprob_threshold + except Exception as e: + print( + "flow threshold or cellprob threshold not a valid number, setting to defaults" + ) + self.flow_threshold.setText("0.4") + self.cellprob_threshold.setText("0.0") + return 0.4, 0.0 + + def compute_cprob(self): + if self.recompute_masks: + flow_threshold, cellprob_threshold = self.get_thresholds() + if flow_threshold is None: + self.logger.info( + "computing masks with cell prob=%0.3f, no flow error threshold" % + (cellprob_threshold)) + else: + self.logger.info( + "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" % + (cellprob_threshold, flow_threshold)) + maski = dynamics.resize_and_compute_masks( + self.flows[4][:-1], self.flows[4][-1], p=self.flows[3].copy(), + cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, + resize=self.cellpix.shape[-2:])[0] + + self.masksOn = True + if not self.OCheckBox.isChecked(): + self.MCheckBox.setChecked(True) + if maski.ndim < 3: + maski = maski[np.newaxis, ...] + self.logger.info("%d cells found" % (len(np.unique(maski)[1:]))) + io._masks_to_gui(self, maski, outlines=None) + self.show() + + def compute_denoise_model(self, model_type=None): + self.progress.setValue(0) + try: + tic = time.time() + nstr = self.DenoiseChoose.currentText() + nstr.replace("-", "") + self.clear_restore() + model_name = model_type + "_" + nstr + print(model_name) + # denoising model + self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(), + model_type=model_name) + self.progress.setValue(10) + diam_up = 30. if "cyto" in model_name else 17. + + # params + channels = self.get_channels() + self.diameter = float(self.Diameter.text()) + normalize_params = self.get_normalize_params() + print("GUI_INFO: channels: ", channels) + print("GUI_INFO: normalize_params: ", normalize_params) + print("GUI_INFO: diameter (before upsampling): ", self.diameter) + + data = self.stack.copy() + print(data.shape) + self.Ly, self.Lx = data.shape[-3:-1] + if "upsample" in model_name: + # get upsampling factor + if self.diameter >= diam_up: + print( + f"GUI_ERROR: cannot upsample, already set to pixel diameter >= {diam_up}" + ) + self.progress.setValue(0) + return + self.ratio = diam_up / self.diameter + print( + "GUI_WARNING: upsampling image, this will also duplicate mask layer and resize it, will use more RAM" + ) + print( + f"GUI_INFO: upsampling image to {diam_up} pixel diameter ({self.ratio:0.2f} times)" + ) + self.Lyr, self.Lxr = int(self.Ly * self.ratio), int(self.Lx * + self.ratio) + self.Ly0, self.Lx0 = self.Ly, self.Lx + # moved resize into eval + #data = resize_image(data, Ly=self.Lyr, Lx=self.Lxr) + #self.diameter = diam_up + #self.Diameter.setText(str(diam_up)) + else: + self.Lyr, self.Lxr = self.Ly, self.Lx + self.Ly0, self.Lx0 = self.Ly, self.Lx + diam_up = self.diameter + + img_norm = self.denoise_model.eval(data, channels=channels, z_axis=0, + channel_axis=3, diameter=self.diameter, + normalize=normalize_params) + print(img_norm.shape) + self.diameter = diam_up + self.Diameter.setText(str(diam_up)) + + if img_norm.ndim == 2: + img_norm = img_norm[:, :, np.newaxis] + if img_norm.ndim == 3: + img_norm = img_norm[np.newaxis, ...] + + self.progress.setValue(100) + self.logger.info(f"{model_name} finished in %0.3f sec" % + (time.time() - tic)) + + # compute saturation + percentile = normalize_params["percentile"] + img_norm_min = img_norm.min() + img_norm_max = img_norm.max() + chan = [0] if channels[0] == 0 else [channels[0] - 1, channels[1] - 1] + self.saturation = [[], [], []] + for c in range(img_norm.shape[-1]): + if np.ptp(img_norm[..., c]) > 1e-3: + img_norm[..., c] -= img_norm_min + img_norm[..., c] /= (img_norm_max - img_norm_min) + for z in range(self.NZ): + x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) * 255. + x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) * 255. + self.saturation[chan[c]].append([x01, x99]) + notchan = np.ones(3, "bool") + notchan[np.array(chan)] = False + notchan = np.nonzero(notchan)[0] + for c in notchan: + for z in range(self.NZ): + self.saturation[c].append([0, 255.]) + + img_norm *= 255. + self.autobtn.setChecked(True) + + # assign to denoised channels + self.stack_filtered = np.zeros( + (self.NZ, self.Lyr, self.Lxr, self.stack.shape[-1]), "float32") + for i, c in enumerate(chan[:img_norm.shape[-1]]): + for z in range(self.NZ): + self.stack_filtered[z, :, :, c] = img_norm[z, :, :, i] + + # make upsampled masks + if model_type == "upsample": + self.cellpix_orig = self.cellpix.copy() + self.outpix_orig = self.outpix.copy() + self.cellpix_resize = cv2.resize( + self.cellpix_orig[0], (self.Lxr, self.Lyr), + interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :] + outlines = masks_to_outlines(self.cellpix_resize[0])[np.newaxis, :, :] + self.outpix_resize = outlines * self.cellpix_resize + + self.restore = model_name + + # draw plot + if model_type == "upsample": + self.resize = True + else: + self.resize = False + self.draw_layer() + self.update_layer() + self.update_scale() + # if denoised in grayscale, show in grayscale + if channels[0] == 0: + self.RGBDropDown.setCurrentIndex(4) + + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).setEnabled(True) + self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1) + + self.update_plot() + + except Exception as e: + print("ERROR: %s" % e) + + def compute_segmentation(self, custom=False, model_name=None, load_model=True): + self.progress.setValue(0) + try: + tic = time.time() + self.clear_all() + self.flows = [[], [], []] + if load_model: + self.initialize_model(model_name=model_name, custom=custom) + self.progress.setValue(10) + do_3D = self.load_3D + stitch_threshold = float(self.stitch_threshold.text()) if not isinstance( + self.stitch_threshold, float) else self.stitch_threshold + anisotropy = float(self.anisotropy.text()) if not isinstance( + self.anisotropy, float) else self.anisotropy + flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance( + self.flow3D_smooth, float) else self.flow3D_smooth + min_size = int(self.min_size.text()) if not isinstance( + self.min_size, int) else self.min_size + resample = self.resample.isChecked() if not isinstance( + self.resample, bool) else self.resample + + do_3D = False if stitch_threshold > 0. else do_3D + + channels = self.get_channels() + if self.restore is not None and self.restore != "filter": + data = self.stack_filtered.copy().squeeze() + else: + data = self.stack.copy().squeeze() + flow_threshold, cellprob_threshold = self.get_thresholds() + self.diameter = float(self.Diameter.text()) + niter = max(0, int(self.niter.text())) + niter = None if niter == 0 else niter + normalize_params = self.get_normalize_params() + print(normalize_params) + try: + masks, flows = self.model.eval( + data, channels=channels, diameter=self.diameter, + cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, do_3D=do_3D, niter=niter, + normalize=normalize_params, stitch_threshold=stitch_threshold, + anisotropy=anisotropy, resample=resample, flow3D_smooth=flow3D_smooth, + min_size=min_size, + progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2] + except Exception as e: + print("NET ERROR: %s" % e) + self.progress.setValue(0) + return + + self.progress.setValue(75) + + # convert flows to uint8 and resize to original image size + flows_new = [] + flows_new.append(flows[0].copy()) # RGB flow + flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) * + 255).astype("uint8")) # cellprob + if self.load_3D: + if stitch_threshold == 0.: + flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8")) + else: + flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8")) + + if not self.load_3D: + if self.restore and "upsample" in self.restore: + self.Ly, self.Lx = self.Lyr, self.Lxr + + if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx): + self.flows = [] + for j in range(len(flows_new)): + self.flows.append( + resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx, + interpolation=cv2.INTER_NEAREST)) + else: + self.flows = flows_new + else: + if not resample: + self.flows = [] + Lz, Ly, Lx = self.NZ, self.Ly, self.Lx + Lz0, Ly0, Lx0 = flows_new[0].shape[:3] + print("GUI_INFO: resizing flows to original image size") + for j in range(len(flows_new)): + flow0 = flows_new[j] + if Ly0 != Ly: + flow0 = resize_image(flow0, Ly=Ly, Lx=Lx, + no_channels=flow0.ndim==3, + interpolation=cv2.INTER_NEAREST) + if Lz0 != Lz: + flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1), + Ly=Lz, Lx=Lx, + no_channels=flow0.ndim==3, + interpolation=cv2.INTER_NEAREST), 0, 1) + self.flows.append(flow0) + else: + self.flows = flows_new + + # add first axis + if self.NZ == 1: + masks = masks[np.newaxis, ...] + self.flows = [ + self.flows[n][np.newaxis, ...] for n in range(len(self.flows)) + ] + + self.logger.info("%d cells found with model in %0.3f sec" % + (len(np.unique(masks)[1:]), time.time() - tic)) + self.progress.setValue(80) + z = 0 + + io._masks_to_gui(self, masks, outlines=None) + self.masksOn = True + self.MCheckBox.setChecked(True) + self.progress.setValue(100) + if self.restore != "filter" and self.restore is not None: + self.compute_saturation() + if not do_3D and not stitch_threshold > 0: + self.recompute_masks = True + else: + self.recompute_masks = False + except Exception as e: + print("ERROR: %s" % e) diff --git a/cellpose/gui/gui3d.py b/cellpose/gui/gui3d.py new file mode 100644 index 0000000000000000000000000000000000000000..956d896595636806e5759598a3337e21558914ec --- /dev/null +++ b/cellpose/gui/gui3d.py @@ -0,0 +1,692 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import sys, os, pathlib, warnings, datetime, time + +from qtpy import QtGui, QtCore +from superqt import QRangeSlider +from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox +import pyqtgraph as pg + +import numpy as np +from scipy.stats import mode +import cv2 + +from . import guiparts, menus, io +from .. import models, core, dynamics, version +from ..utils import download_url_to_file, masks_to_outlines, diameters +from ..io import get_image_files, imsave, imread +from ..transforms import resize_image, normalize99 #fixed import +from ..plot import disk +from ..transforms import normalize99_tile, smooth_sharpen_img +from .gui import MainW + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + + +def avg3d(C): + """ smooth value of c across nearby points + (c is center of grid directly below point) + b -- a -- b + a -- c -- a + b -- a -- b + """ + Ly, Lx = C.shape + # pad T by 2 + T = np.zeros((Ly + 2, Lx + 2), "float32") + M = np.zeros((Ly, Lx), "float32") + T[1:-1, 1:-1] = C.copy() + y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int), + indexing="ij") + y += 1 + x += 1 + a = 1. / 2 #/(z**2 + 1)**0.5 + b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5 + c = 1. + M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] + + c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] + + b * T[y + 1, x + 1]) + M /= 4 * a + 4 * b + c + return M + + +def interpZ(mask, zdraw): + """ find nearby planes and average their values using grid of points + zfill is in ascending order + """ + ifill = np.ones(mask.shape[0], "bool") + zall = np.arange(0, mask.shape[0], 1, int) + ifill[zdraw] = False + zfill = zall[ifill] + zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1] + zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")] + for k, z in enumerate(zfill): + Z = zupper[k] - zlower[k] + zl = (z - zlower[k]) / Z + plower = avg3d(mask[zlower[k]]) * (1 - zl) + pupper = avg3d(mask[zupper[k]]) * zl + mask[z] = (plower + pupper) > 0.33 + #Ml, norml = avg3d(mask[zlower[k]], zl) + #Mu, normu = avg3d(mask[zupper[k]], 1-zl) + #mask[z] = (Ml + Mu) / (norml + normu) > 0.5 + return mask, zfill + + +def run(image=None): + from ..io import logger_setup + logger, log_file = logger_setup() + # Always start by initializing Qt (only once per application) + warnings.filterwarnings("ignore") + app = QApplication(sys.argv) + icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") + guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png") + style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy") + if not icon_path.is_file(): + cp_dir = pathlib.Path.home().joinpath(".cellpose") + cp_dir.mkdir(exist_ok=True) + print("downloading logo") + download_url_to_file( + "https://www.cellpose.org/static/images/cellpose_transparent.png", + icon_path, progress=True) + if not guip_path.is_file(): + print("downloading help window image") + download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png", + guip_path, progress=True) + icon_path = str(icon_path.resolve()) + app_icon = QtGui.QIcon() + app_icon.addFile(icon_path, QtCore.QSize(16, 16)) + app_icon.addFile(icon_path, QtCore.QSize(24, 24)) + app_icon.addFile(icon_path, QtCore.QSize(32, 32)) + app_icon.addFile(icon_path, QtCore.QSize(48, 48)) + app_icon.addFile(icon_path, QtCore.QSize(64, 64)) + app_icon.addFile(icon_path, QtCore.QSize(256, 256)) + app.setWindowIcon(app_icon) + app.setStyle("Fusion") + app.setPalette(guiparts.DarkPalette()) + #app.setStyleSheet("QLineEdit { color: yellow }") + + # models.download_model_weights() # does not exist + MainW_3d(image=image, logger=logger) + ret = app.exec_() + sys.exit(ret) + + +class MainW_3d(MainW): + + def __init__(self, image=None, logger=None): + # MainW init + MainW.__init__(self, image=image, logger=logger) + + # add gradZ view + self.ViewDropDown.insertItem(3, "gradZ") + + # turn off single stroke + self.SCheckBox.setChecked(False) + + ### add orthoviews and z-bar + # ortho crosshair lines + self.vLine = pg.InfiniteLine(angle=90, movable=False) + self.hLine = pg.InfiniteLine(angle=0, movable=False) + self.vLineOrtho = [ + pg.InfiniteLine(angle=90, movable=False), + pg.InfiniteLine(angle=90, movable=False) + ] + self.hLineOrtho = [ + pg.InfiniteLine(angle=0, movable=False), + pg.InfiniteLine(angle=0, movable=False) + ] + self.make_orthoviews() + + # z scrollbar underneath + self.scroll = QScrollBar(QtCore.Qt.Horizontal) + self.scroll.setMaximum(10) + self.scroll.valueChanged.connect(self.move_in_Z) + self.lmain.addWidget(self.scroll, 40, 9, 1, 30) + + b = 22 + + label = QLabel("stitch threshold:") + label.setToolTip( + "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 0, 1, 4) + self.stitch_threshold = QLineEdit() + self.stitch_threshold.setText("0.0") + self.stitch_threshold.setFixedWidth(30) + self.stitch_threshold.setFont(self.medfont) + self.stitch_threshold.setToolTip( + "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)" + ) + self.segBoxG.addWidget(self.stitch_threshold, b, 4, 1, 1) + + label = QLabel("flow3D_smooth:") + label.setToolTip( + "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 5, 1, 3) + self.flow3D_smooth = QLineEdit() + self.flow3D_smooth.setText("0.0") + self.flow3D_smooth.setFixedWidth(30) + self.flow3D_smooth.setFont(self.medfont) + self.flow3D_smooth.setToolTip( + "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)" + ) + self.segBoxG.addWidget(self.flow3D_smooth, b, 8, 1, 1) + + b+=1 + label = QLabel("anisotropy:") + label.setToolTip( + "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 0, 1, 4) + self.anisotropy = QLineEdit() + self.anisotropy.setText("1.0") + self.anisotropy.setFixedWidth(30) + self.anisotropy.setFont(self.medfont) + self.anisotropy.setToolTip( + "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)" + ) + self.segBoxG.addWidget(self.anisotropy, b, 4, 1, 1) + + self.resample = QCheckBox("resample") + self.resample.setToolTip("reample before creating masks; if diameter > 30 resample will use more CPU+GPU memory (see docs for more details)") + self.resample.setFont(self.medfont) + self.resample.setChecked(True) + self.segBoxG.addWidget(self.resample, b, 5, 1, 4) + + b+=1 + label = QLabel("min_size:") + label.setToolTip( + "all masks less than this size in pixels (volume) will be removed" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 0, 1, 4) + self.min_size = QLineEdit() + self.min_size.setText("15") + self.min_size.setFixedWidth(50) + self.min_size.setFont(self.medfont) + self.min_size.setToolTip( + "all masks less than this size in pixels (volume) will be removed" + ) + self.segBoxG.addWidget(self.min_size, b, 4, 1, 3) + + b += 1 + self.orthobtn = QCheckBox("ortho") + self.orthobtn.setToolTip("activate orthoviews with 3D image") + self.orthobtn.setFont(self.medfont) + self.orthobtn.setChecked(False) + self.l0.addWidget(self.orthobtn, b, 0, 1, 2) + self.orthobtn.toggled.connect(self.toggle_ortho) + + label = QLabel("dz:") + label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + label.setFont(self.medfont) + self.l0.addWidget(label, b, 2, 1, 1) + self.dz = 10 + self.dzedit = QLineEdit() + self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.dzedit.setText(str(self.dz)) + self.dzedit.returnPressed.connect(self.update_ortho) + self.dzedit.setFixedWidth(40) + self.dzedit.setFont(self.medfont) + self.l0.addWidget(self.dzedit, b, 3, 1, 2) + + label = QLabel("z-aspect:") + label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + label.setFont(self.medfont) + self.l0.addWidget(label, b, 5, 1, 2) + self.zaspect = 1.0 + self.zaspectedit = QLineEdit() + self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.zaspectedit.setText(str(self.zaspect)) + self.zaspectedit.returnPressed.connect(self.update_ortho) + self.zaspectedit.setFixedWidth(40) + self.zaspectedit.setFont(self.medfont) + self.l0.addWidget(self.zaspectedit, b, 7, 1, 2) + + b += 1 + # add z position underneath + self.currentZ = 0 + label = QLabel("Z:") + label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(label, b, 5, 1, 2) + self.zpos = QLineEdit() + self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.zpos.setText(str(self.currentZ)) + self.zpos.returnPressed.connect(self.update_ztext) + self.zpos.setFixedWidth(40) + self.zpos.setFont(self.medfont) + self.l0.addWidget(self.zpos, b, 7, 1, 2) + + # if called with image, load it + if image is not None: + self.filename = image + io._load_image(self, self.filename, load_3D=True) + + self.load_3D = True + + def add_mask(self, points=None, color=(100, 200, 50), dense=True): + # points is list of strokes + + points_all = np.concatenate(points, axis=0) + + # loop over z values + median = [] + zdraw = np.unique(points_all[:, 0]) + zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int) + zmin = zdraw.min() + pix = np.zeros((2, 0), "uint16") + mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool") + k = 0 + for z in zdraw: + ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros( + 0, "int"), np.zeros(0, "int") + for stroke in points: + stroke = np.concatenate(stroke, axis=0).reshape(-1, 4) + iz = stroke[:, 0] == z + vr = stroke[iz, 1] + vc = stroke[iz, 2] + if iz.sum() > 0: + # get points inside drawn points + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8") + pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), + axis=-1)[:, np.newaxis, :] + mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) + ar, ac = np.nonzero(mask) + ar, ac = ar + vr.min() - 2, ac + vc.min() - 2 + # get dense outline + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 + # concatenate all points + ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac)))) + # if these pixels are overlapping with another cell, reassign them + ioverlap = self.cellpix[z][ar, ac] > 0 + if (~ioverlap).sum() < 8: + print("ERROR: cell too small without overlaps, not drawn") + return None + elif ioverlap.sum() > 0: + ar, ac = ar[~ioverlap], ac[~ioverlap] + # compute outline of new mask + mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8") + mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2 + ars = np.concatenate((ars, ar), axis=0) + acs = np.concatenate((acs, ac), axis=0) + vrs = np.concatenate((vrs, vr), axis=0) + vcs = np.concatenate((vcs, vc), axis=0) + self.draw_mask(z, ars, acs, vrs, vcs, color) + + median.append(np.array([np.median(ars), np.median(acs)])) + mall[z - zmin, ars, acs] = True + pix = np.append(pix, np.vstack((ars, acs)), axis=-1) + + mall = mall[:, pix[0].min():pix[0].max() + 1, + pix[1].min():pix[1].max() + 1].astype("float32") + ymin, xmin = pix[0].min(), pix[1].min() + if len(zdraw) > 1: + mall, zfill = interpZ(mall, zdraw - zmin) + for z in zfill: + mask = mall[z].copy() + ar, ac = np.nonzero(mask) + ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0 + if (~ioverlap).sum() < 5: + print("WARNING: stroke on plane %d not included due to overlaps" % + z) + elif ioverlap.sum() > 0: + mask[ar[ioverlap], ac[ioverlap]] = 0 + ar, ac = ar[~ioverlap], ac[~ioverlap] + # compute outline of mask + outlines = masks_to_outlines(mask) + vr, vc = np.nonzero(outlines) + vr, vc = vr + ymin, vc + xmin + ar, ac = ar + ymin, ac + xmin + self.draw_mask(z + zmin, ar, ac, vr, vc, color) + + self.zdraw.append(zdraw) + + return median + + def move_in_Z(self): + if self.loaded: + self.currentZ = min(self.NZ, max(0, int(self.scroll.value()))) + self.zpos.setText(str(self.currentZ)) + self.update_plot() + self.draw_layer() + self.update_layer() + + def make_orthoviews(self): + self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], [] + for j in range(2): + self.pOrtho.append( + pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}", + border=[100, 100, 100], invertY=True, enableMouse=False)) + self.pOrtho[j].setMenuEnabled(False) + + self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self)) + self.imgOrtho[j].autoDownsample = False + + self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self)) + self.layerOrtho[j].setLevels([0., 255.]) + + #self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j] + self.pOrtho[j].addItem(self.imgOrtho[j]) + self.pOrtho[j].addItem(self.layerOrtho[j]) + self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False) + self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False) + + self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0) + self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0) + + def add_orthoviews(self): + self.yortho = self.Ly // 2 + self.xortho = self.Lx // 2 + if self.NZ > 1: + self.update_ortho() + + self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1) + self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1) + + qGraphicsGridLayout = self.win.ci.layout + qGraphicsGridLayout.setColumnStretchFactor(0, 2) + qGraphicsGridLayout.setColumnStretchFactor(1, 1) + qGraphicsGridLayout.setRowStretchFactor(0, 2) + qGraphicsGridLayout.setRowStretchFactor(1, 1) + + #self.p0.linkView(self.p0.YAxis, self.pOrtho[0]) + #self.p0.linkView(self.p0.XAxis, self.pOrtho[1]) + + self.pOrtho[0].setYRange(0, self.Lx) + self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + self.pOrtho[1].setXRange(0, self.Ly) + #self.pOrtho[0].setLimits(minXRange=self.dz*2+self.dz/3*2) + #self.pOrtho[1].setLimits(minYRange=self.dz*2+self.dz/3*2) + + self.p0.addItem(self.vLine, ignoreBounds=False) + self.p0.addItem(self.hLine, ignoreBounds=False) + self.p0.setYRange(0, self.Lx) + self.p0.setXRange(0, self.Ly) + + self.win.show() + self.show() + + #self.p0.linkView(self.p0.XAxis, self.pOrtho[1]) + + def remove_orthoviews(self): + self.win.removeItem(self.pOrtho[0]) + self.win.removeItem(self.pOrtho[1]) + self.p0.removeItem(self.vLine) + self.p0.removeItem(self.hLine) + self.win.show() + self.show() + + def update_crosshairs(self): + self.yortho = min(self.Ly - 1, max(0, int(self.yortho))) + self.xortho = min(self.Lx - 1, max(0, int(self.xortho))) + self.vLine.setPos(self.xortho) + self.hLine.setPos(self.yortho) + self.vLineOrtho[1].setPos(self.xortho) + self.hLineOrtho[1].setPos(self.zc) + self.vLineOrtho[0].setPos(self.zc) + self.hLineOrtho[0].setPos(self.yortho) + + def update_ortho(self): + if self.NZ > 1 and self.orthobtn.isChecked(): + dzcurrent = self.dz + self.dz = min(100, max(3, int(self.dzedit.text()))) + self.zaspect = max(0.01, min(100., float(self.zaspectedit.text()))) + self.dzedit.setText(str(self.dz)) + self.zaspectedit.setText(str(self.zaspect)) + if self.dz != dzcurrent: + self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + dztot = min(self.NZ, self.dz * 2) + y = self.yortho + x = self.xortho + z = self.currentZ + if dztot == self.NZ: + zmin, zmax = 0, self.NZ + else: + if z - self.dz < 0: + zmin = 0 + zmax = zmin + self.dz * 2 + elif z + self.dz >= self.NZ: + zmax = self.NZ + zmin = zmax - self.dz * 2 + else: + zmin, zmax = z - self.dz, z + self.dz + self.zc = z - zmin + self.update_crosshairs() + if self.view == 0 or self.view == 4: + for j in range(2): + if j == 0: + if self.view == 0: + image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy() + else: + image = self.stack_filtered[zmin:zmax, :, + x].transpose(1, 0, 2).copy() + else: + image = self.stack[ + zmin:zmax, + y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax, + y, :].copy() + if self.nchan == 1: + # show single channel + image = image[..., 0] + if self.color == 0: + self.imgOrtho[j].setImage(image, autoLevels=False, lut=None) + if self.nchan > 1: + levels = np.array([ + self.saturation[0][self.currentZ], + self.saturation[1][self.currentZ], + self.saturation[2][self.currentZ] + ]) + self.imgOrtho[j].setLevels(levels) + else: + self.imgOrtho[j].setLevels( + self.saturation[0][self.currentZ]) + elif self.color > 0 and self.color < 4: + if self.nchan > 1: + image = image[..., self.color - 1] + self.imgOrtho[j].setImage(image, autoLevels=False, + lut=self.cmap[self.color]) + if self.nchan > 1: + self.imgOrtho[j].setLevels( + self.saturation[self.color - 1][self.currentZ]) + else: + self.imgOrtho[j].setLevels( + self.saturation[0][self.currentZ]) + elif self.color == 4: + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") + self.imgOrtho[j].setImage(image, autoLevels=False, lut=None) + self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) + elif self.color == 5: + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") + self.imgOrtho[j].setImage(image, autoLevels=False, + lut=self.cmap[0]) + self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) + self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect) + self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect) + + else: + image = np.zeros((10, 10), "uint8") + self.imgOrtho[0].setImage(image, autoLevels=False, lut=None) + self.imgOrtho[0].setLevels([0.0, 255.0]) + self.imgOrtho[1].setImage(image, autoLevels=False, lut=None) + self.imgOrtho[1].setLevels([0.0, 255.0]) + + zrange = zmax - zmin + self.layer_ortho = [ + np.zeros((self.Ly, zrange, 4), "uint8"), + np.zeros((zrange, self.Lx, 4), "uint8") + ] + if self.masksOn: + for j in range(2): + if j == 0: + cp = self.cellpix[zmin:zmax, :, x].T + else: + cp = self.cellpix[zmin:zmax, y] + self.layer_ortho[j][..., :3] = self.cellcolors[cp, :] + self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8") + if self.selected > 0: + self.layer_ortho[j][cp == self.selected] = np.array( + [255, 255, 255, self.opacity]) + + if self.outlinesOn: + for j in range(2): + if j == 0: + op = self.outpix[zmin:zmax, :, x].T + else: + op = self.outpix[zmin:zmax, y] + self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8") + + for j in range(2): + self.layerOrtho[j].setImage(self.layer_ortho[j]) + self.win.show() + self.show() + + def toggle_ortho(self): + if self.orthobtn.isChecked(): + self.add_orthoviews() + else: + self.remove_orthoviews() + + def plot_clicked(self, event): + if event.button()==QtCore.Qt.LeftButton \ + and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\ + and not self.removing_region: + if event.double(): + try: + self.p0.setYRange(0, self.Ly + self.pr) + except: + self.p0.setYRange(0, self.Ly) + self.p0.setXRange(0, self.Lx) + elif self.loaded and not self.in_stroke: + if self.orthobtn.isChecked(): + items = self.win.scene().items(event.scenePos()) + for x in items: + if x == self.p0: + pos = self.p0.mapSceneToView(event.scenePos()) + x = int(pos.x()) + y = int(pos.y()) + if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx: + self.yortho = y + self.xortho = x + self.update_ortho() + + def update_plot(self): + super().update_plot() + if self.NZ > 1 and self.orthobtn.isChecked(): + self.update_ortho() + self.win.show() + self.show() + + def keyPressEvent(self, event): + if self.loaded: + if not (event.modifiers() & + (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier | + QtCore.Qt.AltModifier) or self.in_stroke): + updated = False + if len(self.current_point_set) > 0: + if event.key() == QtCore.Qt.Key_Return: + self.add_set() + if self.NZ > 1: + if event.key() == QtCore.Qt.Key_Left: + self.currentZ = max(0, self.currentZ - 1) + self.scroll.setValue(self.currentZ) + updated = True + elif event.key() == QtCore.Qt.Key_Right: + self.currentZ = min(self.NZ - 1, self.currentZ + 1) + self.scroll.setValue(self.currentZ) + updated = True + else: + nviews = self.ViewDropDown.count() - 1 + nviews += int( + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).isEnabled()) + if event.key() == QtCore.Qt.Key_X: + self.MCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Z: + self.OCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Left or event.key( + ) == QtCore.Qt.Key_A: + self.currentZ = max(0, self.currentZ - 1) + self.scroll.setValue(self.currentZ) + updated = True + elif event.key() == QtCore.Qt.Key_Right or event.key( + ) == QtCore.Qt.Key_D: + self.currentZ = min(self.NZ - 1, self.currentZ + 1) + self.scroll.setValue(self.currentZ) + updated = True + elif event.key() == QtCore.Qt.Key_PageDown: + self.view = (self.view + 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + elif event.key() == QtCore.Qt.Key_PageUp: + self.view = (self.view - 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + + # can change background or stroke size if cell not finished + if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W: + self.color = (self.color - 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_Down or event.key( + ) == QtCore.Qt.Key_S: + self.color = (self.color + 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_R: + if self.color != 1: + self.color = 1 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_G: + if self.color != 2: + self.color = 2 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_B: + if self.color != 3: + self.color = 3 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif (event.key() == QtCore.Qt.Key_Comma or + event.key() == QtCore.Qt.Key_Period): + count = self.BrushChoose.count() + gci = self.BrushChoose.currentIndex() + if event.key() == QtCore.Qt.Key_Comma: + gci = max(0, gci - 1) + else: + gci = min(count - 1, gci + 1) + self.BrushChoose.setCurrentIndex(gci) + self.brush_choose() + if not updated: + self.update_plot() + if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: + self.p0.keyPressEvent(event) + + def update_ztext(self): + zpos = self.currentZ + try: + zpos = int(self.zpos.text()) + except: + print("ERROR: zposition is not a number") + self.currentZ = max(0, min(self.NZ - 1, zpos)) + self.zpos.setText(str(self.currentZ)) + self.scroll.setValue(self.currentZ) diff --git a/cellpose/gui/guihelpwindowtext.html b/cellpose/gui/guihelpwindowtext.html new file mode 100644 index 0000000000000000000000000000000000000000..4ca5d209f1f39905ec60b049e288a3cf9fcff37b --- /dev/null +++ b/cellpose/gui/guihelpwindowtext.html @@ -0,0 +1,151 @@ + +

+ Main GUI mouse controls: +

+ +

Overlaps in masks are NOT allowed. If you + draw a mask on top of another mask, it is cropped so that it doesn’t overlap with the old mask. Masks in 2D + should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then + you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D + labelling will fill in planes that you have not labelled so that you do not have to as densely label. +

+

!NOTE!: The GUI automatically saves after + you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or + with Ctrl+S. The output file is in the same folder as the loaded image with _seg.npy appended. +

+ +

Bulk Mask Deletion + Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once. + Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete, + click the 'done' button to delete them. +
+
+ Alternatively, you can create a rectangular region to delete a regions of masks by clicking the + 'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete. + Once you have selected the masks you want to delete, click the 'done' button to delete them. +
+
+ At any point in the process, you can click the 'cancel' button to cancel the bulk deletion. +

+
+ +
+
+ FYI there are tooltips throughout the GUI (hover over text to see) +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Keyboard shortcutsDescription
=/+ button // - buttonzoom in // zoom out
CTRL+Zundo previously drawn mask/stroke
CTRL+Yundo remove mask
CTRL+0clear all masks
CTRL+Lload image (can alternatively drag and drop image)
CTRL+SSAVE MASKS IN IMAGE to _seg.npy file
CTRL+Ttrain model using _seg.npy files in folder +
CTRL+Pload _seg.npy file (note: it will load automatically with image if it exists)
CTRL+Mload masks file (must be same size as image with 0 for NO mask, and 1,2,3… for masks)
CTRL+Nsave masks as PNG
CTRL+Rsave ROIs to native ImageJ ROI format
CTRL+Fsave flows to image file
A/D or LEFT/RIGHTcycle through images in current directory
W/S or UP/DOWNchange color (RGB/gray/red/green/blue)
R / G / Btoggle between RGB and Red or Green or Blue
PAGE-UP / PAGE-DOWNchange to flows and cell prob views (if segmentation computed)
Xturn masks ON or OFF
Ztoggle outlines ON or OFF
, / .increase / decrease brush size for drawing masks
+

Segmentation options \ + (2D only)

+

SIZE: you can manually enter the \ + approximate diameter for your cells, or press β€œcalibrate” to let the model estimate it. The size is \ + represented by a disk at the bottom of the view window (can turn this disk of by unchecking \ + β€œscale disk on”).

+

use GPU: if you have specially \ + installed the cuda version of mxnet, then you can activate this, but it won’t give huge speedups when \ + running single 2D images in the GUI.

+

MODEL: there is a cytoplasm \ + model and a nuclei model, choose what you want to segment

+

CHAN TO SEG: this is the channel in \ + which the cytoplasm or nuclei exist

+

CHAN2 (OPT): if cytoplasm model \ + is chosen, then choose the nuclear channel for this option

+
diff --git a/cellpose/gui/guiparts.py b/cellpose/gui/guiparts.py new file mode 100644 index 0000000000000000000000000000000000000000..f84896828e4b913f412b4715c80d00c55476b67c --- /dev/null +++ b/cellpose/gui/guiparts.py @@ -0,0 +1,591 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +from qtpy import QtGui, QtCore, QtWidgets +from qtpy.QtGui import QPainter, QPixmap +from qtpy.QtWidgets import QApplication, QRadioButton, QWidget, QDialog, QButtonGroup, QSlider, QStyle, QStyleOptionSlider, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox +import pyqtgraph as pg +from pyqtgraph import functions as fn +from pyqtgraph import Point +import numpy as np +import pathlib, os + + +def stylesheet(): + return """ + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + QComboBox {color: white; + background-color: rgb(40,40,40);} + QComboBox::item:enabled { color: white; + background-color: rgb(40,40,40); + selection-color: white; + selection-background-color: rgb(50,100,50);} + QComboBox::item:!enabled { + background-color: rgb(40,40,40); + color: rgb(100,100,100); + } + QScrollArea > QWidget > QWidget + { + background: transparent; + border: none; + margin: 0px 0px 0px 0px; + } + + QGroupBox + { border: 1px solid white; color: rgb(255,255,255); + border-radius: 6px; + margin-top: 8px; + padding: 0px 0px;} + + QPushButton:pressed {Text-align: center; + background-color: rgb(150,50,150); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + QPushButton:!pressed {Text-align: center; + background-color: rgb(50,50,50); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + QPushButton:disabled {Text-align: center; + background-color: rgb(30,30,30); + border-color: white; + color:rgb(80,80,80);} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + + """ + + +class DarkPalette(QtGui.QPalette): + """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application. + (from pykilosort/kilosort4) + """ + + def __init__(self): + QtGui.QPalette.__init__(self) + self.setup() + + def setup(self): + self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40)) + self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24)) + self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47)) + self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47)) + self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0)) + self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218)) + self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218)) + self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0)) + self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text, + QtGui.QColor(128, 128, 128)) + self.setColor( + QtGui.QPalette.Disabled, + QtGui.QPalette.ButtonText, + QtGui.QColor(128, 128, 128), + ) + self.setColor( + QtGui.QPalette.Disabled, + QtGui.QPalette.WindowText, + QtGui.QColor(128, 128, 128), + ) + + +def create_channel_choose(): + # choose channel + ChannelChoose = [QComboBox(), QComboBox()] + ChannelLabels = [] + ChannelChoose[0].addItems(["gray", "red", "green", "blue"]) + ChannelChoose[1].addItems(["none", "red", "green", "blue"]) + cstr = ["chan to segment:", "chan2 (optional): "] + for i in range(2): + ChannelLabels.append(QLabel(cstr[i])) + if i == 0: + ChannelLabels[i].setToolTip( + "this is the channel in which the cytoplasm or nuclei exist \ + that you want to segment") + ChannelChoose[i].setToolTip( + "this is the channel in which the cytoplasm or nuclei exist \ + that you want to segment") + else: + ChannelLabels[i].setToolTip( + "if cytoplasm model is chosen, and you also have a \ + nuclear channel, then choose the nuclear channel for this option") + ChannelChoose[i].setToolTip( + "if cytoplasm model is chosen, and you also have a \ + nuclear channel, then choose the nuclear channel for this option") + + return ChannelChoose, ChannelLabels + + +class ModelButton(QPushButton): + + def __init__(self, parent, model_name, text): + super().__init__() + self.setEnabled(False) + self.setText(text) + self.setFont(parent.boldfont) + self.clicked.connect(lambda: self.press(parent)) + self.model_name = model_name if "cyto3" not in model_name else "cyto3" + + def press(self, parent): + parent.compute_segmentation(model_name=self.model_name) + + +class DenoiseButton(QPushButton): + + def __init__(self, parent, text): + super().__init__() + self.setEnabled(False) + self.model_type = text + self.setText(text) + self.setFont(parent.medfont) + self.clicked.connect(lambda: self.press(parent)) + + def press(self, parent): + if self.model_type == "filter": + parent.restore = "filter" + normalize_params = parent.get_normalize_params() + if (normalize_params["sharpen_radius"] == 0 and + normalize_params["smooth_radius"] == 0 and + normalize_params["tile_norm_blocksize"] == 0): + print( + "GUI_ERROR: no filtering settings on (use custom filter settings)") + parent.restore = None + return + parent.restore = self.model_type + parent.compute_saturation() + elif self.model_type != "none": + parent.compute_denoise_model(model_type=self.model_type) + else: + parent.clear_restore() + parent.set_restore_button() + + +class TrainWindow(QDialog): + + def __init__(self, parent, model_strings): + super().__init__(parent) + self.setGeometry(100, 100, 900, 550) + self.setWindowTitle("train settings") + self.win = QWidget(self) + self.l0 = QGridLayout() + self.win.setLayout(self.l0) + + yoff = 0 + qlabel = QLabel("train model w/ images + _seg.npy in current folder >>") + qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold)) + + qlabel.setAlignment(QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, yoff, 0, 1, 2) + + # choose initial model + yoff += 1 + self.ModelChoose = QComboBox() + self.ModelChoose.addItems(model_strings) + self.ModelChoose.addItems(["scratch"]) + self.ModelChoose.setFixedWidth(150) + self.ModelChoose.setCurrentIndex(parent.training_params["model_index"]) + self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1) + qlabel = QLabel("initial model: ") + qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, yoff, 0, 1, 1) + + # choose channels + self.ChannelChoose, self.ChannelLabels = create_channel_choose() + for i in range(2): + yoff += 1 + self.ChannelChoose[i].setFixedWidth(150) + self.ChannelChoose[i].setCurrentIndex( + parent.ChannelChoose[i].currentIndex()) + self.l0.addWidget(self.ChannelLabels[i], yoff, 0, 1, 1) + self.l0.addWidget(self.ChannelChoose[i], yoff, 1, 1, 1) + + # choose parameters + labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"] + self.edits = [] + yoff += 1 + for i, label in enumerate(labels): + qlabel = QLabel(label) + qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, i + yoff, 0, 1, 1) + self.edits.append(QLineEdit()) + self.edits[-1].setText(str(parent.training_params[label])) + self.edits[-1].setFixedWidth(200) + self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) + + yoff += 1 + use_SGD = "SGD" + self.useSGD = QCheckBox(f"{use_SGD}") + self.useSGD.setToolTip("use SGD, if unchecked uses AdamW (recommended learning_rate then 0.001)") + self.useSGD.setChecked(True) + self.l0.addWidget(self.useSGD, i+yoff, 1, 1, 1) + + yoff += len(labels) + + yoff += 1 + self.use_norm = QCheckBox(f"use restored/filtered image") + self.use_norm.setChecked(True) + #self.l0.addWidget(self.use_norm, yoff, 0, 2, 4) + + yoff += 2 + qlabel = QLabel( + "(to remove files, click cancel then remove \nfrom folder and reopen train window)" + ) + self.l0.addWidget(qlabel, yoff, 0, 2, 4) + + # click button + yoff += 3 + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(lambda: self.accept(parent)) + self.buttonBox.rejected.connect(self.reject) + self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4) + + # list files in folder + qlabel = QLabel("filenames") + qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) + self.l0.addWidget(qlabel, 0, 4, 1, 1) + qlabel = QLabel("# of masks") + qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) + self.l0.addWidget(qlabel, 0, 5, 1, 1) + + for i in range(10): + if i > len(parent.train_files) - 1: + break + elif i == 9 and len(parent.train_files) > 10: + label = "..." + nmasks = "..." + else: + label = os.path.split(parent.train_files[i])[-1] + nmasks = str(parent.train_labels[i].max()) + qlabel = QLabel(label) + self.l0.addWidget(qlabel, i + 1, 4, 1, 1) + qlabel = QLabel(nmasks) + qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, i + 1, 5, 1, 1) + + def accept(self, parent): + # set training params + parent.training_params = { + "model_index": self.ModelChoose.currentIndex(), + "learning_rate": float(self.edits[0].text()), + "weight_decay": float(self.edits[1].text()), + "n_epochs": int(self.edits[2].text()), + "model_name": self.edits[3].text(), + "SGD": True if self.useSGD.isChecked() else False, + "channels": [self.ChannelChoose[0].currentIndex(), + self.ChannelChoose[1].currentIndex()], + #"use_norm": True if self.use_norm.isChecked() else False, + } + self.done(1) + + +class ExampleGUI(QDialog): + + def __init__(self, parent=None): + super(ExampleGUI, self).__init__(parent) + self.setGeometry(100, 100, 1300, 900) + self.setWindowTitle("GUI layout") + self.win = QWidget(self) + layout = QGridLayout() + self.win.setLayout(layout) + guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png") + guip_path = str(guip_path.resolve()) + pixmap = QPixmap(guip_path) + label = QLabel(self) + label.setPixmap(pixmap) + pixmap.scaled + layout.addWidget(label, 0, 0, 1, 1) + + +class HelpWindow(QDialog): + + def __init__(self, parent=None): + super(HelpWindow, self).__init__(parent) + self.setGeometry(100, 50, 700, 1000) + self.setWindowTitle("cellpose help") + self.win = QWidget(self) + layout = QGridLayout() + self.win.setLayout(layout) + + text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html") + with open(str(text_file.resolve()), "r") as f: + text = f.read() + + label = QLabel(text) + label.setFont(QtGui.QFont("Arial", 8)) + label.setWordWrap(True) + layout.addWidget(label, 0, 0, 1, 1) + self.show() + + +class TrainHelpWindow(QDialog): + + def __init__(self, parent=None): + super(TrainHelpWindow, self).__init__(parent) + self.setGeometry(100, 50, 700, 300) + self.setWindowTitle("training instructions") + self.win = QWidget(self) + layout = QGridLayout() + self.win.setLayout(layout) + + text_file = pathlib.Path(__file__).parent.joinpath( + "guitrainhelpwindowtext.html") + with open(str(text_file.resolve()), "r") as f: + text = f.read() + + label = QLabel(text) + label.setFont(QtGui.QFont("Arial", 8)) + label.setWordWrap(True) + layout.addWidget(label, 0, 0, 1, 1) + self.show() + + +class ViewBoxNoRightDrag(pg.ViewBox): + + def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True, + invertY=False, enableMenu=True, name=None, invertX=False): + pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY, + enableMenu, name, invertX) + self.parent = parent + self.axHistoryPointer = -1 + + def keyPressEvent(self, ev): + """ + This routine should capture key presses in the current view box. + The following events are implemented: + +/= : moves forward in the zooming stack (if it exists) + - : moves backward in the zooming stack (if it exists) + + """ + ev.accept() + if ev.text() == "-": + self.scaleBy([1.1, 1.1]) + elif ev.text() in ["+", "="]: + self.scaleBy([0.9, 0.9]) + else: + ev.ignore() + + +class ImageDraw(pg.ImageItem): + """ + **Bases:** :class:`GraphicsObject ` + GraphicsObject displaying an image. Optimized for rapid update (ie video display). + This item displays either a 2D numpy array (height, width) or + a 3D array (height, width, RGBa). This array is optionally scaled (see + :func:`setLevels `) and/or colored + with a lookup table (see :func:`setLookupTable `) + before being displayed. + ImageItem is frequently used in conjunction with + :class:`HistogramLUTItem ` or + :class:`HistogramLUTWidget ` to provide a GUI + for controlling the levels and lookup table used to display the image. + """ + + sigImageChanged = QtCore.Signal() + + def __init__(self, image=None, viewbox=None, parent=None, **kargs): + super(ImageDraw, self).__init__() + #self.image=None + #self.viewbox=viewbox + self.levels = np.array([0, 255]) + self.lut = None + self.autoDownsample = False + self.axisOrder = "row-major" + self.removable = False + + self.parent = parent + #kernel[1,1] = 1 + self.setDrawKernel(kernel_size=self.parent.brush_size) + self.parent.current_stroke = [] + self.parent.in_stroke = False + + def mouseClickEvent(self, ev): + if (self.parent.masksOn or + self.parent.outlinesOn) and not self.parent.removing_region: + is_right_click = ev.button() == QtCore.Qt.RightButton + if self.parent.loaded \ + and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\ + and not self.parent.deleting_multiple: + if not self.parent.in_stroke: + ev.accept() + self.create_start(ev.pos()) + self.parent.stroke_appended = False + self.parent.in_stroke = True + self.drawAt(ev.pos(), ev) + else: + ev.accept() + self.end_stroke() + self.parent.in_stroke = False + elif not self.parent.in_stroke: + y, x = int(ev.pos().y()), int(ev.pos().x()) + if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx: + if ev.button() == QtCore.Qt.LeftButton and not ev.double(): + idx = self.parent.cellpix[self.parent.currentZ][y, x] + if idx > 0: + if ev.modifiers() & QtCore.Qt.ControlModifier: + # delete mask selected + self.parent.remove_cell(idx) + elif ev.modifiers() & QtCore.Qt.AltModifier: + self.parent.merge_cells(idx) + elif self.parent.masksOn and not self.parent.deleting_multiple: + self.parent.unselect_cell() + self.parent.select_cell(idx) + elif self.parent.deleting_multiple: + if idx in self.parent.removing_cells_list: + self.parent.unselect_cell_multi(idx) + self.parent.removing_cells_list.remove(idx) + else: + self.parent.select_cell_multi(idx) + self.parent.removing_cells_list.append(idx) + + elif self.parent.masksOn and not self.parent.deleting_multiple: + self.parent.unselect_cell() + + def mouseDragEvent(self, ev): + ev.ignore() + return + + def hoverEvent(self, ev): + #QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.CrossCursor) + if self.parent.in_stroke: + if self.parent.in_stroke: + # continue stroke if not at start + self.drawAt(ev.pos()) + if self.is_at_start(ev.pos()): + #self.parent.in_stroke = False + self.end_stroke() + else: + ev.acceptClicks(QtCore.Qt.RightButton) + #ev.acceptClicks(QtCore.Qt.LeftButton) + + def create_start(self, pos): + self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False, + pen=pg.mkPen(color=(255, 0, 0), + width=self.parent.brush_size), + size=max(3 * 2, + self.parent.brush_size * 1.8 * 2), + brush=None) + self.parent.p0.addItem(self.scatter) + + def is_at_start(self, pos): + thresh_out = max(6, self.parent.brush_size * 3) + thresh_in = max(3, self.parent.brush_size * 1.8) + # first check if you ever left the start + if len(self.parent.current_stroke) > 3: + stroke = np.array(self.parent.current_stroke) + dist = (((stroke[1:, 1:] - + stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5 + dist = dist.flatten() + #print(dist) + has_left = (dist > thresh_out).nonzero()[0] + if len(has_left) > 0: + first_left = np.sort(has_left)[0] + has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum() + if has_returned > 0: + return True + else: + return False + else: + return False + + def end_stroke(self): + self.parent.p0.removeItem(self.scatter) + if not self.parent.stroke_appended: + self.parent.strokes.append(self.parent.current_stroke) + self.parent.stroke_appended = True + self.parent.current_stroke = np.array(self.parent.current_stroke) + ioutline = self.parent.current_stroke[:, 3] == 1 + self.parent.current_point_set.append( + list(self.parent.current_stroke[ioutline])) + self.parent.current_stroke = [] + if self.parent.autosave: + self.parent.add_set() + if len(self.parent.current_point_set) and len( + self.parent.current_point_set[0]) > 0 and self.parent.autosave: + self.parent.add_set() + self.parent.in_stroke = False + + def tabletEvent(self, ev): + pass + #print(ev.device()) + #print(ev.pointerType()) + #print(ev.pressure()) + + def drawAt(self, pos, ev=None): + mask = self.strokemask + stroke = self.parent.current_stroke + pos = [int(pos.y()), int(pos.x())] + dk = self.drawKernel + kc = self.drawKernelCenter + sx = [0, dk.shape[0]] + sy = [0, dk.shape[1]] + tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]] + ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]] + kcent = kc.copy() + if tx[0] <= 0: + sx[0] = 0 + sx[1] = kc[0] + 1 + tx = sx + kcent[0] = 0 + if ty[0] <= 0: + sy[0] = 0 + sy[1] = kc[1] + 1 + ty = sy + kcent[1] = 0 + if tx[1] >= self.parent.Ly - 1: + sx[0] = dk.shape[0] - kc[0] - 1 + sx[1] = dk.shape[0] + tx[0] = self.parent.Ly - kc[0] - 1 + tx[1] = self.parent.Ly + kcent[0] = tx[1] - tx[0] - 1 + if ty[1] >= self.parent.Lx - 1: + sy[0] = dk.shape[1] - kc[1] - 1 + sy[1] = dk.shape[1] + ty[0] = self.parent.Lx - kc[1] - 1 + ty[1] = self.parent.Lx + kcent[1] = ty[1] - ty[0] - 1 + + ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1])) + ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1])) + self.image[ts] = mask[ss] + + for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)): + for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)): + iscent = np.logical_and(kx == kcent[0], ky == kcent[1]) + stroke.append([self.parent.currentZ, x, y, iscent]) + self.updateImage() + + def setDrawKernel(self, kernel_size=3): + bs = kernel_size + kernel = np.ones((bs, bs), np.uint8) + self.drawKernel = kernel + self.drawKernelCenter = [ + int(np.floor(kernel.shape[0] / 2)), + int(np.floor(kernel.shape[1] / 2)) + ] + onmask = 255 * kernel[:, :, np.newaxis] + offmask = np.zeros((bs, bs, 1)) + opamask = 100 * kernel[:, :, np.newaxis] + self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1) + self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1) diff --git a/cellpose/gui/guitrainhelpwindowtext.html b/cellpose/gui/guitrainhelpwindowtext.html new file mode 100644 index 0000000000000000000000000000000000000000..f198359113ba0e549f5174e564e549034fbacfb0 --- /dev/null +++ b/cellpose/gui/guitrainhelpwindowtext.html @@ -0,0 +1,25 @@ + + Check out this video to learn the process. +
    +
  1. Drag and drop an image from a folder of images with a similar style (like similar cell types).
  2. +
  3. Run the built-in models on one of the images using the "model zoo" and find the one that works best for your + data. Make sure that if you have a nuclear channel you have selected it for CHAN2. +
  4. +
  5. Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI + autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The + segmentation is saved in a "_seg.npy" file. +
  6. +
  7. Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T. +
  8. +
  9. Choose the pretrained model to start the training from (the model you used in #2), and type in the model + name that you want to use. The other parameters should work well in general for most data types. Then click + OK. +
  10. +
  11. The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder. + Next you can repeat #3-#5 as many times as is necessary. +
  12. +
  13. The trained model is available to use in the future in the GUI in the "custom model" section and is saved + in your image folder. +
  14. +
+
\ No newline at end of file diff --git a/cellpose/gui/io.py b/cellpose/gui/io.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc2e0cd2bc2ba02634c134638735042f411a71d --- /dev/null +++ b/cellpose/gui/io.py @@ -0,0 +1,711 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os, datetime, gc, warnings, glob, shutil, copy +from natsort import natsorted +import numpy as np +import cv2 +import tifffile +import logging +import fastremap + +from ..io import imread, imsave, outlines_to_text, add_model, remove_model, save_rois +from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models +from ..utils import masks_to_outlines, outlines_list + +try: + import qtpy + from qtpy.QtWidgets import QFileDialog + GUI = True +except: + GUI = False + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + + +def _init_model_list(parent): + MODEL_DIR.mkdir(parents=True, exist_ok=True) + parent.model_list_path = MODEL_LIST_PATH + parent.model_strings = get_user_models() + + +def _add_model(parent, filename=None, load_model=True): + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Add model to GUI") + filename = name[0] + add_model(filename) + fname = os.path.split(filename)[-1] + parent.ModelChooseC.addItems([fname]) + parent.model_strings.append(fname) + + for ind, model_string in enumerate(parent.model_strings[:-1]): + if model_string == fname: + _remove_model(parent, ind=ind + 1, verbose=False) + + parent.ModelChooseC.setCurrentIndex(len(parent.model_strings)) + if load_model: + parent.model_choose(custom=True) + + +def _remove_model(parent, ind=None, verbose=True): + if ind is None: + ind = parent.ModelChooseC.currentIndex() + if ind > 0: + ind -= 1 + parent.ModelChooseC.removeItem(ind + 1) + del parent.model_strings[ind] + # remove model from txt path + modelstr = parent.ModelChooseC.currentText() + remove_model(modelstr) + if len(parent.model_strings) > 0: + parent.ModelChooseC.setCurrentIndex(len(parent.model_strings)) + else: + parent.ModelChooseC.setCurrentIndex(0) + else: + print("ERROR: no model selected to delete") + + +def _get_train_set(image_names): + """ get training data and labels for images in current folder image_names""" + train_data, train_labels, train_files = [], [], [] + restore = None + normalize_params = normalize_default + for image_name_full in image_names: + image_name = os.path.splitext(image_name_full)[0] + label_name = None + if os.path.exists(image_name + "_seg.npy"): + dat = np.load(image_name + "_seg.npy", allow_pickle=True).item() + masks = dat["masks"].squeeze() + if masks.ndim == 2: + fastremap.renumber(masks, in_place=True) + label_name = image_name + "_seg.npy" + else: + print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2") + if "img_restore" in dat: + data = dat["img_restore"].squeeze() + restore = dat["restore"] + else: + data = imread(image_name_full) + normalize_params = dat[ + "normalize_params"] if "normalize_params" in dat else normalize_default + if label_name is not None: + train_files.append(image_name_full) + train_data.append(data) + train_labels.append(masks) + if restore: + print(f"GUI_INFO: using {restore} images (dat['img_restore'])") + return train_data, train_labels, train_files, restore, normalize_params + + +def _load_image(parent, filename=None, load_seg=True, load_3D=False): + """ load image with filename; if None, open QFileDialog """ + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Load image") + filename = name[0] + if filename == "": + return + manual_file = os.path.splitext(filename)[0] + "_seg.npy" + load_mask = False + if load_seg: + if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked(): + _load_seg(parent, manual_file, image=imread(filename), image_file=filename, + load_3D=load_3D) + return + elif parent.autoloadMasks.isChecked(): + mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext( + filename)[-1] + mask_file = os.path.splitext(filename)[ + 0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file + load_mask = True if os.path.isfile(mask_file) else False + try: + print(f"GUI_INFO: loading image: {filename}") + image = imread(filename) + parent.loaded = True + except Exception as e: + print("ERROR: images not compatible") + print(f"ERROR: {e}") + + if parent.loaded: + parent.reset() + parent.filename = filename + filename = os.path.split(parent.filename)[-1] + _initialize_images(parent, image, load_3D=load_3D) + parent.loaded = True + parent.enable_buttons() + if load_mask: + _load_masks(parent, filename=mask_file) + + +def _initialize_images(parent, image, load_3D=False): + """ format image for GUI + + assumes image is Z x channels x W x H + + """ + load_3D = parent.load_3D if load_3D is False else load_3D + parent.nchan = 3 + if image.ndim > 4: + image = image.squeeze() + if image.ndim > 4: + raise ValueError("cannot load 4D stack, reduce dimensions") + elif image.ndim == 1: + raise ValueError("cannot load 1D stack, increase dimensions") + if image.ndim == 4: + if not load_3D: + raise ValueError( + "cannot load 3D stack, run 'python -m cellpose --Zstack' for 3D GUI") + else: + # check if tiff is channels first + if image.shape[0] < 4 and image.shape[0] == min(image.shape) and image.shape[0] < image.shape[1]: + # tiff is channels x Z x W x H => Z x channels x W x H + image = image.transpose((1, 0, 2, 3)) + image = np.transpose(image, (0, 2, 3, 1)) + elif image.ndim == 3: + if not load_3D: + # assume smallest dimension is channels and put last + c = np.array(image.shape).argmin() + image = image.transpose(((c + 1) % 3, (c + 2) % 3, c)) + elif load_3D: + # assume smallest dimension is Z and put first if <3x max dim + shape = np.array(image.shape) + z = shape.argmin() + if shape[z] < shape.max() / 3: + image = image.transpose((z, (z + 1) % 3, (z + 2) % 3)) + image = image[..., np.newaxis] + elif image.ndim == 2: + if not load_3D: + image = image[..., np.newaxis] + else: + raise ValueError( + "cannot load 2D stack in 3D mode, run 'python -m cellpose' for 2D GUI") + if image.shape[-1] > 3: + print("WARNING: image has more than 3 channels, keeping only first 3") + image = image[..., :3] + elif image.shape[-1] == 2: + # fill in with blank channels to make 3 channels + shape = image.shape + image = np.concatenate( + (image, np.zeros((*shape[:-1], 3 - shape[-1]), dtype=np.uint8)), axis=-1) + parent.nchan = 2 + elif image.shape[-1] == 1: + parent.nchan = 1 + + parent.stack = image + print(f"GUI_INFO: image shape: {image.shape}") + if load_3D: + parent.NZ = len(parent.stack) + parent.scroll.setMaximum(parent.NZ - 1) + else: + parent.NZ = 1 + parent.stack = parent.stack[np.newaxis, ...] + + img_min = image.min() + img_max = image.max() + parent.stack = parent.stack.astype(np.float32) + parent.stack -= img_min + if img_max > img_min + 1e-3: + parent.stack /= (img_max - img_min) + parent.stack *= 255 + + if load_3D: + print("GUI_INFO: converted to float and normalized values to 0.0->255.0") + + del image + gc.collect() + + parent.imask = 0 + parent.Ly, parent.Lx = parent.stack.shape[-3:-1] + parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1] + parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8") + if hasattr(parent, "stack_filtered"): + parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1] + elif parent.restore and "upsample" in parent.restore: + parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx * + parent.ratio) + else: + parent.Lyr, parent.Lxr = parent.Ly, parent.Lx + parent.clear_all() + + if not hasattr(parent, "stack_filtered") and parent.restore: + print("GUI_INFO: no 'img_restore' found, applying current settings") + parent.compute_restore() + + if parent.autobtn.isChecked(): + if parent.restore is None or parent.restore != "filter": + print( + "GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)" + ) + parent.compute_saturation() + elif len(parent.saturation) != parent.NZ: + parent.saturation = [] + for r in range(3): + parent.saturation.append([]) + for n in range(parent.NZ): + parent.saturation[-1].append([0, 255]) + parent.sliders[r].setValue([0, 255]) + parent.compute_scale() + parent.track_changes = [] + + if load_3D: + parent.currentZ = int(np.floor(parent.NZ / 2)) + parent.scroll.setValue(parent.currentZ) + parent.zpos.setText(str(parent.currentZ)) + else: + parent.currentZ = 0 + + +def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False): + """ load *_seg.npy with filename; if None, open QFileDialog """ + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy") + filename = name[0] + try: + dat = np.load(filename, allow_pickle=True).item() + # check if there are keys in filename + dat["outlines"] + parent.loaded = True + except: + parent.loaded = False + print("ERROR: not NPY") + return + + parent.reset() + if image is None: + found_image = False + if "filename" in dat: + parent.filename = dat["filename"] + if os.path.isfile(parent.filename): + parent.filename = dat["filename"] + found_image = True + else: + imgname = os.path.split(parent.filename)[1] + root = os.path.split(filename)[0] + parent.filename = root + "/" + imgname + if os.path.isfile(parent.filename): + found_image = True + if found_image: + try: + print(parent.filename) + image = imread(parent.filename) + except: + parent.loaded = False + found_image = False + print("ERROR: cannot find image file, loading from npy") + if not found_image: + parent.filename = filename[:-8] + print(parent.filename) + if "img" in dat: + image = dat["img"] + else: + print("ERROR: no image file found and no image in npy") + return + else: + parent.filename = image_file + + parent.restore = None + parent.ratio = 1. + + if "normalize_params" in dat: + parent.restore = None if "restore" not in dat else dat["restore"] + print(f"GUI_INFO: restore: {parent.restore}") + parent.set_normalize_params(dat["normalize_params"]) + parent.set_restore_button() + + if "img_restore" in dat: + img = dat["img_restore"] + img_min = img.min() + img_max = img.max() + parent.stack_filtered = img.astype("float32") + parent.stack_filtered -= img_min + if img_max > img_min + 1e-3: + parent.stack_filtered /= (img_max - img_min) + parent.stack_filtered *= 255 + if parent.stack_filtered.ndim < 4: + parent.stack_filtered = parent.stack_filtered[np.newaxis, ...] + if parent.stack_filtered.ndim < 4: + parent.stack_filtered = parent.stack_filtered[..., np.newaxis] + shape = parent.stack_filtered.shape + if shape[-1] == 2: + if "chan_choose" in dat: + channels = np.array(dat["chan_choose"]) - 1 + img = np.zeros((*shape[:-1], 3), dtype="float32") + img[..., channels] = parent.stack_filtered + parent.stack_filtered = img + else: + parent.stack_filtered = np.concatenate( + (parent.stack_filtered, np.zeros( + (*shape[:-1], 1), dtype="float32")), axis=-1) + elif shape[-1] > 3: + parent.stack_filtered = parent.stack_filtered[..., :3] + + parent.restore = dat["restore"] + parent.ViewDropDown.model().item(parent.ViewDropDown.count() - + 1).setEnabled(True) + parent.view = parent.ViewDropDown.count() - 1 + if parent.restore and "upsample" in parent.restore: + print(parent.stack_filtered.shape, image.shape) + parent.ratio = dat["ratio"] + + parent.set_restore_button() + + _initialize_images(parent, image, load_3D=load_3D) + print(parent.stack.shape) + if "chan_choose" in dat: + parent.ChannelChoose[0].setCurrentIndex(dat["chan_choose"][0]) + parent.ChannelChoose[1].setCurrentIndex(dat["chan_choose"][1]) + + if "outlines" in dat: + if isinstance(dat["outlines"], list): + # old way of saving files + dat["outlines"] = dat["outlines"][::-1] + for k, outline in enumerate(dat["outlines"]): + if "colors" in dat: + color = dat["colors"][k] + else: + col_rand = np.random.randint(1000) + color = parent.colormap[col_rand, :3] + median = parent.add_mask(points=outline, color=color) + if median is not None: + parent.cellcolors = np.append(parent.cellcolors, + color[np.newaxis, :], axis=0) + parent.ncells += 1 + else: + if dat["masks"].min() == -1: + dat["masks"] += 1 + dat["outlines"] += 1 + parent.ncells = dat["masks"].max() + if "colors" in dat and len(dat["colors"]) == dat["masks"].max(): + colors = dat["colors"] + else: + colors = parent.colormap[:parent.ncells, :3] + + _masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors) + + parent.draw_layer() + if "est_diam" in dat: + parent.Diameter.setText("%0.1f" % dat["est_diam"]) + parent.diameter = dat["est_diam"] + parent.compute_scale() + + if "manual_changes" in dat: + parent.track_changes = dat["manual_changes"] + print("GUI_INFO: loaded in previous changes") + if "zdraw" in dat: + parent.zdraw = dat["zdraw"] + else: + parent.zdraw = [None for n in range(parent.ncells)] + parent.loaded = True + #print(f"GUI_INFO: {parent.ncells} masks found in {filename}") + else: + parent.clear_all() + + parent.ismanual = np.zeros(parent.ncells, bool) + if "ismanual" in dat: + if len(dat["ismanual"]) == parent.ncells: + parent.ismanual = dat["ismanual"] + + if "current_channel" in dat: + parent.color = (dat["current_channel"] + 2) % 5 + parent.RGBDropDown.setCurrentIndex(parent.color) + + if "flows" in dat: + parent.flows = dat["flows"] + try: + if parent.flows[0].shape[-3] != dat["masks"].shape[-2]: + Ly, Lx = dat["masks"].shape[-2:] + for i in range(len(parent.flows)): + parent.flows[i] = cv2.resize( + parent.flows[i].squeeze(), (Lx, Ly), + interpolation=cv2.INTER_NEAREST)[np.newaxis, ...] + if parent.NZ == 1: + parent.recompute_masks = True + else: + parent.recompute_masks = False + + except: + try: + if len(parent.flows[0]) > 0: + parent.flows = parent.flows[0] + except: + parent.flows = [[], [], [], [], [[]]] + parent.recompute_masks = False + + parent.enable_buttons() + parent.update_layer() + del dat + gc.collect() + + +def _load_masks(parent, filename=None): + """ load zeros-based masks (0=no cell, 1=cell 1, ...) """ + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)") + filename = name[0] + print(f"GUI_INFO: loading masks: {filename}") + masks = imread(filename) + outlines = None + if masks.ndim > 3: + # Z x nchannels x Ly x Lx + if masks.shape[-1] > 5: + parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2))) + outlines = masks[..., 1] + masks = masks[..., 0] + else: + parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2))) + masks = masks[..., 0] + elif masks.ndim == 3: + if masks.shape[-1] < 5: + masks = masks[np.newaxis, :, :, 0] + elif masks.ndim < 3: + masks = masks[np.newaxis, :, :] + # masks should be Z x Ly x Lx + if masks.shape[0] != parent.NZ: + print("ERROR: masks are not same depth (number of planes) as image stack") + return + + _masks_to_gui(parent, masks, outlines) + if parent.ncells > 0: + parent.draw_layer() + parent.toggle_mask_ops() + del masks + gc.collect() + parent.update_layer() + parent.update_plot() + + +def _masks_to_gui(parent, masks, outlines=None, colors=None): + """ masks loaded into GUI """ + # get unique values + shape = masks.shape + if len(fastremap.unique(masks)) != masks.max() + 1: + print("GUI_INFO: renumbering masks") + fastremap.renumber(masks, in_place=True) + outlines = None + masks = masks.reshape(shape) + if masks.ndim == 2: + outlines = None + masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype( + np.uint32) + if parent.restore and "upsample" in parent.restore: + parent.cellpix_resize = masks.copy() + parent.cellpix = parent.cellpix_resize.copy() + parent.cellpix_orig = cv2.resize( + masks.squeeze(), (parent.Lx0, parent.Ly0), + interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :] + parent.resize = True + else: + parent.cellpix = masks + if parent.cellpix.ndim == 2: + parent.cellpix = parent.cellpix[np.newaxis, :, :] + if parent.restore and "upsample" in parent.restore: + if parent.cellpix_resize.ndim == 2: + parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :] + if parent.cellpix_orig.ndim == 2: + parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :] + + print(f"GUI_INFO: {masks.max()} masks found") + + # get outlines + if outlines is None: # parent.outlinesOn + parent.outpix = np.zeros_like(parent.cellpix) + if parent.restore and "upsample" in parent.restore: + parent.outpix_orig = np.zeros_like(parent.cellpix_orig) + for z in range(parent.NZ): + outlines = masks_to_outlines(parent.cellpix[z]) + parent.outpix[z] = outlines * parent.cellpix[z] + if parent.restore and "upsample" in parent.restore: + outlines = masks_to_outlines(parent.cellpix_orig[z]) + parent.outpix_orig[z] = outlines * parent.cellpix_orig[z] + if z % 50 == 0 and parent.NZ > 1: + print("GUI_INFO: plane %d outlines processed" % z) + if parent.restore and "upsample" in parent.restore: + parent.outpix_resize = parent.outpix.copy() + else: + parent.outpix = outlines + if parent.restore and "upsample" in parent.restore: + parent.outpix_resize = parent.outpix.copy() + parent.outpix_orig = np.zeros_like(parent.cellpix_orig) + for z in range(parent.NZ): + outlines = masks_to_outlines(parent.cellpix_orig[z]) + parent.outpix_orig[z] = outlines * parent.cellpix_orig[z] + if z % 50 == 0 and parent.NZ > 1: + print("GUI_INFO: plane %d outlines processed" % z) + + if parent.outpix.ndim == 2: + parent.outpix = parent.outpix[np.newaxis, :, :] + if parent.restore and "upsample" in parent.restore: + if parent.outpix_resize.ndim == 2: + parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :] + if parent.outpix_orig.ndim == 2: + parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :] + + parent.ncells = parent.cellpix.max() + colors = parent.colormap[:parent.ncells, :3] if colors is None else colors + print("GUI_INFO: creating cellcolors and drawing masks") + parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors), + axis=0).astype(np.uint8) + if parent.ncells > 0: + parent.draw_layer() + parent.toggle_mask_ops() + parent.ismanual = np.zeros(parent.ncells, bool) + parent.zdraw = list(-1 * np.ones(parent.ncells, np.int16)) + + if hasattr(parent, "stack_filtered"): + parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1) + print("set denoised/filtered view") + else: + parent.ViewDropDown.setCurrentIndex(0) + + +def _save_png(parent): + """ save masks to png or tiff (if 3D) """ + filename = parent.filename + base = os.path.splitext(filename)[0] + if parent.NZ == 1: + if parent.cellpix[0].max() > 65534: + print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)") + imsave(base + "_cp_masks.tif", parent.cellpix[0]) + else: + print("GUI_INFO: saving 2D masks to png") + imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16)) + else: + print("GUI_INFO: saving 3D masks to tiff") + imsave(base + "_cp_masks.tif", parent.cellpix) + + +def _save_flows(parent): + """ save flows and cellprob to tiff """ + filename = parent.filename + base = os.path.splitext(filename)[0] + print("GUI_INFO: saving flows and cellprob to tiff") + if len(parent.flows) > 0: + imsave(base + "_cp_cellprob.tif", parent.flows[1]) + for i in range(3): + imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i]) + if len(parent.flows) > 2: + imsave(base + "_cp_flows.tif", parent.flows[2]) + print("GUI_INFO: saved flows and cellprob") + else: + print("ERROR: no flows or cellprob found") + +def _save_rois(parent): + """ save masks as rois in .zip file for ImageJ """ + filename = parent.filename + if parent.NZ == 1: + print( + f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.") + save_rois(parent.cellpix[0], parent.filename) + else: + print("ERROR: cannot save 3D outlines") + + +def _save_outlines(parent): + filename = parent.filename + base = os.path.splitext(filename)[0] + if parent.NZ == 1: + print( + "GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ" + ) + outlines = outlines_list(parent.cellpix[0]) + outlines_to_text(base, outlines) + else: + print("ERROR: cannot save 3D outlines") + + +def _save_sets_with_check(parent): + """ Save masks and update *_seg.npy file. Use this function when saving should be optional + based on the disableAutosave checkbox. Otherwise, use _save_sets """ + if not parent.disableAutosave.isChecked(): + _save_sets(parent) + + +def _save_sets(parent): + """ save masks to *_seg.npy. This function should be used when saving + is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check + """ + filename = parent.filename + base = os.path.splitext(filename)[0] + flow_threshold, cellprob_threshold = parent.get_thresholds() + if parent.NZ > 1: + dat = { + "outlines": + parent.outpix, + "colors": + parent.cellcolors[1:], + "masks": + parent.cellpix, + "current_channel": (parent.color - 2) % 5, + "filename": + parent.filename, + "flows": + parent.flows, + "zdraw": + parent.zdraw, + "model_path": + parent.current_model_path + if hasattr(parent, "current_model_path") else 0, + "flow_threshold": + flow_threshold, + "cellprob_threshold": + cellprob_threshold, + "normalize_params": + parent.get_normalize_params(), + "restore": + parent.restore, + "ratio": + parent.ratio, + "diameter": + parent.diameter + } + if parent.restore is not None: + dat["img_restore"] = parent.stack_filtered + np.save(base + "_seg.npy", dat) + else: + dat = { + "outlines": + parent.outpix.squeeze() if parent.restore is None or + not "upsample" in parent.restore else parent.outpix_resize.squeeze(), + "colors": + parent.cellcolors[1:], + "masks": + parent.cellpix.squeeze() if parent.restore is None or + not "upsample" in parent.restore else parent.cellpix_resize.squeeze(), + "chan_choose": [ + parent.ChannelChoose[0].currentIndex(), + parent.ChannelChoose[1].currentIndex() + ], + "filename": + parent.filename, + "flows": + parent.flows, + "ismanual": + parent.ismanual, + "manual_changes": + parent.track_changes, + "model_path": + parent.current_model_path + if hasattr(parent, "current_model_path") else 0, + "flow_threshold": + flow_threshold, + "cellprob_threshold": + cellprob_threshold, + "normalize_params": + parent.get_normalize_params(), + "restore": + parent.restore, + "ratio": + parent.ratio, + "diameter": + parent.diameter + } + if parent.restore is not None: + dat["img_restore"] = parent.stack_filtered + np.save(base + "_seg.npy", dat) + del dat + #print(parent.point_sets) + print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells, base + "_seg.npy")) diff --git a/cellpose/gui/make_train.py b/cellpose/gui/make_train.py new file mode 100644 index 0000000000000000000000000000000000000000..20cb9fc36381c1a133e860ca50db6d9c0cb54d78 --- /dev/null +++ b/cellpose/gui/make_train.py @@ -0,0 +1,104 @@ +import sys, os, argparse, glob, pathlib, time +import numpy as np +from natsort import natsorted +from tqdm import tqdm +from cellpose import utils, models, io, core, version_str, transforms + + +def main(): + parser = argparse.ArgumentParser(description='cellpose parameters') + + input_img_args = parser.add_argument_group("input image arguments") + input_img_args.add_argument('--dir', default=[], type=str, + help='folder containing data to run or train on.') + input_img_args.add_argument( + '--image_path', default=[], type=str, help= + 'if given and --dir not given, run on single image instead of folder (cannot train with this option)' + ) + input_img_args.add_argument( + '--look_one_level_down', action='store_true', + help='run processing on all subdirectories of current folder') + input_img_args.add_argument('--img_filter', default=[], type=str, + help='end string for images to run on') + input_img_args.add_argument( + '--channel_axis', default=None, type=int, + help='axis of image which corresponds to image channels') + input_img_args.add_argument('--z_axis', default=None, type=int, + help='axis of image which corresponds to Z dimension') + input_img_args.add_argument( + '--chan', default=0, type=int, help= + 'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s') + input_img_args.add_argument( + '--chan2', default=0, type=int, help= + 'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s' + ) + input_img_args.add_argument('--invert', action='store_true', + help='invert grayscale channel') + input_img_args.add_argument( + '--all_channels', action='store_true', help= + 'use all channels in image if using own model and images with special channels') + input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float, + help="anisotropy of volume in 3D") + + + # algorithm settings + algorithm_args = parser.add_argument_group("algorithm arguments") + algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0, + type=float, help='high-pass filtering radius. Default: %(default)s') + algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int, + help='tile normalization block size. Default: %(default)s') + algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int, + help='number of crops in XY to save per tiff. Default: %(default)s') + algorithm_args.add_argument('--crop_size', required=False, default=512, type=int, + help='size of random crop to save. Default: %(default)s') + + args = parser.parse_args() + + # find images + if len(args.img_filter) > 0: + imf = args.img_filter + else: + imf = None + + if len(args.dir) > 0: + image_names = io.get_image_files(args.dir, "_masks", imf=imf, + look_one_level_down=args.look_one_level_down) + dirname = args.dir + else: + if os.path.exists(args.image_path): + image_names = [args.image_path] + dirname = os.path.split(args.image_path)[0] + else: + raise ValueError(f"ERROR: no file found at {args.image_path}") + + np.random.seed(0) + nimg_per_tif = args.nimg_per_tif + crop_size = args.crop_size + os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True) + pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)] + npm = ["YX", "ZY", "ZX"] + for name in image_names: + name0 = os.path.splitext(os.path.split(name)[-1])[0] + img0 = io.imread(name) + img0 = transforms.convert_image(img0, channels=[args.chan, args.chan2], channel_axis=args.channel_axis, z_axis=args.z_axis) + for p in range(3): + img = img0.transpose(pm[p]).copy() + print(npm[p], img[0].shape) + Ly, Lx = img.shape[1:3] + imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]] + if args.anisotropy > 1.0 and p > 0: + imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx) + for k, img in enumerate(imgs): + if args.tile_norm: + img = transforms.normalize99_tile(img, blocksize=args.tile_norm) + if args.sharpen_radius: + img = transforms.smooth_sharpen_img(img, + sharpen_radius=args.sharpen_radius) + ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size) + lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size) + io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'), + img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze()) + + +if __name__ == '__main__': + main() diff --git a/cellpose/gui/menus.py b/cellpose/gui/menus.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7994fcbd35e21596c7f59cea5ceeabfd7ffe43 --- /dev/null +++ b/cellpose/gui/menus.py @@ -0,0 +1,148 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import qtpy +from qtpy.QtWidgets import QAction +from . import io +from .. import models + + +def mainmenu(parent): + main_menu = parent.menuBar() + file_menu = main_menu.addMenu("&File") + # load processed data + loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent) + loadImg.setShortcut("Ctrl+L") + loadImg.triggered.connect(lambda: io._load_image(parent)) + file_menu.addAction(loadImg) + + parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent, + checkable=True) + parent.autoloadMasks.setChecked(False) + file_menu.addAction(parent.autoloadMasks) + + parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent, + checkable=True) + parent.disableAutosave.setChecked(False) + file_menu.addAction(parent.disableAutosave) + + parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent) + parent.loadMasks.setShortcut("Ctrl+M") + parent.loadMasks.triggered.connect(lambda: io._load_masks(parent)) + file_menu.addAction(parent.loadMasks) + parent.loadMasks.setEnabled(False) + + loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent) + loadManual.setShortcut("Ctrl+P") + loadManual.triggered.connect(lambda: io._load_seg(parent)) + file_menu.addAction(loadManual) + + parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent) + parent.saveSet.setShortcut("Ctrl+S") + parent.saveSet.triggered.connect(lambda: io._save_sets(parent)) + file_menu.addAction(parent.saveSet) + parent.saveSet.setEnabled(False) + + parent.savePNG = QAction("Save masks as P&NG/tif", parent) + parent.savePNG.setShortcut("Ctrl+N") + parent.savePNG.triggered.connect(lambda: io._save_png(parent)) + file_menu.addAction(parent.savePNG) + parent.savePNG.setEnabled(False) + + parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent) + parent.saveOutlines.setShortcut("Ctrl+O") + parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent)) + file_menu.addAction(parent.saveOutlines) + parent.saveOutlines.setEnabled(False) + + parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ", + parent) + parent.saveROIs.setShortcut("Ctrl+R") + parent.saveROIs.triggered.connect(lambda: io._save_rois(parent)) + file_menu.addAction(parent.saveROIs) + parent.saveROIs.setEnabled(False) + + parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent) + parent.saveFlows.setShortcut("Ctrl+F") + parent.saveFlows.triggered.connect(lambda: io._save_flows(parent)) + file_menu.addAction(parent.saveFlows) + parent.saveFlows.setEnabled(False) + + +def editmenu(parent): + main_menu = parent.menuBar() + edit_menu = main_menu.addMenu("&Edit") + parent.undo = QAction("Undo previous mask/trace", parent) + parent.undo.setShortcut("Ctrl+Z") + parent.undo.triggered.connect(parent.undo_action) + parent.undo.setEnabled(False) + edit_menu.addAction(parent.undo) + + parent.redo = QAction("Undo remove mask", parent) + parent.redo.setShortcut("Ctrl+Y") + parent.redo.triggered.connect(parent.undo_remove_action) + parent.redo.setEnabled(False) + edit_menu.addAction(parent.redo) + + parent.ClearButton = QAction("Clear all masks", parent) + parent.ClearButton.setShortcut("Ctrl+0") + parent.ClearButton.triggered.connect(parent.clear_all) + parent.ClearButton.setEnabled(False) + edit_menu.addAction(parent.ClearButton) + + parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent) + parent.remcell.setShortcut("Ctrl+Click") + parent.remcell.triggered.connect(parent.remove_action) + parent.remcell.setEnabled(False) + edit_menu.addAction(parent.remcell) + + parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent) + parent.mergecell.setEnabled(False) + edit_menu.addAction(parent.mergecell) + + +def modelmenu(parent): + main_menu = parent.menuBar() + io._init_model_list(parent) + model_menu = main_menu.addMenu("&Models") + parent.addmodel = QAction("Add custom torch model to GUI", parent) + #parent.addmodel.setShortcut("Ctrl+A") + parent.addmodel.triggered.connect(parent.add_model) + parent.addmodel.setEnabled(True) + model_menu.addAction(parent.addmodel) + + parent.removemodel = QAction("Remove selected custom model from GUI", parent) + #parent.removemodel.setShortcut("Ctrl+R") + parent.removemodel.triggered.connect(parent.remove_model) + parent.removemodel.setEnabled(True) + model_menu.addAction(parent.removemodel) + + parent.newmodel = QAction("&Train new model with image+masks in folder", parent) + parent.newmodel.setShortcut("Ctrl+T") + parent.newmodel.triggered.connect(parent.new_model) + parent.newmodel.setEnabled(False) + model_menu.addAction(parent.newmodel) + + openTrainHelp = QAction("Training instructions", parent) + openTrainHelp.triggered.connect(parent.train_help_window) + model_menu.addAction(openTrainHelp) + + +def helpmenu(parent): + main_menu = parent.menuBar() + help_menu = main_menu.addMenu("&Help") + + openHelp = QAction("&Help with GUI", parent) + openHelp.setShortcut("Ctrl+H") + openHelp.triggered.connect(parent.help_window) + help_menu.addAction(openHelp) + + openGUI = QAction("&GUI layout", parent) + openGUI.setShortcut("Ctrl+G") + openGUI.triggered.connect(parent.gui_window) + help_menu.addAction(openGUI) + + openTrainHelp = QAction("Training instructions", parent) + openTrainHelp.triggered.connect(parent.train_help_window) + help_menu.addAction(openTrainHelp) diff --git a/cellpose/io.py b/cellpose/io.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f30f5cee3a34cded07a2a14be66ae7813829d3 --- /dev/null +++ b/cellpose/io.py @@ -0,0 +1,756 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os, datetime, gc, warnings, glob, shutil +from natsort import natsorted +import numpy as np +import cv2 +import tifffile +import logging, pathlib, sys +from tqdm import tqdm +from pathlib import Path +import re +from . import version_str +from roifile import ImagejRoi, roiwrite + +try: + from qtpy import QtGui, QtCore, Qt, QtWidgets + from qtpy.QtWidgets import QMessageBox + GUI = True +except: + GUI = False + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + +try: + import nd2 + ND2 = True +except: + ND2 = False + +try: + import nrrd + NRRD = True +except: + NRRD = False + +try: + from google.cloud import storage + SERVER_UPLOAD = True +except: + SERVER_UPLOAD = False + +io_logger = logging.getLogger(__name__) + +def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None): + cp_dir = pathlib.Path.home().joinpath(cp_path) + cp_dir.mkdir(exist_ok=True) + log_file = cp_dir.joinpath(logfile_name) + try: + log_file.unlink() + except: + print('creating new log file') + handlers = [logging.FileHandler(log_file),] + if stdout_file_replacement is not None: + handlers.append(logging.FileHandler(stdout_file_replacement)) + else: + handlers.append(logging.StreamHandler(sys.stdout)) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=handlers, + ) + logger = logging.getLogger(__name__) + logger.info(f"WRITING LOG OUTPUT TO {log_file}") + logger.info(version_str) + #logger.handlers[1].stream = sys.stdout + + return logger, log_file + + +from . import utils, plot, transforms + +# helper function to check for a path; if it doesn't exist, make it +def check_dir(path): + if not os.path.isdir(path): + os.mkdir(path) + + +def outlines_to_text(base, outlines): + with open(base + "_cp_outlines.txt", "w") as f: + for o in outlines: + xy = list(o.flatten()) + xy_str = ",".join(map(str, xy)) + f.write(xy_str) + f.write("\n") + + +def load_dax(filename): + ### modified from ZhuangLab github: + ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156 + + inf_filename = os.path.splitext(filename)[0] + ".inf" + if not os.path.exists(inf_filename): + io_logger.critical( + f"ERROR: no inf file found for dax file {filename}, cannot load dax without it" + ) + return None + + ### get metadata + image_height, image_width = None, None + # extract the movie information from the associated inf file + size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)") + length_re = re.compile(r"number of frames = ([\d]+)") + endian_re = re.compile(r" (big|little) endian") + + with open(inf_filename, "r") as inf_file: + lines = inf_file.read().split("\n") + for line in lines: + m = size_re.match(line) + if m: + image_height = int(m.group(2)) + image_width = int(m.group(1)) + m = length_re.match(line) + if m: + number_frames = int(m.group(1)) + m = endian_re.search(line) + if m: + if m.group(1) == "big": + bigendian = 1 + else: + bigendian = 0 + # set defaults, warn the user that they couldn"t be determined from the inf file. + if not image_height: + io_logger.warning("could not determine dax image size, assuming 256x256") + image_height = 256 + image_width = 256 + + ### load image + img = np.memmap(filename, dtype="uint16", + shape=(number_frames, image_height, image_width)) + if bigendian: + img = img.byteswap() + img = np.array(img) + + return img + + +def imread(filename): + """ + Read in an image file with tif or image file type supported by cv2. + + Args: + filename (str): The path to the image file. + + Returns: + numpy.ndarray: The image data as a NumPy array. + + Raises: + None + + Raises an error if the image file format is not supported. + + Examples: + >>> img = imread("image.tif") + """ + # ensure that extension check is not case sensitive + ext = os.path.splitext(filename)[-1].lower() + if ext == ".tif" or ext == ".tiff" or ext == ".flex": + with tifffile.TiffFile(filename) as tif: + ltif = len(tif.pages) + try: + full_shape = tif.shaped_metadata[0]["shape"] + except: + try: + page = tif.series[0][0] + full_shape = tif.series[0].shape + except: + ltif = 0 + if ltif < 10: + img = tif.asarray() + else: + page = tif.series[0][0] + shape, dtype = page.shape, page.dtype + ltif = int(np.prod(full_shape) / np.prod(shape)) + io_logger.info(f"reading tiff with {ltif} planes") + img = np.zeros((ltif, *shape), dtype=dtype) + for i, page in enumerate(tqdm(tif.series[0])): + img[i] = page.asarray() + img = img.reshape(full_shape) + return img + elif ext == ".dax": + img = load_dax(filename) + return img + elif ext == ".nd2": + if not ND2: + io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file") + return None + elif ext == ".nrrd": + if not NRRD: + io_logger.critical( + "ERROR: need to 'pip install pynrrd' to load in .nrrd file") + return None + else: + img, metadata = nrrd.read(filename) + if img.ndim == 3: + img = img.transpose(2, 0, 1) + return img + elif ext != ".npy": + try: + img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH) + if img.ndim > 2: + img = img[..., [2, 1, 0]] + return img + except Exception as e: + io_logger.critical("ERROR: could not read file, %s" % e) + return None + else: + try: + dat = np.load(filename, allow_pickle=True).item() + masks = dat["masks"] + return masks + except Exception as e: + io_logger.critical("ERROR: could not read masks from file, %s" % e) + return None + + +def remove_model(filename, delete=False): + """ remove model from .cellpose custom model list """ + filename = os.path.split(filename)[-1] + from . import models + model_strings = models.get_user_models() + if len(model_strings) > 0: + with open(models.MODEL_LIST_PATH, "w") as textfile: + for fname in model_strings: + textfile.write(fname + "\n") + else: + # write empty file + textfile = open(models.MODEL_LIST_PATH, "w") + textfile.close() + print(f"{filename} removed from custom model list") + if delete: + os.remove(os.fspath(models.MODEL_DIR.joinpath(fname))) + print("model deleted") + + +def add_model(filename): + """ add model to .cellpose models folder to use with GUI or CLI """ + from . import models + fname = os.path.split(filename)[-1] + try: + shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname))) + except shutil.SameFileError: + pass + print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}") + if fname not in models.get_user_models(): + with open(models.MODEL_LIST_PATH, "a") as textfile: + textfile.write(fname + "\n") + + +def imsave(filename, arr): + """ + Saves an image array to a file. + + Args: + filename (str): The name of the file to save the image to. + arr (numpy.ndarray): The image array to be saved. + + Returns: + None + """ + ext = os.path.splitext(filename)[-1].lower() + if ext == ".tif" or ext == ".tiff": + tifffile.imwrite(filename, arr) + else: + if len(arr.shape) > 2: + arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) + cv2.imwrite(filename, arr) + + +def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False): + """ + Finds all images in a folder and its subfolders (if specified) with the given file extensions. + + Args: + folder (str): The path to the folder to search for images. + mask_filter (str): The filter for mask files. + imf (str, optional): The additional filter for image files. Defaults to None. + look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False. + + Returns: + list: A list of image file paths. + + Raises: + ValueError: If no files are found in the specified folder. + ValueError: If no images are found in the specified folder with the supported file extensions. + ValueError: If no images are found in the specified folder without the mask or flow file endings. + """ + mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1", + "_flows_2", "_cellprob", "_masks", mask_filter] + image_names = [] + if imf is None: + imf = "" + + folders = [] + if look_one_level_down: + folders = natsorted(glob.glob(os.path.join(folder, "*/"))) + folders.append(folder) + exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"] + l0 = 0 + al = 0 + for folder in folders: + all_files = glob.glob(folder + "/*") + al += len(all_files) + for ext in exts: + image_names.extend(glob.glob(folder + f"/*{imf}{ext}")) + image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}")) + l0 += len(image_names) + + # return error if no files found + if al == 0: + raise ValueError("ERROR: no files in --dir folder ") + elif l0 == 0: + raise ValueError( + "ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex" + ) + + image_names = natsorted(image_names) + imn = [] + for im in image_names: + imfile = os.path.splitext(im)[0] + igood = all([(len(imfile) > len(mask_filter) and + imfile[-len(mask_filter):] != mask_filter) or + len(imfile) <= len(mask_filter) for mask_filter in mask_filters]) + if len(imf) > 0: + igood &= imfile[-len(imf):] == imf + if igood: + imn.append(im) + + image_names = imn + + # remove duplicates + image_names = [*set(image_names)] + image_names = natsorted(image_names) + + if len(image_names) == 0: + raise ValueError( + "ERROR: no images in --dir folder without _masks or _flows or _cellprob ending") + + return image_names + +def get_label_files(image_names, mask_filter, imf=None): + """ + Get the label files corresponding to the given image names and mask filter. + + Args: + image_names (list): List of image names. + mask_filter (str): Mask filter to be applied. + imf (str, optional): Image file extension. Defaults to None. + + Returns: + tuple: A tuple containing the label file names and flow file names (if present). + """ + nimg = len(image_names) + label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)] + + if imf is not None and len(imf) > 0: + label_names = [label_names0[n][:-len(imf)] for n in range(nimg)] + else: + label_names = label_names0 + + # check for flows + if os.path.exists(label_names0[0] + "_flows.tif"): + flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)] + else: + flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)] + if not all([os.path.exists(flow) for flow in flow_names]): + io_logger.info( + "not all flows are present, running flow generation for all images") + flow_names = None + + # check for masks + if mask_filter == "_seg.npy": + label_names = [label_names[n] + mask_filter for n in range(nimg)] + return label_names, None + + if os.path.exists(label_names[0] + mask_filter + ".tif"): + label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)] + elif os.path.exists(label_names[0] + mask_filter + ".tiff"): + label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)] + elif os.path.exists(label_names[0] + mask_filter + ".png"): + label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)] + # todo, allow _seg.npy + #elif os.path.exists(label_names[0] + "_seg.npy"): + # io_logger.info("labels found as _seg.npy files, converting to tif") + else: + if not flow_names: + raise ValueError("labels not provided with correct --mask_filter") + else: + label_names = None + if not all([os.path.exists(label) for label in label_names]): + if not flow_names: + raise ValueError( + "labels not provided for all images in train and/or test set") + else: + label_names = None + + return label_names, flow_names + + +def load_images_labels(tdir, mask_filter="_masks", image_filter=None, + look_one_level_down=False): + """ + Loads images and corresponding labels from a directory. + + Args: + tdir (str): The directory path. + mask_filter (str, optional): The filter for mask files. Defaults to "_masks". + image_filter (str, optional): The filter for image files. Defaults to None. + look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False. + + Returns: + tuple: A tuple containing a list of images, a list of labels, and a list of image names. + """ + image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down) + nimg = len(image_names) + + # training data + label_names, flow_names = get_label_files(image_names, mask_filter, + imf=image_filter) + + images = [] + labels = [] + k = 0 + for n in range(nimg): + if (os.path.isfile(label_names[n]) or + (flow_names is not None and os.path.isfile(flow_names[0]))): + image = imread(image_names[n]) + if label_names is not None: + label = imread(label_names[n]) + if flow_names is not None: + flow = imread(flow_names[n]) + if flow.shape[0] < 4: + label = np.concatenate((label[np.newaxis, :, :], flow), axis=0) + else: + label = flow + images.append(image) + labels.append(label) + k += 1 + io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels") + return images, labels, image_names + +def load_train_test_data(train_dir, test_dir=None, image_filter=None, + mask_filter="_masks", look_one_level_down=False): + """ + Loads training and testing data for a Cellpose model. + + Args: + train_dir (str): The directory path containing the training data. + test_dir (str, optional): The directory path containing the testing data. Defaults to None. + image_filter (str, optional): The filter for selecting image files. Defaults to None. + mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks". + look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False. + + Returns: + images, labels, image_names, test_images, test_labels, test_image_names + + """ + images, labels, image_names = load_images_labels(train_dir, mask_filter, + image_filter, look_one_level_down) + # testing data + test_images, test_labels, test_image_names = None, None, None + if test_dir is not None: + test_images, test_labels, test_image_names = load_images_labels( + test_dir, mask_filter, image_filter, look_one_level_down) + + return images, labels, image_names, test_images, test_labels, test_image_names + + +def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=None, + imgs_restore=None, restore_type=None, ratio=1.): + """Save output of model eval to be loaded in GUI. + + Can be list output (run on multiple images) or single output (run on single image). + + Saved to file_names[k]+"_seg.npy". + + Args: + images (list): Images input into cellpose. + masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels. + flows (list): Flows output from Cellpose.eval. + file_names (list, str): Names of files of images. + diams (float array): Diameters used to run Cellpose. Defaults to 30. + channels (list, int, optional): Channels used to run Cellpose. Defaults to None. + + Returns: + None + """ + + if channels is None: + channels = [0, 0] + + if isinstance(masks, list): + if not isinstance(diams, (list, np.ndarray)): + diams = diams * np.ones(len(masks), np.float32) + if imgs_restore is None: + imgs_restore = [None] * len(masks) + if isinstance(file_names, str): + file_names = [file_names] * len(masks) + for k, [image, mask, flow, diam, file_name, img_restore + ] in enumerate(zip(images, masks, flows, diams, file_names, + imgs_restore)): + channels_img = channels + if channels_img is not None and len(channels) > 2: + channels_img = channels[k] + masks_flows_to_seg(image, mask, flow, file_name, diams=diam, + channels=channels_img, imgs_restore=img_restore, + restore_type=restore_type, ratio=ratio) + return + + if len(channels) == 1: + channels = channels[0] + + flowi = [] + if flows[0].ndim == 3: + Ly, Lx = masks.shape[-2:] + flowi.append( + cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis, + ...]) + else: + flowi.append(flows[0]) + + if flows[0].ndim == 3: + cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype( + np.uint8) + cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST) + flowi.append(cellprob[np.newaxis, ...]) + flowi.append(np.zeros(flows[0].shape, dtype=np.uint8)) + flowi[-1] = flowi[-1][np.newaxis, ...] + else: + flowi.append( + (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8)) + flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8)) + if len(flows) > 2: + if len(flows) > 3: + flowi.append(flows[3]) + else: + flowi.append([]) + flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0)) + outlines = masks * utils.masks_to_outlines(masks) + base = os.path.splitext(file_names)[0] + + dat = { + "outlines": + outlines.astype(np.uint16) if outlines.max() < 2**16 - + 1 else outlines.astype(np.uint32), + "masks": + masks.astype(np.uint16) if outlines.max() < 2**16 - + 1 else masks.astype(np.uint32), + "chan_choose": + channels, + "ismanual": + np.zeros(masks.max(), bool), + "filename": + file_names, + "flows": + flowi, + "diameter": + diams + } + if restore_type is not None and imgs_restore is not None: + dat["restore"] = restore_type + dat["ratio"] = ratio + dat["img_restore"] = imgs_restore + + np.save(base + "_seg.npy", dat) + +def save_to_png(images, masks, flows, file_names): + """ deprecated (runs io.save_masks with png=True) + + does not work for 3D images + + """ + save_masks(images, masks, flows, file_names, png=True) + + +def save_rois(masks, file_name, multiprocessing=None): + """ save masks to .roi files in .zip archive for ImageJ/Fiji + + Args: + masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels + file_name (str): name to save the .zip file to + + Returns: + None + """ + outlines = utils.outlines_list(masks, multiprocessing=multiprocessing) + nonempty_outlines = [outline for outline in outlines if len(outline)!=0] + if len(outlines)!=len(nonempty_outlines): + print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.") + rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines] + file_name = os.path.splitext(file_name)[0] + '_rois.zip' + + + # Delete file if it exists; the roifile lib appends to existing zip files. + # If the user removed a mask it will still be in the zip file + if os.path.exists(file_name): + os.remove(file_name) + + roiwrite(file_name, rois) + + +def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0], + suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False, + in_folders=False, savedir=None, save_txt=False, save_mpl=False): + """ Save masks + nicely plotted segmentation image to png and/or tiff. + + Can save masks, flows to different directories, if in_folders is True. + + If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png". + + If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif". + + If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png". + + Only tif option works for 3D data, and only tif option works for empty masks. + + Args: + images (list): Images input into cellpose. + masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels. + flows (list): Flows output from Cellpose.eval. + file_names (list, str): Names of files of images. + png (bool, optional): Save masks to PNG. Defaults to True. + tif (bool, optional): Save masks to TIF. Defaults to False. + channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0]. + suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks". + save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False. + save_outlines (bool, optional): Save outlines of masks. Defaults to False. + dir_above (bool, optional): Save masks/flows in directory above. Defaults to False. + in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False. + savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None. + save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False. + save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D. + This takes a long time for large images. Defaults to False. + + Returns: + None + """ + + if isinstance(masks, list): + for image, mask, flow, file_name in zip(images, masks, flows, file_names): + save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix, + dir_above=dir_above, save_flows=save_flows, + save_outlines=save_outlines, savedir=savedir, save_txt=save_txt, + in_folders=in_folders, save_mpl=save_mpl) + return + + if masks.ndim > 2 and not tif: + raise ValueError("cannot save 3D outputs as PNG, use tif option instead") + + if masks.max() == 0: + io_logger.warning("no masks found, will not save PNG or outlines") + if not tif: + return + else: + png = False + save_outlines = False + save_flows = False + save_txt = False + + if savedir is None: + if dir_above: + savedir = Path(file_names).parent.parent.absolute( + ) #go up a level to save in its own folder + else: + savedir = Path(file_names).parent.absolute() + + check_dir(savedir) + + basename = os.path.splitext(os.path.basename(file_names))[0] + if in_folders: + maskdir = os.path.join(savedir, "masks") + outlinedir = os.path.join(savedir, "outlines") + txtdir = os.path.join(savedir, "txt_outlines") + flowdir = os.path.join(savedir, "flows") + else: + maskdir = savedir + outlinedir = savedir + txtdir = savedir + flowdir = savedir + + check_dir(maskdir) + + exts = [] + if masks.ndim > 2: + png = False + tif = True + if png: + if masks.max() < 2**16: + masks = masks.astype(np.uint16) + exts.append(".png") + else: + png = False + tif = True + io_logger.warning( + "found more than 65535 masks in each image, cannot save PNG, saving as TIF" + ) + if tif: + exts.append(".tif") + + # save masks + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for ext in exts: + imsave(os.path.join(maskdir, basename + suffix + ext), masks) + + if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3: + # Make and save original/segmentation/flows image + + img = images.copy() + if img.ndim < 3: + img = img[:, :, np.newaxis] + elif img.shape[0] < 8: + np.transpose(img, (1, 2, 0)) + + fig = plt.figure(figsize=(12, 3)) + plot.show_segmentation(fig, img, masks, flows[0]) + fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"), + dpi=300) + plt.close(fig) + + # ImageJ txt outline files + if masks.ndim < 3 and save_txt: + check_dir(txtdir) + outlines = utils.outlines_list(masks) + outlines_to_text(os.path.join(txtdir, basename), outlines) + + # RGB outline images + if masks.ndim < 3 and save_outlines: + check_dir(outlinedir) + outlines = utils.masks_to_outlines(masks) + outX, outY = np.nonzero(outlines) + img0 = transforms.normalize99(images) + if img0.shape[0] < 4: + img0 = np.transpose(img0, (1, 2, 0)) + if img0.shape[-1] < 3 or img0.ndim < 3: + img0 = plot.image_to_rgb(img0, channels=channels) + else: + if img0.max() <= 50.0: + img0 = np.uint8(np.clip(img0 * 255, 0, 1)) + imgout = img0.copy() + imgout[outX, outY] = np.array([255, 0, 0]) #pure red + imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"), + imgout) + + # save RGB flow picture + if masks.ndim < 3 and save_flows: + check_dir(flowdir) + imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"), + (flows[0] * (2**16 - 1)).astype(np.uint16)) + #save full flow data + imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1]) diff --git a/cellpose/logo/cellpose.ico b/cellpose/logo/cellpose.ico new file mode 100644 index 0000000000000000000000000000000000000000..01344e7d286956763ba64ebf5f5032e70657f402 Binary files /dev/null and b/cellpose/logo/cellpose.ico differ diff --git a/cellpose/logo/logo.png b/cellpose/logo/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..3f4a324287ef48d9f6ef8b8f9d942ef37cca64b4 --- /dev/null +++ b/cellpose/logo/logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11fd5caf7ce1f3c0f9f4b283ac5f34981e604e8b68a7ede70457dbe8611aa328 +size 29050 diff --git a/cellpose/metrics.py b/cellpose/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..46e656a16a936a7ea2774dde31f1c784a222b7d4 --- /dev/null +++ b/cellpose/metrics.py @@ -0,0 +1,264 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" +import numpy as np +from . import utils, dynamics +from numba import jit +from scipy.optimize import linear_sum_assignment +from scipy.ndimage import convolve, mean + + +def mask_ious(masks_true, masks_pred): + """Return best-matched masks.""" + iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:] + n_min = min(iou.shape[0], iou.shape[1]) + costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min) + true_ind, pred_ind = linear_sum_assignment(costs) + iout = np.zeros(masks_true.max()) + iout[true_ind] = iou[true_ind, pred_ind] + preds = np.zeros(masks_true.max(), "int") + preds[true_ind] = pred_ind + 1 + return iout, preds + + +def boundary_scores(masks_true, masks_pred, scales): + """ + Calculate boundary precision, recall, and F-score. + + Args: + masks_true (list): List of true masks. + masks_pred (list): List of predicted masks. + scales (list): List of scales. + + Returns: + tuple: A tuple containing precision, recall, and F-score arrays. + """ + diams = [utils.diameters(lbl)[0] for lbl in masks_true] + precision = np.zeros((len(scales), len(masks_true))) + recall = np.zeros((len(scales), len(masks_true))) + fscore = np.zeros((len(scales), len(masks_true))) + for j, scale in enumerate(scales): + for n in range(len(masks_true)): + diam = max(1, scale * diams[n]) + rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))]) + filt = (rs <= diam).astype(np.float32) + otrue = utils.masks_to_outlines(masks_true[n]) + otrue = convolve(otrue, filt) + opred = utils.masks_to_outlines(masks_pred[n]) + opred = convolve(opred, filt) + tp = np.logical_and(otrue == 1, opred == 1).sum() + fp = np.logical_and(otrue == 0, opred == 1).sum() + fn = np.logical_and(otrue == 1, opred == 0).sum() + precision[j, n] = tp / (tp + fp) + recall[j, n] = tp / (tp + fn) + fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j]) + return precision, recall, fscore + + +def aggregated_jaccard_index(masks_true, masks_pred): + """ + AJI = intersection of all matched masks / union of all masks + + Args: + masks_true (list of np.ndarrays (int) or np.ndarray (int)): + where 0=NO masks; 1,2... are mask labels + masks_pred (list of np.ndarrays (int) or np.ndarray (int)): + np.ndarray (int) where 0=NO masks; 1,2... are mask labels + + Returns: + aji (float): aggregated jaccard index for each set of masks + """ + aji = np.zeros(len(masks_true)) + for n in range(len(masks_true)): + iout, preds = mask_ious(masks_true[n], masks_pred[n]) + inds = np.arange(0, masks_true[n].max(), 1, int) + overlap = _label_overlap(masks_true[n], masks_pred[n]) + union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum() + overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)] + aji[n] = overlap.sum() / union + return aji + + +def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]): + """ + Average precision estimation: AP = TP / (TP + FP + FN) + + This function is based heavily on the *fast* stardist matching functions + (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py) + + Args: + masks_true (list of np.ndarrays (int) or np.ndarray (int)): + where 0=NO masks; 1,2... are mask labels + masks_pred (list of np.ndarrays (int) or np.ndarray (int)): + np.ndarray (int) where 0=NO masks; 1,2... are mask labels + + Returns: + ap (array [len(masks_true) x len(threshold)]): + average precision at thresholds + tp (array [len(masks_true) x len(threshold)]): + number of true positives at thresholds + fp (array [len(masks_true) x len(threshold)]): + number of false positives at thresholds + fn (array [len(masks_true) x len(threshold)]): + number of false negatives at thresholds + """ + not_list = False + if not isinstance(masks_true, list): + masks_true = [masks_true] + masks_pred = [masks_pred] + not_list = True + if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray): + threshold = [threshold] + + if len(masks_true) != len(masks_pred): + raise ValueError( + "metrics.average_precision requires len(masks_true)==len(masks_pred)") + + ap = np.zeros((len(masks_true), len(threshold)), np.float32) + tp = np.zeros((len(masks_true), len(threshold)), np.float32) + fp = np.zeros((len(masks_true), len(threshold)), np.float32) + fn = np.zeros((len(masks_true), len(threshold)), np.float32) + n_true = np.array(list(map(np.max, masks_true))) + n_pred = np.array(list(map(np.max, masks_pred))) + + for n in range(len(masks_true)): + #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape) + if n_pred[n] > 0: + iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:] + for k, th in enumerate(threshold): + tp[n, k] = _true_positive(iou, th) + fp[n] = n_pred[n] - tp[n] + fn[n] = n_true[n] - tp[n] + ap[n] = tp[n] / (tp[n] + fp[n] + fn[n]) + + if not_list: + ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0] + return ap, tp, fp, fn + + +@jit(nopython=True) +def _label_overlap(x, y): + """Fast function to get pixel overlaps between masks in x and y. + + Args: + x (np.ndarray, int): Where 0=NO masks; 1,2... are mask labels. + y (np.ndarray, int): Where 0=NO masks; 1,2... are mask labels. + + Returns: + overlap (np.ndarray, int): Matrix of pixel overlaps of size [x.max()+1, y.max()+1]. + """ + # put label arrays into standard form then flatten them + # x = (utils.format_labels(x)).ravel() + # y = (utils.format_labels(y)).ravel() + x = x.ravel() + y = y.ravel() + + # preallocate a "contact map" matrix + overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint) + + # loop over the labels in x and add to the corresponding + # overlap entry. If label A in x and label B in y share P + # pixels, then the resulting overlap is P + # len(x)=len(y), the number of pixels in the whole image + for i in range(len(x)): + overlap[x[i], y[i]] += 1 + return overlap + + +def _intersection_over_union(masks_true, masks_pred): + """Calculate the intersection over union of all mask pairs. + + Parameters: + masks_true (np.ndarray, int): Ground truth masks, where 0=NO masks; 1,2... are mask labels. + masks_pred (np.ndarray, int): Predicted masks, where 0=NO masks; 1,2... are mask labels. + + Returns: + iou (np.ndarray, float): Matrix of IOU pairs of size [x.max()+1, y.max()+1]. + + How it works: + The overlap matrix is a lookup table of the area of intersection + between each set of labels (true and predicted). The true labels + are taken to be along axis 0, and the predicted labels are taken + to be along axis 1. The sum of the overlaps along axis 0 is thus + an array giving the total overlap of the true labels with each of + the predicted labels, and likewise the sum over axis 1 is the + total overlap of the predicted labels with each of the true labels. + Because the label 0 (background) is included, this sum is guaranteed + to reconstruct the total area of each label. Adding this row and + column vectors gives a 2D array with the areas of every label pair + added together. This is equivalent to the union of the label areas + except for the duplicated overlap area, so the overlap matrix is + subtracted to find the union matrix. + """ + overlap = _label_overlap(masks_true, masks_pred) + n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) + n_pixels_true = np.sum(overlap, axis=1, keepdims=True) + iou = overlap / (n_pixels_pred + n_pixels_true - overlap) + iou[np.isnan(iou)] = 0.0 + return iou + + +def _true_positive(iou, th): + """Calculate the true positive at threshold th. + + Args: + iou (float, np.ndarray): Array of IOU pairs. + th (float): Threshold on IOU for positive label. + + Returns: + tp (float): Number of true positives at threshold. + + How it works: + (1) Find minimum number of masks. + (2) Define cost matrix; for a given threshold, each element is negative + the higher the IoU is (perfect IoU is 1, worst is 0). The second term + gets more negative with higher IoU, but less negative with greater + n_min (but that's a constant...). + (3) Solve the linear sum assignment problem. The costs array defines the cost + of matching a true label with a predicted label, so the problem is to + find the set of pairings that minimizes this cost. The scipy.optimize + function gives the ordered lists of corresponding true and predicted labels. + (4) Extract the IoUs from these pairings and then threshold to get a boolean array + whose sum is the number of true positives that is returned. + """ + n_min = min(iou.shape[0], iou.shape[1]) + costs = -(iou >= th).astype(float) - iou / (2 * n_min) + true_ind, pred_ind = linear_sum_assignment(costs) + match_ok = iou[true_ind, pred_ind] >= th + tp = match_ok.sum() + return tp + + +def flow_error(maski, dP_net, device=None): + """Error in flows from predicted masks vs flows predicted by network run on image. + + This function serves to benchmark the quality of masks. It works as follows: + 1. The predicted masks are used to create a flow diagram. + 2. The mask-flows are compared to the flows that the network predicted. + + If there is a discrepancy between the flows, it suggests that the mask is incorrect. + Masks with flow_errors greater than 0.4 are discarded by default. This setting can be + changed in Cellpose.eval or CellposeModel.eval. + + Args: + maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels. + dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape. + + Returns: + A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks; + dP_masks (np.ndarray, float): ND flows produced from the predicted masks. + """ + if dP_net.shape[1:] != maski.shape: + print("ERROR: net flow is not same size as predicted masks") + return + + # flows predicted from estimated masks + dP_masks = dynamics.masks_to_flows(maski, device=device) + # difference between predicted flows vs mask flows + flow_errors = np.zeros(maski.max()) + for i in range(dP_masks.shape[0]): + flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski, + index=np.arange(1, + maski.max() + 1)) + + return flow_errors, dP_masks diff --git a/cellpose/models.py b/cellpose/models.py new file mode 100644 index 0000000000000000000000000000000000000000..40447ebe15b0f8ae4b9ac64825a134ee04774f04 --- /dev/null +++ b/cellpose/models.py @@ -0,0 +1,797 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os, sys, time, shutil, tempfile, datetime, pathlib, subprocess +from pathlib import Path +import numpy as np +from tqdm import trange, tqdm +from urllib.parse import urlparse +import torch +from scipy.ndimage import gaussian_filter +#import cv2 +import gc + +import logging + +models_logger = logging.getLogger(__name__) + +from . import transforms, dynamics, utils, plot +from .resnet_torch import CPnet +from .core import assign_device, check_mkl, run_net, run_3D + +_MODEL_URL = "https://www.cellpose.org/models" +_MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH") +_MODEL_DIR_DEFAULT = pathlib.Path.home().joinpath(".cellpose", "models") +MODEL_DIR = pathlib.Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT + +MODEL_NAMES = [ + "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", + "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", "CPx", + "transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer", + "neurips_grayscale_cyto2", + "CP", "CPx", "TN1", "TN2", "TN3", "LC1", "LC2", "LC3", "LC4" +] + +MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt")) + +normalize_default = { + "lowhigh": None, + "percentile": None, + "normalize": True, + "norm3D": True, + "sharpen_radius": 0, + "smooth_radius": 0, + "tile_norm_blocksize": 0, + "tile_norm_smooth3D": 1, + "invert": False +} + + +def model_path(model_type, model_index=0): + torch_str = "torch" + if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": + basename = "%s%s_%d" % (model_type, torch_str, model_index) + else: + basename = model_type + return cache_model_path(basename) + + +def size_model_path(model_type): + torch_str = "torch" + if (model_type == "cyto" or model_type == "nuclei" or + model_type == "cyto2" or model_type == "cyto3"): + if model_type == "cyto3": + basename = "size_%s.npy" % model_type + else: + basename = "size_%s%s_0.npy" % (model_type, torch_str) + return cache_model_path(basename) + else: + if os.path.exists(model_type) and os.path.exists(model_type + "_size.npy"): + return model_type + "_size.npy" + else: + raise FileNotFoundError(f"size model not found ({model_type + '_size.npy'})") + + +def cache_model_path(basename): + MODEL_DIR.mkdir(parents=True, exist_ok=True) + url = f"{_MODEL_URL}/{basename}" + cached_file = os.fspath(MODEL_DIR.joinpath(basename)) + if not os.path.exists(cached_file): + models_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + utils.download_url_to_file(url, cached_file, progress=True) + return cached_file + + +def get_user_models(): + model_strings = [] + if os.path.exists(MODEL_LIST_PATH): + with open(MODEL_LIST_PATH, "r") as textfile: + lines = [line.rstrip() for line in textfile] + if len(lines) > 0: + model_strings.extend(lines) + return model_strings + + +class Cellpose(): + """Main model which combines SizeModel and CellposeModel. + + Args: + gpu (bool, optional): Whether or not to use GPU, will check if GPU available. Defaults to False. + model_type (str, optional): Model type. "cyto"=cytoplasm model; "nuclei"=nucleus model; + "cyto2"=cytoplasm model with additional user images; + "cyto3"=super-generalist model; Defaults to "cyto3". + device (torch device, optional): Device used for model running / training. Overrides gpu input. Recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). Defaults to None. + + Attributes: + device (torch device): Device used for model running / training. + gpu (bool): Flag indicating if GPU is used. + diam_mean (float): Mean diameter for cytoplasm model. + cp (CellposeModel): CellposeModel instance. + pretrained_size (str): Pretrained size model path. + sz (SizeModel): SizeModel instance. + + """ + + def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None, + backbone="default"): + super(Cellpose, self).__init__() + + # assign device (GPU or CPU) + sdevice, gpu = assign_device(use_torch=True, gpu=gpu) + self.device = device if device is not None else sdevice + self.gpu = gpu + self.backbone = backbone + + model_type = "cyto3" if model_type is None else model_type + + self.diam_mean = 30. #default for any cyto model + nuclear = "nuclei" in model_type + if nuclear: + self.diam_mean = 17. + + if model_type in ["cyto", "nuclei", "cyto2", "cyto3"] and nchan != 2: + nchan = 2 + models_logger.warning( + f"cannot set nchan to other value for {model_type} model") + self.nchan = nchan + + self.cp = CellposeModel(device=self.device, gpu=self.gpu, model_type=model_type, + diam_mean=self.diam_mean, nchan=self.nchan, + backbone=self.backbone) + self.cp.model_type = model_type + + # size model not used for bacterial model + self.pretrained_size = size_model_path(model_type) + self.sz = SizeModel(device=self.device, pretrained_size=self.pretrained_size, + cp_model=self.cp) + self.sz.model_type = model_type + + def eval(self, x, batch_size=8, channels=[0, 0], channel_axis=None, invert=False, + normalize=True, diameter=30., do_3D=False, **kwargs): + """Run cellpose size model and mask model and get masks. + + Args: + x (list or array): List or array of images. Can be list of 2D/3D images, or array of 2D/3D images, or 4D image array. + batch_size (int, optional): Number of 224x224 patches to run simultaneously on the GPU. Can make smaller or bigger depending on GPU memory usage. Defaults to 8. + channels (list, optional): List of channels, either of length 2 or of length number of images by 2. First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). For instance, to segment grayscale images, input [0,0]. To segment images with cells in green and nuclei in blue, input [2,3]. To segment one grayscale image and one image with cells in green and nuclei in blue, input [[0,0], [2,3]]. Defaults to [0,0]. + channel_axis (int, optional): If None, channels dimension is attempted to be automatically determined. Defaults to None. + invert (bool, optional): Invert image pixel intensity before running network (if True, image is also normalized). Defaults to False. + normalize (bool, optional): If True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; can also pass dictionary of parameters (see CellposeModel for details). Defaults to True. + diameter (float, optional): If set to None, then diameter is automatically estimated if size model is loaded. Defaults to 30.. + do_3D (bool, optional): Set to True to run 3D segmentation on 4D image input. Defaults to False. + + Returns: + A tuple containing (masks, flows, styles, diams): masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels; + flows (list of lists 2D arrays or list of 3D arrays): flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY flows at each pixel; + flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); + flows[k][3] = final pixel locations after Euler integration; + styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image; + diams (list of diameters or float): List of diameters or float (if do_3D=True). + + """ + + tic0 = time.time() + models_logger.info(f"channels set to {channels}") + + diam0 = diameter[0] if isinstance(diameter, (np.ndarray, list)) else diameter + estimate_size = True if (diameter is None or diam0 == 0) else False + + if estimate_size and self.pretrained_size is not None and not do_3D and x[ + 0].ndim < 4: + tic = time.time() + models_logger.info("~~~ ESTIMATING CELL DIAMETER(S) ~~~") + diams, _ = self.sz.eval(x, channels=channels, channel_axis=channel_axis, + batch_size=batch_size, normalize=normalize, + invert=invert) + diameter = None + models_logger.info("estimated cell diameter(s) in %0.2f sec" % + (time.time() - tic)) + models_logger.info(">>> diameter(s) = ") + if isinstance(diams, list) or isinstance(diams, np.ndarray): + diam_string = "[" + "".join(["%0.2f, " % d for d in diams]) + "]" + else: + diam_string = "[ %0.2f ]" % diams + models_logger.info(diam_string) + elif estimate_size: + if self.pretrained_size is None: + reason = "no pretrained size model specified in model Cellpose" + else: + reason = "does not work on non-2D images" + models_logger.warning(f"could not estimate diameter, {reason}") + diams = self.diam_mean + else: + diams = diameter + + models_logger.info("~~~ FINDING MASKS ~~~") + masks, flows, styles = self.cp.eval(x, channels=channels, + channel_axis=channel_axis, + batch_size=batch_size, normalize=normalize, + invert=invert, diameter=diams, do_3D=do_3D, + **kwargs) + models_logger.info(">>>> TOTAL TIME %0.2f sec" % (time.time() - tic0)) + + return masks, flows, styles, diams + +def get_model_params(pretrained_model, model_type, pretrained_model_ortho, default_model="cyto3"): + """ return pretrained_model path, diam_mean and if model is builtin """ + builtin = False + use_default = False + diam_mean = None + model_strings = get_user_models() + all_models = MODEL_NAMES.copy() + all_models.extend(model_strings) + + # check if pretrained_model is builtin or custom user model saved in .cellpose/models + # if yes, then set to model_type + if (pretrained_model and not Path(pretrained_model).exists() and + np.any([pretrained_model == s for s in all_models])): + model_type = pretrained_model + + # check if model_type is builtin or custom user model saved in .cellpose/models + if model_type is not None and np.any([model_type == s for s in all_models]): + if np.any([model_type == s for s in MODEL_NAMES]): + builtin = True + models_logger.info(f">> {model_type} << model set to be used") + if model_type == "nuclei": + diam_mean = 17. + pretrained_model = model_path(model_type) + # if model_type is not None and does not exist, use default model + elif model_type is not None: + if Path(model_type).exists(): + pretrained_model = model_type + else: + models_logger.warning("model_type does not exist, using default model") + use_default = True + # if model_type is None... + else: + # if pretrained_model does not exist, use default model + if pretrained_model and not Path(pretrained_model).exists(): + models_logger.warning( + "pretrained_model path does not exist, using default model") + use_default = True + elif pretrained_model: + if pretrained_model[-13:] == "nucleitorch_0": + builtin = True + diam_mean = 17. + + if pretrained_model_ortho: + if pretrained_model_ortho in all_models: + pretrained_model_ortho = model_path(pretrained_model_ortho) + elif Path(pretrained_model_ortho).exists(): + pass + else: + pretrained_model_ortho = None + + pretrained_model = model_path(default_model) if use_default else pretrained_model + builtin = True if use_default else builtin + return pretrained_model, diam_mean, builtin, pretrained_model_ortho + + +class CellposeModel(): + """ + Class representing a Cellpose model. + + Attributes: + diam_mean (float): Mean "diameter" value for the model. + builtin (bool): Whether the model is a built-in model or not. + device (torch device): Device used for model running / training. + mkldnn (None or bool): MKLDNN flag for the model. + nchan (int): Number of channels used as input to the network. + nclasses (int): Number of classes in the model. + nbase (list): List of base values for the model. + net (CPnet): Cellpose network. + pretrained_model (str): Path to pretrained cellpose model. + pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D. + backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer). + + Methods: + __init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None, nchan=2): + Initialize the CellposeModel. + + eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None): + Segment list of images x, or 4D array - Z x nchan x Y x X. + + """ + + def __init__(self, gpu=False, pretrained_model=False, model_type=None, + mkldnn=True, diam_mean=30., device=None, nchan=2, + pretrained_model_ortho=None, backbone="default"): + """ + Initialize the CellposeModel. + + Parameters: + gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available. + pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded. + model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo). + mkldnn (bool, optional): Use MKLDNN for CPU inference, faster but not always supported. + diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value. + device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). + nchan (int, optional): Number of channels to use as input to network, default is 2 (cyto + nuclei) or (nuclei + zeros). + """ + self.diam_mean = diam_mean + + ### set model path + default_model = "cyto3" if backbone == "default" else "transformer_cp3" + pretrained_model, diam_mean, builtin, pretrained_model_ortho = get_model_params( + pretrained_model, + model_type, + pretrained_model_ortho, + default_model) + self.diam_mean = diam_mean if diam_mean is not None else self.diam_mean + + ### assign model device + self.mkldnn = None + self.device = assign_device(gpu=gpu)[0] if device is None else device + if torch.cuda.is_available(): + device_gpu = self.device.type == "cuda" + elif torch.backends.mps.is_available(): + device_gpu = self.device.type == "mps" + else: + device_gpu = False + self.gpu = device_gpu + if not self.gpu: + self.mkldnn = check_mkl(True) if mkldnn else False + + ### create neural network + self.nchan = nchan + self.nclasses = 3 + nbase = [32, 64, 128, 256] + self.nbase = [nchan, *nbase] + self.pretrained_model = pretrained_model + if backbone == "default": + self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, + max_pool=True, diam_mean=self.diam_mean).to(self.device) + else: + from .segformer import Transformer + self.net = Transformer( + encoder_weights="imagenet" if not self.pretrained_model else None, + diam_mean=self.diam_mean).to(self.device) + + ### load model weights + if self.pretrained_model: + models_logger.info(f">>>> loading model {pretrained_model}") + self.net.load_model(self.pretrained_model, device=self.device) + if not builtin: + self.diam_mean = self.net.diam_mean.data.cpu().numpy()[0] + self.diam_labels = self.net.diam_labels.data.cpu().numpy()[0] + models_logger.info( + f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)" + ) + if not builtin: + models_logger.info( + f">>>> model diam_labels = {self.diam_labels: .3f} (mean diameter of training ROIs)" + ) + if pretrained_model_ortho is not None: + models_logger.info(f">>>> loading ortho model {pretrained_model_ortho}") + self.net_ortho = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, + max_pool=True, diam_mean=self.diam_mean).to(self.device) + self.net_ortho.load_model(pretrained_model_ortho, device=self.device) + else: + self.net_ortho = None + else: + models_logger.info(f">>>> no model weights loaded") + self.diam_labels = self.diam_mean + + self.net_type = f"cellpose_{backbone}" + + def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, + z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, + flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, + flow3D_smooth=0, stitch_threshold=0.0, + min_size=15, max_size_fraction=0.4, niter=None, + augment=False, tile_overlap=0.1, bsize=224, + interp=True, compute_masks=True, progress=None): + """ segment list of images x, or 4D array - Z x nchan x Y x X + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True. + channels (list, optional): list of channels, either of length 2 or of length number of images by 2. + First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). + Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to segment grayscale images, input [0,0]. To segment images with cells + in green and nuclei in blue, input [2,3]. To segment one grayscale image and one + image with cells in green and nuclei in blue, input [[0,0], [2,3]]. + Defaults to None. + channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. + if None, channels dimension is attempted to be automatically determined. Defaults to None. + z_axis (int, optional): z axis in element of list x, or of np.ndarray x. + if None, z dimension is attempted to be automatically determined. Defaults to None. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + invert (bool, optional): invert image pixel intensity before running network. Defaults to False. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + diameter (float, optional): diameter for each image, + if diameter is None, set to diam_mean or diam_train if available. Defaults to None. + flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4. + cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0. + do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False. + flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0. + anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. + stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0. + min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15. + max_size_fraction (float, optional): max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None. + augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + bsize (int, optional): block size for tiles, recommended to keep at 224, like in training. Defaults to 224. + interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True. + compute_masks (bool, optional): Whether or not to compute dynamics and return masks. This is set to False when retrieving the styles for the size model. Defaults to True. + progress (QProgressBar, optional): pyqt progress bar. Defaults to None. + + Returns: + A tuple containing (masks, flows, styles, diams): + masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels; + flows (list of lists 2D arrays or list of 3D arrays): flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY flows at each pixel; + flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); + flows[k][3] = final pixel locations after Euler integration; + styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image. + + """ + if isinstance(x, list) or x.squeeze().ndim == 5: + self.timing = [] + masks, styles, flows = [], [], [] + tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) + nimg = len(x) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + for i in iterator: + tic = time.time() + maski, flowi, stylei = self.eval( + x[i], batch_size=batch_size, + channels=channels[i] if channels is not None and + ((len(channels) == len(x) and + (isinstance(channels[i], list) or + isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2)) + else channels, channel_axis=channel_axis, z_axis=z_axis, + normalize=normalize, invert=invert, + rescale=rescale[i] if isinstance(rescale, list) or + isinstance(rescale, np.ndarray) else rescale, + diameter=diameter[i] if isinstance(diameter, list) or + isinstance(diameter, np.ndarray) else diameter, do_3D=do_3D, + anisotropy=anisotropy, augment=augment, + tile_overlap=tile_overlap, bsize=bsize, resample=resample, + interp=interp, flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold, compute_masks=compute_masks, + min_size=min_size, max_size_fraction=max_size_fraction, + stitch_threshold=stitch_threshold, flow3D_smooth=flow3D_smooth, + progress=progress, niter=niter + ) + masks.append(maski) + flows.append(flowi) + styles.append(stylei) + self.timing.append(time.time() - tic) + return masks, flows, styles + + else: + # reshape image + x = transforms.convert_image(x, channels, channel_axis=channel_axis, + z_axis=z_axis, do_3D=(do_3D or + stitch_threshold > 0), + nchan=self.nchan) + if x.ndim < 4: + x = x[np.newaxis, ...] + nimg = x.shape[0] + + if diameter is not None and diameter > 0: + rescale = self.diam_mean / diameter + elif rescale is None: + rescale = self.diam_mean / self.diam_labels + + # normalize image + normalize_params = normalize_default + if isinstance(normalize, dict): + normalize_params = {**normalize_params, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params["normalize"] = normalize + normalize_params["invert"] = invert + + # pre-normalize if 3D stack for stitching or do_3D + do_normalization = True if normalize_params["normalize"] else False + x = np.asarray(x) + if nimg > 1 and do_normalization and (stitch_threshold or do_3D): + normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"] + x = transforms.normalize_img(x, **normalize_params) + do_normalization = False # do not normalize again + else: + if normalize_params["norm3D"] and nimg > 1: + models_logger.warning( + "normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False" + ) + normalize_params["norm3D"] = False + if do_normalization: + x = transforms.normalize_img(x, **normalize_params) + + dP, cellprob, styles = self._run_net( + x, rescale=rescale, augment=augment, + batch_size=batch_size, tile_overlap=tile_overlap, bsize=bsize, + resample=resample, do_3D=do_3D, anisotropy=anisotropy) + + if do_3D: + if flow3D_smooth > 0: + models_logger.info(f"smoothing flows with sigma={flow3D_smooth}") + dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth)) + torch.cuda.empty_cache() + gc.collect() + + if compute_masks: + niter0 = 200 if not resample else (1 / rescale * 200) + niter = niter0 if niter is None or niter == 0 else niter + masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size, + max_size_fraction=max_size_fraction, niter=niter, + stitch_threshold=stitch_threshold, do_3D=do_3D) + else: + masks = np.zeros(0) #pass back zeros if not compute_masks + + masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze() + + return masks, [plot.dx_to_circ(dP), dP, cellprob], styles + + def _run_net(self, x, rescale=1.0, resample=True, augment=False, + batch_size=8, tile_overlap=0.1, + bsize=224, anisotropy=1.0, do_3D=False): + """ run network on image x """ + tic = time.time() + shape = x.shape + nimg = shape[0] + + if do_3D: + Lz, Ly, Lx = shape[:-1] + if rescale != 1.0 or (anisotropy is not None and anisotropy != 1.0): + models_logger.info(f"resizing 3D image with rescale={rescale:.2f} and anisotropy={anisotropy}") + anisotropy = 1.0 if anisotropy is None else anisotropy + if rescale != 1.0: + x = transforms.resize_image(x, Ly=int(Ly*rescale), + Lx=int(Lx*rescale)) + x = transforms.resize_image(x.transpose(1,0,2,3), + Ly=int(Lz*anisotropy*rescale), + Lx=int(Lx*rescale)).transpose(1,0,2,3) + yf, styles = run_3D(self.net, x, + batch_size=batch_size, augment=augment, + tile_overlap=tile_overlap, net_ortho=self.net_ortho) + if resample: + if rescale != 1.0 or Lz != yf.shape[0]: + models_logger.info("resizing 3D flows and cellprob to original image size") + if rescale != 1.0: + yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx) + if Lz != yf.shape[0]: + yf = transforms.resize_image(yf.transpose(1,0,2,3), + Ly=Lz, Lx=Lx).transpose(1,0,2,3) + cellprob = yf[..., -1] + dP = yf[..., :-1].transpose((3, 0, 1, 2)) + else: + yf, styles = run_net(self.net, x, bsize=bsize, augment=augment, + batch_size=batch_size, + tile_overlap=tile_overlap, + rsz=rescale if rescale!=1.0 else None) + if resample: + if rescale != 1.0: + yf = transforms.resize_image(yf, shape[1], shape[2]) + cellprob = yf[..., 2] + dP = yf[..., :2].transpose((3, 0, 1, 2)) + + styles = styles.squeeze() + + net_time = time.time() - tic + if nimg > 1: + models_logger.info("network run in %2.2fs" % (net_time)) + + return dP, cellprob, styles + + def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0, + interp=True, min_size=15, max_size_fraction=0.4, niter=None, + do_3D=False, stitch_threshold=0.0): + """ compute masks from flows and cell probability """ + Lz, Ly, Lx = shape[:3] + tic = time.time() + if do_3D: + masks = dynamics.resize_and_compute_masks( + dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, + min_size=min_size, max_size_fraction=max_size_fraction, + resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum() + else None, + device=self.device) + else: + nimg = shape[0] + Ly0, Lx0 = cellprob[0].shape + resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx] + tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + for i in iterator: + # turn off min_size for 3D stitching + min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1 + outputs = dynamics.resize_and_compute_masks( + dP[:, i], cellprob[i], + niter=niter, cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, interp=interp, resize=resize, + min_size=min_size0, max_size_fraction=max_size_fraction, + device=self.device) + if i==0 and nimg > 1: + masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype) + if nimg > 1: + masks[i] = outputs + else: + masks = outputs + + if stitch_threshold > 0 and nimg > 1: + models_logger.info( + f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks" + ) + masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold) + masks = utils.fill_holes_and_remove_small_masks( + masks, min_size=min_size) + elif nimg > 1: + models_logger.warning( + "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only" + ) + + flow_time = time.time() - tic + if shape[0] > 1: + models_logger.info("masks created in %2.2fs" % (flow_time)) + + return masks + +class SizeModel(): + """ + Linear regression model for determining the size of objects in image + used to rescale before input to cp_model. + Uses styles from cp_model. + + Attributes: + pretrained_size (str): Path to pretrained size model. + cp (UnetModel or CellposeModel): Model from which to get styles. + device (torch device): Device used for model running / training + (torch.device("cuda") or torch.device("cpu")), overrides gpu input, + recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). + diam_mean (float): Mean diameter of objects. + + Methods: + eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False, + augment=False, batch_size=8, progress=None, interp=True): + Use images x to produce style or use style input to predict size of objects in image. + + Raises: + ValueError: If no pretrained cellpose model is specified, cannot compute size. + """ + + def __init__(self, cp_model, device=None, pretrained_size=None, **kwargs): + super(SizeModel, self).__init__(**kwargs) + """ + Initialize size model. + + Args: + cp_model (UnetModel or CellposeModel): Model from which to get styles. + device (torch device, optional): Device used for model running / training + (torch.device("cuda") or torch.device("cpu")), overrides gpu input, + recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). + pretrained_size (str): Path to pretrained size model. + """ + + self.pretrained_size = pretrained_size + self.cp = cp_model + self.device = self.cp.device + self.diam_mean = self.cp.diam_mean + if pretrained_size is not None: + self.params = np.load(self.pretrained_size, allow_pickle=True).item() + self.diam_mean = self.params["diam_mean"] + if not hasattr(self.cp, "pretrained_model"): + error_message = "no pretrained cellpose model specified, cannot compute size" + models_logger.critical(error_message) + raise ValueError(error_message) + + def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False, + augment=False, batch_size=8, progress=None): + """Use images x to produce style or use style input to predict size of objects in image. + + Object size estimation is done in two steps: + 1. Use a linear regression model to predict size from style in image. + 2. Resize image to predicted size and run CellposeModel to get output masks. + Take the median object size of the predicted masks as the final predicted size. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + channels (list, optional): list of channels, either of length 2 or of length number of images by 2. + First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). + Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to segment grayscale images, input [0,0]. To segment images with cells + in green and nuclei in blue, input [2,3]. To segment one grayscale image and one + image with cells in green and nuclei in blue, input [[0,0], [2,3]]. + Defaults to None. + channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. + if None, channels dimension is attempted to be automatically determined. Defaults to None. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + invert (bool, optional): Invert image pixel intensity before running network (if True, image is also normalized). Defaults to False. + augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + progress (QProgressBar, optional): pyqt progress bar. Defaults to None. + + + Returns: + A tuple containing (diam, diam_style): + diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps; + diam_style (np.ndarray): Estimated diameters from style alone. + """ + if isinstance(x, list): + self.timing = [] + diams, diams_style = [], [] + nimg = len(x) + tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + for i in iterator: + tic = time.time() + diam, diam_style = self.eval( + x[i], channels=channels[i] if + (channels is not None and len(channels) == len(x) and + (isinstance(channels[i], list) or + isinstance(channels[i], np.ndarray)) and + len(channels[i]) == 2) else channels, channel_axis=channel_axis, + normalize=normalize, invert=invert, augment=augment, + batch_size=batch_size, progress=progress) + diams.append(diam) + diams_style.append(diam_style) + self.timing.append(time.time() - tic) + + return diams, diams_style + + if x.squeeze().ndim > 3: + models_logger.warning("image is not 2D cannot compute diameter") + return self.diam_mean, self.diam_mean + + styles = self.cp.eval(x, channels=channels, channel_axis=channel_axis, + normalize=normalize, invert=invert, augment=augment, + batch_size=batch_size, resample=False, + compute_masks=False)[-1] + + diam_style = self._size_estimation(np.array(styles)) + diam_style = self.diam_mean if (diam_style == 0 or + np.isnan(diam_style)) else diam_style + + masks = self.cp.eval( + x, compute_masks=True, channels=channels, channel_axis=channel_axis, + normalize=normalize, invert=invert, augment=augment, + batch_size=batch_size, resample=False, + rescale=self.diam_mean / diam_style if self.diam_mean > 0 else 1, + diameter=None, interp=False)[0] + + diam = utils.diameters(masks)[0] + diam = self.diam_mean if (diam == 0 or np.isnan(diam)) else diam + return diam, diam_style + + def _size_estimation(self, style): + """ linear regression from style to size + + sizes were estimated using "diameters" from square estimates not circles; + therefore a conversion factor is included (to be removed) + + """ + szest = np.exp(self.params["A"] @ (style - self.params["smean"]).T + + np.log(self.diam_mean) + self.params["ymean"]) + szest = np.maximum(5., szest) + return szest diff --git a/cellpose/plot.py b/cellpose/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..7c89bc1e6a5f745555510db85709770cfa0d4581 --- /dev/null +++ b/cellpose/plot.py @@ -0,0 +1,282 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os +import numpy as np +import cv2 +from scipy.ndimage import gaussian_filter +from . import utils, io, transforms + +try: + import matplotlib + MATPLOTLIB_ENABLED = True +except: + MATPLOTLIB_ENABLED = False + +try: + from skimage import color + from skimage.segmentation import find_boundaries + SKIMAGE_ENABLED = True +except: + SKIMAGE_ENABLED = False + + +# modified to use sinebow color +def dx_to_circ(dP): + """Converts the optic flow representation to a circular color representation. + + Args: + dP (ndarray): Flow field components [dy, dx]. + + Returns: + ndarray: The circular color representation of the optic flow. + + """ + mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.) + angles = np.arctan2(dP[1], dP[0]) + np.pi + a = 2 + mag /= a + rgb = np.zeros((*dP.shape[1:], 3), "uint8") + rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8") + rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8") + rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8") + + return rgb + + +def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None): + """Plot segmentation results (like on website). + + Can save each panel of figure with file_name option. Use channels option if + img input is not an RGB image with 3 channels. + + Args: + fig (matplotlib.pyplot.figure): Figure in which to make plot. + img (ndarray): 2D or 3D array. Image input into cellpose. + maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels. + flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows). + channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0]. + file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None. + seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False. + """ + if not MATPLOTLIB_ENABLED: + raise ImportError( + "matplotlib not installed, install with 'pip install matplotlib'") + ax = fig.add_subplot(1, 4, 1) + img0 = img.copy() + + if img0.shape[0] < 4: + img0 = np.transpose(img0, (1, 2, 0)) + if img0.shape[-1] < 3 or img0.ndim < 3: + img0 = image_to_rgb(img0, channels=channels) + else: + if img0.max() <= 50.0: + img0 = np.uint8(np.clip(img0, 0, 1) * 255) + ax.imshow(img0) + ax.set_title("original image") + ax.axis("off") + + outlines = utils.masks_to_outlines(maski) + + overlay = mask_overlay(img0, maski) + + ax = fig.add_subplot(1, 4, 2) + outX, outY = np.nonzero(outlines) + imgout = img0.copy() + imgout[outX, outY] = np.array([255, 0, 0]) # pure red + + ax.imshow(imgout) + ax.set_title("predicted outlines") + ax.axis("off") + + ax = fig.add_subplot(1, 4, 3) + ax.imshow(overlay) + ax.set_title("predicted masks") + ax.axis("off") + + ax = fig.add_subplot(1, 4, 4) + ax.imshow(flowi) + ax.set_title("predicted cell pose") + ax.axis("off") + + if file_name is not None: + save_path = os.path.splitext(file_name)[0] + io.imsave(save_path + "_overlay.jpg", overlay) + io.imsave(save_path + "_outlines.jpg", imgout) + io.imsave(save_path + "_flows.jpg", flowi) + + +def mask_rgb(masks, colors=None): + """Masks in random RGB colors. + + Args: + masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels. + colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range. + + Returns: + RGB (uint8, 3D array): Array of masks overlaid on grayscale image. + """ + if colors is not None: + if colors.max() > 1: + colors = np.float32(colors) + colors /= 255 + colors = utils.rgb_to_hsv(colors) + + HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32) + HSV[:, :, 2] = 1.0 + for n in range(int(masks.max())): + ipix = (masks == n + 1).nonzero() + if colors is None: + HSV[ipix[0], ipix[1], 0] = np.random.rand() + else: + HSV[ipix[0], ipix[1], 0] = colors[n, 0] + HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5 + HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5 + RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8) + return RGB + + +def mask_overlay(img, masks, colors=None): + """Overlay masks on image (set image to grayscale). + + Args: + img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)]. + masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels. + colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range. + + Returns: + RGB (uint8, 3D array): Array of masks overlaid on grayscale image. + """ + if colors is not None: + if colors.max() > 1: + colors = np.float32(colors) + colors /= 255 + colors = utils.rgb_to_hsv(colors) + if img.ndim > 2: + img = img.astype(np.float32).mean(axis=-1) + else: + img = img.astype(np.float32) + + HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32) + HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1) + hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())] + for n in range(int(masks.max())): + ipix = (masks == n + 1).nonzero() + if colors is None: + HSV[ipix[0], ipix[1], 0] = hues[n] + else: + HSV[ipix[0], ipix[1], 0] = colors[n, 0] + HSV[ipix[0], ipix[1], 1] = 1.0 + RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8) + return RGB + + +def image_to_rgb(img0, channels=[0, 0]): + """Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3. + + Args: + img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2. + + Returns: + ndarray: RGB image of shape Ly x Lx x 3. + + """ + img = img0.copy() + img = img.astype(np.float32) + if img.ndim < 3: + img = img[:, :, np.newaxis] + if img.shape[0] < 5: + img = np.transpose(img, (1, 2, 0)) + if channels[0] == 0: + img = img.mean(axis=-1)[:, :, np.newaxis] + for i in range(img.shape[-1]): + if np.ptp(img[:, :, i]) > 0: + img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1) + img[:, :, i] = np.clip(img[:, :, i], 0, 1) + img *= 255 + img = np.uint8(img) + RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8) + if img.shape[-1] == 1: + RGB = np.tile(img, (1, 1, 3)) + else: + RGB[:, :, channels[0] - 1] = img[:, :, 0] + if channels[1] > 0: + RGB[:, :, channels[1] - 1] = img[:, :, 1] + return RGB + + +def interesting_patch(mask, bsize=130): + """ + Get patch of size bsize x bsize with most masks. + + Args: + mask (ndarray): Input mask. + bsize (int): Size of the patch. + + Returns: + tuple: Patch coordinates (y, x). + + """ + Ly, Lx = mask.shape + m = np.float32(mask > 0) + m = gaussian_filter(m, bsize / 2) + y, x = np.unravel_index(np.argmax(m), m.shape) + ycent = max(bsize // 2, min(y, Ly - bsize // 2)) + xcent = max(bsize // 2, min(x, Lx - bsize // 2)) + patch = [ + np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int), + np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int) + ] + return patch + + +def disk(med, r, Ly, Lx): + """Returns the pixels of a disk with a given radius and center. + + Args: + med (tuple): The center coordinates of the disk. + r (float): The radius of the disk. + Ly (int): The height of the image. + Lx (int): The width of the image. + + Returns: + tuple: A tuple containing the y and x coordinates of the pixels within the disk. + + """ + yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int), + indexing="ij") + inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r + y = yy[inds].flatten() + x = xx[inds].flatten() + return y, x + + +def outline_view(img0, maski, color=[1, 0, 0], mode="inner"): + """ + Generates a red outline overlay onto the image. + + Args: + img0 (numpy.ndarray): The input image. + maski (numpy.ndarray): The mask representing the region of interest. + color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red). + mode (str, optional): The mode for generating the outline. Defaults to "inner". + + Returns: + numpy.ndarray: The image with the red outline overlay. + + """ + if img0.ndim == 2: + img0 = np.stack([img0] * 3, axis=-1) + elif img0.ndim != 3: + raise ValueError("img0 not right size (must have ndim 2 or 3)") + + if SKIMAGE_ENABLED: + outlines = find_boundaries(maski, mode=mode) + else: + outlines = utils.masks_to_outlines(maski, mode=mode) + outY, outX = np.nonzero(outlines) + imgout = img0.copy() + imgout[outY, outX] = np.array(color) + + return imgout diff --git a/cellpose/resnet_torch.py b/cellpose/resnet_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..fd69298e9ef6fca5048ad9d6f96321e6f3c2108a --- /dev/null +++ b/cellpose/resnet_torch.py @@ -0,0 +1,345 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def batchconv(in_channels, out_channels, sz, conv_3D=False): + conv_layer = nn.Conv3d if conv_3D else nn.Conv2d + batch_norm = nn.BatchNorm3d if conv_3D else nn.BatchNorm2d + return nn.Sequential( + batch_norm(in_channels, eps=1e-5, momentum=0.05), + nn.ReLU(inplace=True), + conv_layer(in_channels, out_channels, sz, padding=sz // 2), + ) + + +def batchconv0(in_channels, out_channels, sz, conv_3D=False): + conv_layer = nn.Conv3d if conv_3D else nn.Conv2d + batch_norm = nn.BatchNorm3d if conv_3D else nn.BatchNorm2d + return nn.Sequential( + batch_norm(in_channels, eps=1e-5, momentum=0.05), + conv_layer(in_channels, out_channels, sz, padding=sz // 2), + ) + + +class resdown(nn.Module): + + def __init__(self, in_channels, out_channels, sz, conv_3D=False): + super().__init__() + self.conv = nn.Sequential() + self.proj = batchconv0(in_channels, out_channels, 1, conv_3D) + for t in range(4): + if t == 0: + self.conv.add_module("conv_%d" % t, + batchconv(in_channels, out_channels, sz, conv_3D)) + else: + self.conv.add_module("conv_%d" % t, + batchconv(out_channels, out_channels, sz, conv_3D)) + + def forward(self, x): + x = self.proj(x) + self.conv[1](self.conv[0](x)) + x = x + self.conv[3](self.conv[2](x)) + return x + + +class downsample(nn.Module): + + def __init__(self, nbase, sz, conv_3D=False, max_pool=True): + super().__init__() + self.down = nn.Sequential() + if max_pool: + self.maxpool = nn.MaxPool3d(2, stride=2) if conv_3D else nn.MaxPool2d( + 2, stride=2) + else: + self.maxpool = nn.AvgPool3d(2, stride=2) if conv_3D else nn.AvgPool2d( + 2, stride=2) + for n in range(len(nbase) - 1): + self.down.add_module("res_down_%d" % n, + resdown(nbase[n], nbase[n + 1], sz, conv_3D)) + + def forward(self, x): + xd = [] + for n in range(len(self.down)): + if n > 0: + y = self.maxpool(xd[n - 1]) + else: + y = x + xd.append(self.down[n](y)) + return xd + + +class batchconvstyle(nn.Module): + + def __init__(self, in_channels, out_channels, style_channels, sz, conv_3D=False): + super().__init__() + self.concatenation = False + self.conv = batchconv(in_channels, out_channels, sz, conv_3D) + self.full = nn.Linear(style_channels, out_channels) + + def forward(self, style, x, mkldnn=False, y=None): + if y is not None: + x = x + y + feat = self.full(style) + for k in range(len(x.shape[2:])): + feat = feat.unsqueeze(-1) + if mkldnn: + x = x.to_dense() + y = (x + feat).to_mkldnn() + else: + y = x + feat + y = self.conv(y) + return y + + +class resup(nn.Module): + + def __init__(self, in_channels, out_channels, style_channels, sz, conv_3D=False): + super().__init__() + self.concatenation = False + self.conv = nn.Sequential() + self.conv.add_module("conv_0", + batchconv(in_channels, out_channels, sz, conv_3D=conv_3D)) + self.conv.add_module( + "conv_1", + batchconvstyle(out_channels, out_channels, style_channels, sz, + conv_3D=conv_3D)) + self.conv.add_module( + "conv_2", + batchconvstyle(out_channels, out_channels, style_channels, sz, + conv_3D=conv_3D)) + self.conv.add_module( + "conv_3", + batchconvstyle(out_channels, out_channels, style_channels, sz, + conv_3D=conv_3D)) + self.proj = batchconv0(in_channels, out_channels, 1, conv_3D=conv_3D) + + def forward(self, x, y, style, mkldnn=False): + x = self.proj(x) + self.conv[1](style, self.conv[0](x), y=y, mkldnn=mkldnn) + x = x + self.conv[3](style, self.conv[2](style, x, mkldnn=mkldnn), + mkldnn=mkldnn) + return x + + +class make_style(nn.Module): + + def __init__(self, conv_3D=False): + super().__init__() + self.flatten = nn.Flatten() + self.avg_pool = F.avg_pool3d if conv_3D else F.avg_pool2d + + def forward(self, x0): + style = self.avg_pool(x0, kernel_size=x0.shape[2:]) + style = self.flatten(style) + style = style / torch.sum(style**2, axis=1, keepdim=True)**.5 + return style + + +class upsample(nn.Module): + + def __init__(self, nbase, sz, conv_3D=False): + super().__init__() + self.upsampling = nn.Upsample(scale_factor=2, mode="nearest") + self.up = nn.Sequential() + for n in range(1, len(nbase)): + self.up.add_module("res_up_%d" % (n - 1), + resup(nbase[n], nbase[n - 1], nbase[-1], sz, conv_3D)) + + def forward(self, style, xd, mkldnn=False): + x = self.up[-1](xd[-1], xd[-1], style, mkldnn=mkldnn) + for n in range(len(self.up) - 2, -1, -1): + if mkldnn: + x = self.upsampling(x.to_dense()).to_mkldnn() + else: + x = self.upsampling(x) + x = self.up[n](x, xd[n], style, mkldnn=mkldnn) + return x + + +class CPnet(nn.Module): + """ + CPnet is the Cellpose neural network model used for cell segmentation and image restoration. + + Args: + nbase (list): List of integers representing the number of channels in each layer of the downsample path. + nout (int): Number of output channels. + sz (int): Size of the input image. + mkldnn (bool, optional): Whether to use MKL-DNN acceleration. Defaults to False. + conv_3D (bool, optional): Whether to use 3D convolution. Defaults to False. + max_pool (bool, optional): Whether to use max pooling. Defaults to True. + diam_mean (float, optional): Mean diameter of the cells. Defaults to 30.0. + + Attributes: + nbase (list): List of integers representing the number of channels in each layer of the downsample path. + nout (int): Number of output channels. + sz (int): Size of the input image. + residual_on (bool): Whether to use residual connections. + style_on (bool): Whether to use style transfer. + concatenation (bool): Whether to use concatenation. + conv_3D (bool): Whether to use 3D convolution. + mkldnn (bool): Whether to use MKL-DNN acceleration. + downsample (nn.Module): Downsample blocks of the network. + upsample (nn.Module): Upsample blocks of the network. + make_style (nn.Module): Style module, avgpool's over all spatial positions. + output (nn.Module): Output module - batchconv layer. + diam_mean (nn.Parameter): Parameter representing the mean diameter to which the cells are rescaled to during training. + diam_labels (nn.Parameter): Parameter representing the mean diameter of the cells in the training set (before rescaling). + + """ + + def __init__(self, nbase, nout, sz, mkldnn=False, conv_3D=False, max_pool=True, + diam_mean=30.): + super().__init__() + self.nchan = nbase[0] + self.nbase = nbase + self.nout = nout + self.sz = sz + self.residual_on = True + self.style_on = True + self.concatenation = False + self.conv_3D = conv_3D + self.mkldnn = mkldnn if mkldnn is not None else False + self.downsample = downsample(nbase, sz, conv_3D=conv_3D, max_pool=max_pool) + nbaseup = nbase[1:] + nbaseup.append(nbaseup[-1]) + self.upsample = upsample(nbaseup, sz, conv_3D=conv_3D) + self.make_style = make_style(conv_3D=conv_3D) + self.output = batchconv(nbaseup[0], nout, 1, conv_3D=conv_3D) + self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean, + requires_grad=False) + self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean, + requires_grad=False) + + @property + def device(self): + """ + Get the device of the model. + + Returns: + torch.device: The device of the model. + """ + return next(self.parameters()).device + + def forward(self, data): + """ + Forward pass of the CPnet model. + + Args: + data (torch.Tensor): Input data. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + if self.mkldnn: + data = data.to_mkldnn() + T0 = self.downsample(data) + if self.mkldnn: + style = self.make_style(T0[-1].to_dense()) + else: + style = self.make_style(T0[-1]) + style0 = style + if not self.style_on: + style = style * 0 + T1 = self.upsample(style, T0, self.mkldnn) + T1 = self.output(T1) + if self.mkldnn: + T0 = [t0.to_dense() for t0 in T0] + T1 = T1.to_dense() + return T1, style0, T0 + + def save_model(self, filename): + """ + Save the model to a file. + + Args: + filename (str): The path to the file where the model will be saved. + """ + torch.save(self.state_dict(), filename) + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device, weights_only=True) + else: + self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D, + self.diam_mean) + state_dict = torch.load(filename, map_location=torch.device("cpu"), + weights_only=True) + + if state_dict["output.2.weight"].shape[0] != self.nout: + for name in self.state_dict(): + if "output" not in name: + self.state_dict()[name].copy_(state_dict[name]) + else: + self.load_state_dict( + dict([(name, param) for name, param in state_dict.items()]), + strict=False) + +class CPnetBioImageIO(CPnet): + """ + A subclass of the CPnet model compatible with the BioImage.IO Spec. + + This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec, + allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo. + """ + + def forward(self, x): + """ + Perform a forward pass of the CPnet model and return unpacked tensors. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + output_tensor, style_tensor, downsampled_tensors = super().forward(x) + return output_tensor, style_tensor, *downsampled_tensors + + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device, weights_only=True) + else: + self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D, + self.diam_mean) + state_dict = torch.load(filename, map_location=torch.device("cpu"), + weights_only=True) + + self.load_state_dict(state_dict) + + def load_state_dict(self, state_dict): + """ + Load the state dictionary into the model. + + This method overrides the default `load_state_dict` to handle Cellpose's custom + loading mechanism and ensures compatibility with BioImage.IO Core. + + Args: + state_dict (Mapping[str, Any]): A state dictionary to load into the model + """ + if state_dict["output.2.weight"].shape[0] != self.nout: + for name in self.state_dict(): + if "output" not in name: + self.state_dict()[name].copy_(state_dict[name]) + else: + super().load_state_dict( + {name: param for name, param in state_dict.items()}, + strict=False) + diff --git a/cellpose/segformer.py b/cellpose/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..da142f6f5d4d004b20f876d85decfd42479c4f41 --- /dev/null +++ b/cellpose/segformer.py @@ -0,0 +1,79 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" +import torch +from torch import nn + +try: + import segmentation_models_pytorch as smp + + class Transformer(nn.Module): + """ Transformer encoder from segformer paper with MAnet decoder + (configuration from MEDIAR) + """ + + def __init__(self, encoder="mit_b5", encoder_weights=None, decoder="MAnet", + diam_mean=30.): + super().__init__() + net_fcn = smp.MAnet if decoder == "MAnet" else smp.FPN + self.encoder = encoder + self.decoder = decoder + self.net = net_fcn( + encoder_name=encoder, + encoder_weights=encoder_weights, + # (use "imagenet" pre-trained weights for encoder initialization if training) + in_channels=3, + classes=3, + activation=None) + self.nout = 3 + self.mkldnn = False + self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean, + requires_grad=False) + self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean, + requires_grad=False) + + def forward(self, X): + # have to convert to 3-chan (RGB) + if X.shape[1] < 3: + X = torch.cat( + (X, + torch.zeros((X.shape[0], 3 - X.shape[1], X.shape[2], X.shape[3]), + device=X.device)), dim=1) + y = self.net(X) + return y, torch.zeros((X.shape[0], 256), device=X.device) + + @property + def device(self): + return next(self.parameters()).device + + def save_model(self, filename): + """ + Save the model to a file. + + Args: + filename (str): The path to the file where the model will be saved. + """ + torch.save(self.state_dict(), filename) + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device, weights_only=True) + else: + self.__init__(encoder=self.encoder, decoder=self.decoder, + diam_mean=self.diam_mean) + state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True) + + self.load_state_dict( + dict([(name, param) for name, param in state_dict.items()]), + strict=False) + +except Exception as e: + print(e) + print("need to install segmentation_models_pytorch to run transformer") diff --git a/cellpose/test_mkl.py b/cellpose/test_mkl.py new file mode 100644 index 0000000000000000000000000000000000000000..142f4dbefa0b17eeb62f8a6733dd2f63454132e2 --- /dev/null +++ b/cellpose/test_mkl.py @@ -0,0 +1,42 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import os, sys + +os.environ["MKLDNN_VERBOSE"] = "1" +import numpy as np +import time + +try: + import mxnet as mx + x = mx.sym.Variable("x") + MXNET_ENABLED = True +except: + MXNET_ENABLED = False + + +def test_mkl(): + if MXNET_ENABLED: + num_filter = 32 + kernel = (3, 3) + pad = (1, 1) + shape = (32, 32, 256, 256) + + x = mx.sym.Variable("x") + w = mx.sym.Variable("w") + y = mx.sym.Convolution(data=x, weight=w, num_filter=num_filter, kernel=kernel, + no_bias=True, pad=pad) + exe = y.simple_bind(mx.cpu(), x=shape) + + exe.arg_arrays[0][:] = np.random.normal(size=exe.arg_arrays[0].shape) + exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape) + + exe.forward(is_train=False) + o = exe.outputs[0] + t = o.asnumpy() + + +if __name__ == "__main__": + if MXNET_ENABLED: + test_mkl() diff --git a/cellpose/train.py b/cellpose/train.py new file mode 100644 index 0000000000000000000000000000000000000000..10f619f1c7b20c50d0b808252d11fb21e172b676 --- /dev/null +++ b/cellpose/train.py @@ -0,0 +1,708 @@ +import time +import os +import numpy as np +from cellpose import io, transforms, utils, models, dynamics, metrics, resnet_torch +from cellpose.transforms import normalize_img +from pathlib import Path +import torch +from torch import nn +from tqdm import trange +from numba import prange + +import logging + +train_logger = logging.getLogger(__name__) + + +def _loss_fn_seg(lbl, y, device): + """ + Calculates the loss function between true labels lbl and prediction y. + + Args: + lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX). + y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob). + device (torch.device): Device on which the tensors are located. + + Returns: + torch.Tensor: Loss value. + + """ + criterion = nn.MSELoss(reduction="mean") + criterion2 = nn.BCEWithLogitsLoss(reduction="mean") + veci = 5. * torch.from_numpy(lbl[:, 1:]).to(device) + loss = criterion(y[:, :2], veci) + loss /= 2. + loss2 = criterion2(y[:, -1], torch.from_numpy(lbl[:, 0] > 0.5).to(device).float()) + loss = loss + loss2 + return loss + + +def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, + channels=None, channel_axis=None, rgb=False, + normalize_params={"normalize": False}): + """ + Get a batch of images and labels. + + Args: + inds (list): List of indices indicating which images and labels to retrieve. + data (list or None): List of image data. If None, images will be loaded from files. + labels (list or None): List of label data. If None, labels will be loaded from files. + files (list or None): List of file paths for images. + labels_files (list or None): List of file paths for labels. + channels (list or None): List of channel indices to extract from images. + channel_axis (int or None): Axis along which the channels are located. + normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize). + + Returns: + tuple: A tuple containing two lists: the batch of images and the batch of labels. + """ + if data is None: + lbls = None + imgs = [io.imread(files[i]) for i in inds] + imgs = _reshape_norm(imgs, channels=channels, channel_axis=channel_axis, + rgb=rgb, normalize_params=normalize_params) + if labels_files is not None: + lbls = [io.imread(labels_files[i])[1:] for i in inds] + else: + imgs = [data[i] for i in inds] + lbls = [labels[i][1:] for i in inds] + return imgs, lbls + + +def pad_to_rgb(img): + if img.ndim == 2 or np.ptp(img[1]) < 1e-3: + if img.ndim == 2: + img = img[np.newaxis, :, :] + img = np.tile(img[:1], (3, 1, 1)) + elif img.shape[0] < 3: + nc, Ly, Lx = img.shape + # randomly flip channels + if np.random.rand() > 0.5: + img = img[::-1] + # randomly insert blank channel + ic = np.random.randint(3) + img = np.insert(img, ic, np.zeros((3 - nc, Ly, Lx), dtype=img.dtype), axis=0) + return img + + +def convert_to_rgb(img): + if img.ndim == 2: + img = img[np.newaxis, :, :] + img = np.tile(img, (3, 1, 1)) + elif img.shape[0] < 3: + img = img.mean(axis=0, keepdims=True) + img = transforms.normalize99(img) + img = np.tile(img, (3, 1, 1)) + return img + + +def _reshape_norm(data, channels=None, channel_axis=None, rgb=False, + normalize_params={"normalize": False}): + """ + Reshapes and normalizes the input data. + + Args: + data (list): List of input data. + channels (int or list, optional): Number of channels or list of channel indices to keep. Defaults to None. + channel_axis (int, optional): Axis along which the channels are located. Defaults to None. + normalize_params (dict, optional): Dictionary of normalization parameters. Defaults to {"normalize": False}. + + Returns: + list: List of reshaped and normalized data. + """ + if channels is not None or channel_axis is not None: + data = [ + transforms.convert_image(td, channels=channels, channel_axis=channel_axis) + for td in data + ] + data = [td.transpose(2, 0, 1) for td in data] + if normalize_params["normalize"]: + data = [ + transforms.normalize_img(td, normalize=normalize_params, axis=0) + for td in data + ] + if rgb: + data = [pad_to_rgb(td) for td in data] + return data + + +def _reshape_norm_save(files, channels=None, channel_axis=None, + normalize_params={"normalize": False}): + """ not currently used -- normalization happening on each batch if not load_files """ + files_new = [] + for f in trange(files): + td = io.imread(f) + if channels is not None: + td = transforms.convert_image(td, channels=channels, + channel_axis=channel_axis) + td = td.transpose(2, 0, 1) + if normalize_params["normalize"]: + td = transforms.normalize_img(td, normalize=normalize_params, axis=0) + fnew = os.path.splitext(str(f))[0] + "_cpnorm.tif" + io.imsave(fnew, td) + files_new.append(fnew) + return files_new + # else: + # train_files = reshape_norm_save(train_files, channels=channels, + # channel_axis=channel_axis, normalize_params=normalize_params) + # elif test_files is not None: + # test_files = reshape_norm_save(test_files, channels=channels, + # channel_axis=channel_axis, normalize_params=normalize_params) + + +def _process_train_test(train_data=None, train_labels=None, train_files=None, + train_labels_files=None, train_probs=None, test_data=None, + test_labels=None, test_files=None, test_labels_files=None, + test_probs=None, load_files=True, min_train_masks=5, + compute_flows=False, channels=None, channel_axis=None, + rgb=False, normalize_params={"normalize": False + }, device=None): + """ + Process train and test data. + + Args: + train_data (list or None): List of training data arrays. + train_labels (list or None): List of training label arrays. + train_files (list or None): List of training file paths. + train_labels_files (list or None): List of training label file paths. + train_probs (ndarray or None): Array of training probabilities. + test_data (list or None): List of test data arrays. + test_labels (list or None): List of test label arrays. + test_files (list or None): List of test file paths. + test_labels_files (list or None): List of test label file paths. + test_probs (ndarray or None): Array of test probabilities. + load_files (bool): Whether to load data from files. + min_train_masks (int): Minimum number of masks required for training images. + compute_flows (bool): Whether to compute flows. + channels (list or None): List of channel indices to use. + channel_axis (int or None): Axis of channel dimension. + rgb (bool): Convert training/testing images to RGB. + normalize_params (dict): Dictionary of normalization parameters. + device (torch.device): Device to use for computation. + + Returns: + tuple: A tuple containing the processed train and test data and sampling probabilities and diameters. + """ + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + if train_data is not None and train_labels is not None: + # if data is loaded + nimg = len(train_data) + nimg_test = len(test_data) if test_data is not None else None + else: + # otherwise use files + nimg = len(train_files) + if train_labels_files is None: + train_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files + ] + train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)] + if (test_data is not None or test_files is not None) and test_labels_files is None: + test_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files + ] + test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)] + if not load_files: + train_logger.info(">>> using files instead of loading dataset") + else: + # load all images + train_logger.info(">>> loading images and labels") + train_data = [io.imread(train_files[i]) for i in trange(nimg)] + train_labels = [io.imread(train_labels_files[i]) for i in trange(nimg)] + nimg_test = len(test_files) if test_files is not None else None + if load_files and nimg_test: + test_data = [io.imread(test_files[i]) for i in trange(nimg_test)] + test_labels = [io.imread(test_labels_files[i]) for i in trange(nimg_test)] + + ### check that arrays are correct size + if ((train_labels is not None and nimg != len(train_labels)) or + (train_labels_files is not None and nimg != len(train_labels_files))): + error_message = "train data and labels not same length" + train_logger.critical(error_message) + raise ValueError(error_message) + if ((test_labels is not None and nimg_test != len(test_labels)) or + (test_labels_files is not None and nimg_test != len(test_labels_files))): + train_logger.warning("test data and labels not same length, not using") + test_data, test_files = None, None + if train_labels is not None: + if train_labels[0].ndim < 2 or train_data[0].ndim < 2: + error_message = "training data or labels are not at least two-dimensional" + train_logger.critical(error_message) + raise ValueError(error_message) + if train_data[0].ndim > 3: + error_message = "training data is more than three-dimensional (should be 2D or 3D array)" + train_logger.critical(error_message) + raise ValueError(error_message) + + ### check that flows are computed + if train_labels is not None: + train_labels = dynamics.labels_to_flows(train_labels, files=train_files, + device=device) + if test_labels is not None: + test_labels = dynamics.labels_to_flows(test_labels, files=test_files, + device=device) + elif compute_flows: + for k in trange(nimg): + tl = dynamics.labels_to_flows(io.imread(train_labels_files), + files=train_files, device=device) + if test_files is not None: + for k in trange(nimg_test): + tl = dynamics.labels_to_flows(io.imread(test_labels_files), + files=test_files, device=device) + + ### compute diameters + nmasks = np.zeros(nimg) + diam_train = np.zeros(nimg) + train_logger.info(">>> computing diameters") + for k in trange(nimg): + tl = (train_labels[k][0] + if train_labels is not None else io.imread(train_labels_files[k])[0]) + diam_train[k], dall = utils.diameters(tl) + nmasks[k] = len(dall) + diam_train[diam_train < 5] = 5. + if test_data is not None: + diam_test = np.array( + [utils.diameters(test_labels[k][0])[0] for k in trange(len(test_labels))]) + diam_test[diam_test < 5] = 5. + elif test_labels_files is not None: + diam_test = np.array([ + utils.diameters(io.imread(test_labels_files[k])[0])[0] + for k in trange(len(test_labels_files)) + ]) + diam_test[diam_test < 5] = 5. + else: + diam_test = None + + ### check to remove training images with too few masks + if min_train_masks > 0: + nremove = (nmasks < min_train_masks).sum() + if nremove > 0: + train_logger.warning( + f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set" + ) + ikeep = np.nonzero(nmasks >= min_train_masks)[0] + if train_data is not None: + train_data = [train_data[i] for i in ikeep] + train_labels = [train_labels[i] for i in ikeep] + if train_files is not None: + train_files = [train_files[i] for i in ikeep] + if train_labels_files is not None: + train_labels_files = [train_labels_files[i] for i in ikeep] + if train_probs is not None: + train_probs = train_probs[ikeep] + diam_train = diam_train[ikeep] + nimg = len(train_data) + + ### normalize probabilities + train_probs = 1. / nimg * np.ones(nimg, + "float64") if train_probs is None else train_probs + train_probs /= train_probs.sum() + if test_files is not None or test_data is not None: + test_probs = 1. / nimg_test * np.ones( + nimg_test, "float64") if test_probs is None else test_probs + test_probs /= test_probs.sum() + + ### reshape and normalize train / test data + normed = False + if channels is not None or normalize_params["normalize"]: + if channels: + train_logger.info(f">>> using channels {channels}") + if normalize_params["normalize"]: + train_logger.info(f">>> normalizing {normalize_params}") + if train_data is not None: + train_data = _reshape_norm(train_data, channels=channels, + channel_axis=channel_axis, rgb=rgb, + normalize_params=normalize_params) + normed = True + if test_data is not None: + test_data = _reshape_norm(test_data, channels=channels, + channel_axis=channel_axis, rgb=rgb, + normalize_params=normalize_params) + + return (train_data, train_labels, train_files, train_labels_files, train_probs, + diam_train, test_data, test_labels, test_files, test_labels_files, + test_probs, diam_test, normed) + + +def train_seg(net, train_data=None, train_labels=None, train_files=None, + train_labels_files=None, train_probs=None, test_data=None, + test_labels=None, test_files=None, test_labels_files=None, + test_probs=None, load_files=True, batch_size=8, learning_rate=0.005, + n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None, + channel_axis=None, rgb=False, normalize=True, compute_flows=False, + save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, + nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224, + min_train_masks=5, model_name=None): + """ + Train the network with images for segmentation. + + Args: + net (object): The network model to train. + train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. + train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. + train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. + train_labels_files (list or None): List of training label file paths. Defaults to None. + train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None. + test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None. + test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. + test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None. + test_labels_files (list or None): List of test label file paths. Defaults to None. + test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None. + load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True. + batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8. + learning_rate (float or List[float], optional): Float or list/np.ndarray - learning rate for training. Defaults to 0.005. + n_epochs (int, optional): Integer - number of times to go through the whole training set during training. Defaults to 2000. + weight_decay (float, optional): Float - weight decay for the optimizer. Defaults to 1e-5. + momentum (float, optional): Float - momentum for the optimizer. Defaults to 0.9. + SGD (bool, optional): Boolean - whether to use SGD as optimization instead of RAdam. Defaults to False. + channels (List[int], optional): List of ints - channels to use for training. Defaults to None. + channel_axis (int, optional): Integer - axis of the channel dimension in the input data. Defaults to None. + normalize (bool or dict, optional): Boolean or dictionary - whether to normalize the data. Defaults to True. + compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False. + save_path (str, optional): String - where to save the trained model. Defaults to None. + save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100. + save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False. + nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None. + nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None. + rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True. + min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5. + model_name (str, optional): String - name of the network. Defaults to None. + + Returns: + tuple: A tuple containing the path to the saved model weights, training losses, and test losses. + + """ + device = net.device + + scale_range0 = 0.5 if rescale else 1.0 + scale_range = scale_range if scale_range is not None else scale_range0 + + if isinstance(normalize, dict): + normalize_params = {**models.normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = models.normalize_default + normalize_params["normalize"] = normalize + + out = _process_train_test(train_data=train_data, train_labels=train_labels, + train_files=train_files, train_labels_files=train_labels_files, + train_probs=train_probs, + test_data=test_data, test_labels=test_labels, + test_files=test_files, test_labels_files=test_labels_files, + test_probs=test_probs, + load_files=load_files, min_train_masks=min_train_masks, + compute_flows=compute_flows, channels=channels, + channel_axis=channel_axis, rgb=rgb, + normalize_params=normalize_params, device=net.device) + (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, + test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, + normed) = out + # already normalized, do not normalize during training + if normed: + kwargs = {} + else: + kwargs = { + "normalize_params": normalize_params, + "channels": channels, + "channel_axis": channel_axis, + "rgb": rgb + } + + net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device) + + nimg = len(train_data) if train_data is not None else len(train_files) + nimg_test = len(test_data) if test_data is not None else None + nimg_test = len(test_files) if test_files is not None else nimg_test + nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch + nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch + + # learning rate schedule + LR = np.linspace(0, learning_rate, 10) + LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10))) + if n_epochs > 300: + LR = LR[:-100] + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(10)) + elif n_epochs > 100: + LR = LR[:-50] + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(5)) + + train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}") + + if not SGD: + train_logger.info( + f">>> AdamW, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}" + ) + optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate, + weight_decay=weight_decay) + else: + train_logger.info( + f">>> SGD, learning_rate={learning_rate:0.5f}, weight_decay={weight_decay:0.5f}, momentum={momentum:0.3f}" + ) + optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, + weight_decay=weight_decay, momentum=momentum) + + t0 = time.time() + model_name = f"cellpose_{t0}" if model_name is None else model_name + save_path = Path.cwd() if save_path is None else Path(save_path) + filename = save_path / "models" / model_name + (save_path / "models").mkdir(exist_ok=True) + + train_logger.info(f">>> saving model to {filename}") + + lavg, nsum = 0, 0 + train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs) + for iepoch in range(n_epochs): + np.random.seed(iepoch) + if nimg != nimg_per_epoch: + # choose random images for epoch with probability train_probs + rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), + p=train_probs) + else: + # otherwise use all images + rperm = np.random.permutation(np.arange(0, nimg)) + for param_group in optimizer.param_groups: + param_group["lr"] = LR[iepoch] # set learning rate + net.train() + for k in range(0, nimg_per_epoch, batch_size): + kend = min(k + batch_size, nimg_per_epoch) + inds = rperm[k:kend] + imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, + files=train_files, labels_files=train_labels_files, + **kwargs) + diams = np.array([diam_train[i] for i in inds]) + rsc = diams / net.diam_mean.item() if rescale else np.ones( + len(diams), "float32") + # augmentations + imgi, lbl = transforms.random_rotate_and_resize(imgs, Y=lbls, rescale=rsc, + scale_range=scale_range, + xy=(bsize, bsize))[:2] + # network and loss optimization + X = torch.from_numpy(imgi).to(device) + y = net(X)[0] + loss = _loss_fn_seg(lbl, y, device) + optimizer.zero_grad() + loss.backward() + optimizer.step() + train_loss = loss.item() + train_loss *= len(imgi) + + # keep track of average training loss across epochs + lavg += train_loss + nsum += len(imgi) + # per epoch training loss + train_losses[iepoch] += train_loss + train_losses[iepoch] /= nimg_per_epoch + + if iepoch == 5 or iepoch % 10 == 0: + lavgt = 0. + if test_data is not None or test_files is not None: + np.random.seed(42) + if nimg_test != nimg_test_per_epoch: + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + else: + rperm = np.random.permutation(np.arange(0, nimg_test)) + for ibatch in range(0, len(rperm), batch_size): + with torch.no_grad(): + net.eval() + inds = rperm[ibatch:ibatch + batch_size] + imgs, lbls = _get_batch(inds, data=test_data, + labels=test_labels, files=test_files, + labels_files=test_labels_files, + **kwargs) + diams = np.array([diam_test[i] for i in inds]) + rsc = diams / net.diam_mean.item() if rescale else np.ones( + len(diams), "float32") + imgi, lbl = transforms.random_rotate_and_resize( + imgs, Y=lbls, rescale=rsc, scale_range=scale_range, + xy=(bsize, bsize))[:2] + X = torch.from_numpy(imgi).to(device) + y = net(X)[0] + loss = _loss_fn_seg(lbl, y, device) + test_loss = loss.item() + test_loss *= len(imgi) + lavgt += test_loss + lavgt /= len(rperm) + test_losses[iepoch] = lavgt + lavg /= nsum + train_logger.info( + f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" + ) + lavg, nsum = 0, 0 + + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): + if save_each and iepoch != n_epochs - 1: #separate files as model progresses + filename0 = str(filename) + f"_epoch_{iepoch:04d}" + else: + filename0 = filename + train_logger.info(f"saving network parameters to {filename0}") + net.save_model(filename0) + + net.save_model(filename) + + return filename, train_losses, test_losses + + +def train_size(net, pretrained_model, train_data=None, train_labels=None, + train_files=None, train_labels_files=None, train_probs=None, + test_data=None, test_labels=None, test_files=None, + test_labels_files=None, test_probs=None, load_files=True, + min_train_masks=5, channels=None, channel_axis=None, rgb=False, + normalize=True, nimg_per_epoch=None, nimg_test_per_epoch=None, + batch_size=64, scale_range=1.0, bsize=512, l2_regularization=1.0, + n_epochs=10): + """Train the size model. + + Args: + net (object): The neural network model. + pretrained_model (str): The path to the pretrained model. + train_data (numpy.ndarray, optional): The training data. Defaults to None. + train_labels (numpy.ndarray, optional): The training labels. Defaults to None. + train_files (list, optional): The training file paths. Defaults to None. + train_labels_files (list, optional): The training label file paths. Defaults to None. + train_probs (numpy.ndarray, optional): The training probabilities. Defaults to None. + test_data (numpy.ndarray, optional): The test data. Defaults to None. + test_labels (numpy.ndarray, optional): The test labels. Defaults to None. + test_files (list, optional): The test file paths. Defaults to None. + test_labels_files (list, optional): The test label file paths. Defaults to None. + test_probs (numpy.ndarray, optional): The test probabilities. Defaults to None. + load_files (bool, optional): Whether to load files. Defaults to True. + min_train_masks (int, optional): The minimum number of training masks. Defaults to 5. + channels (list, optional): The channels. Defaults to None. + channel_axis (int, optional): The channel axis. Defaults to None. + normalize (bool or dict, optional): Whether to normalize the data. Defaults to True. + nimg_per_epoch (int, optional): The number of images per epoch. Defaults to None. + nimg_test_per_epoch (int, optional): The number of test images per epoch. Defaults to None. + batch_size (int, optional): The batch size. Defaults to 64. + l2_regularization (float, optional): The L2 regularization factor. Defaults to 1.0. + n_epochs (int, optional): The number of epochs. Defaults to 10. + + Returns: + dict: The trained size model parameters. + """ + if isinstance(normalize, dict): + normalize_params = {**models.normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = models.normalize_default + normalize_params["normalize"] = normalize + + out = _process_train_test( + train_data=train_data, train_labels=train_labels, train_files=train_files, + train_labels_files=train_labels_files, train_probs=train_probs, + test_data=test_data, test_labels=test_labels, test_files=test_files, + test_labels_files=test_labels_files, test_probs=test_probs, + load_files=load_files, min_train_masks=min_train_masks, compute_flows=False, + channels=channels, channel_axis=channel_axis, normalize_params=normalize_params, + device=net.device) + (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, + test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, + normed) = out + + # already normalized, do not normalize during training + if normed: + kwargs = {} + else: + kwargs = { + "normalize_params": normalize_params, + "channels": channels, + "channel_axis": channel_axis, + "rgb": rgb + } + + nimg = len(train_data) if train_data is not None else len(train_files) + nimg_test = len(test_data) if test_data is not None else None + nimg_test = len(test_files) if test_files is not None else nimg_test + nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch + nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch + + diam_mean = net.diam_mean.item() + device = net.device + net.eval() + + styles = np.zeros((n_epochs * nimg_per_epoch, 256), np.float32) + diams = np.zeros((n_epochs * nimg_per_epoch,), np.float32) + tic = time.time() + for iepoch in range(n_epochs): + np.random.seed(iepoch) + if nimg != nimg_per_epoch: + rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), + p=train_probs) + else: + rperm = np.random.permutation(np.arange(0, nimg)) + for ibatch in range(0, nimg_per_epoch, batch_size): + inds_batch = np.arange(ibatch, min(nimg_per_epoch, ibatch + batch_size)) + inds = rperm[inds_batch] + imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, + files=train_files, **kwargs) + diami = diam_train[inds].copy() + imgi, lbl, scale = transforms.random_rotate_and_resize( + imgs, scale_range=scale_range, xy=(bsize, bsize)) + imgi = torch.from_numpy(imgi).to(device) + with torch.no_grad(): + feat = net(imgi)[1] + indsi = inds_batch + nimg_per_epoch * iepoch + styles[indsi] = feat.cpu().numpy() + diams[indsi] = np.log(diami) - np.log(diam_mean) + np.log(scale) + del feat + train_logger.info("ran %d epochs in %0.3f sec" % + (iepoch + 1, time.time() - tic)) + + l2_regularization = 1. + + # create model + smean = styles.copy().mean(axis=0) + X = ((styles.copy() - smean).T).copy() + ymean = diams.copy().mean() + y = diams.copy() - ymean + + A = np.linalg.solve(X @ X.T + l2_regularization * np.eye(X.shape[0]), X @ y) + ypred = A @ X + + train_logger.info("train correlation: %0.4f" % np.corrcoef(y, ypred)[0, 1]) + + if nimg_test: + np.random.seed(0) + styles_test = np.zeros((nimg_test_per_epoch, 256), np.float32) + diams_test = np.zeros((nimg_test_per_epoch,), np.float32) + diams_test0 = np.zeros((nimg_test_per_epoch,), np.float32) + if nimg_test != nimg_test_per_epoch: + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + else: + rperm = np.random.permutation(np.arange(0, nimg_test)) + for ibatch in range(0, nimg_test_per_epoch, batch_size): + inds_batch = np.arange(ibatch, min(nimg_test_per_epoch, + ibatch + batch_size)) + inds = rperm[inds_batch] + imgs, lbls = _get_batch(inds, data=test_data, labels=test_labels, + files=test_files, labels_files=test_labels_files, + **kwargs) + diami = diam_test[inds].copy() + imgi, lbl, scale = transforms.random_rotate_and_resize( + imgs, Y=lbls, scale_range=scale_range, xy=(bsize, bsize)) + imgi = torch.from_numpy(imgi).to(device) + diamt = np.array([utils.diameters(lbl0[0])[0] for lbl0 in lbl]) + diamt = np.maximum(5., diamt) + with torch.no_grad(): + feat = net(imgi)[1] + styles_test[inds_batch] = feat.cpu().numpy() + diams_test[inds_batch] = np.log(diami) - np.log(diam_mean) + np.log(scale) + diams_test0[inds_batch] = diamt + + diam_test_pred = np.exp(A @ (styles_test - smean).T + np.log(diam_mean) + ymean) + diam_test_pred = np.maximum(5., diam_test_pred) + train_logger.info("test correlation: %0.4f" % + np.corrcoef(diams_test0, diam_test_pred)[0, 1]) + + pretrained_size = str(pretrained_model) + "_size.npy" + params = {"A": A, "smean": smean, "diam_mean": diam_mean, "ymean": ymean} + np.save(pretrained_size, params) + train_logger.info("model saved to " + pretrained_size) + + return params diff --git a/cellpose/transforms.py b/cellpose/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5f5160272848462c94eedcc1db4c0e7b624a99 --- /dev/null +++ b/cellpose/transforms.py @@ -0,0 +1,1034 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import logging +import warnings + +import cv2 +import numpy as np +import torch +from scipy.ndimage import gaussian_filter1d +from torch.fft import fft2, fftshift, ifft2 + +transforms_logger = logging.getLogger(__name__) + + +def _taper_mask(ly=224, lx=224, sig=7.5): + """ + Generate a taper mask. + + Args: + ly (int): The height of the mask. Default is 224. + lx (int): The width of the mask. Default is 224. + sig (float): The sigma value for the tapering function. Default is 7.5. + + Returns: + numpy.ndarray: The taper mask. + + """ + bsize = max(224, max(ly, lx)) + xm = np.arange(bsize) + xm = np.abs(xm - xm.mean()) + mask = 1 / (1 + np.exp((xm - (bsize / 2 - 20)) / sig)) + mask = mask * mask[:, np.newaxis] + mask = mask[bsize // 2 - ly // 2:bsize // 2 + ly // 2 + ly % 2, + bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2] + return mask + + +def unaugment_tiles(y): + """Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX). + + Args: + y (float32): Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx) where chan = (flowsY, flowsX, cell prob). + + Returns: + float32: Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx). + + """ + for j in range(y.shape[0]): + for i in range(y.shape[1]): + if j % 2 == 0 and i % 2 == 1: + y[j, i] = y[j, i, :, ::-1, :] + y[j, i, 0] *= -1 + elif j % 2 == 1 and i % 2 == 0: + y[j, i] = y[j, i, :, :, ::-1] + y[j, i, 1] *= -1 + elif j % 2 == 1 and i % 2 == 1: + y[j, i] = y[j, i, :, ::-1, ::-1] + y[j, i, 0] *= -1 + y[j, i, 1] *= -1 + return y + + +def average_tiles(y, ysub, xsub, Ly, Lx): + """ + Average the results of the network over tiles. + + Args: + y (float): Output of cellpose network for each tile. Shape: [ntiles x nclasses x bsize x bsize] + ysub (list): List of arrays with start and end of tiles in Y of length ntiles + xsub (list): List of arrays with start and end of tiles in X of length ntiles + Ly (int): Size of pre-tiled image in Y (may be larger than original image if image size is less than bsize) + Lx (int): Size of pre-tiled image in X (may be larger than original image if image size is less than bsize) + + Returns: + yf (float32): Network output averaged over tiles. Shape: [nclasses x Ly x Lx] + """ + Navg = np.zeros((Ly, Lx)) + yf = np.zeros((y.shape[1], Ly, Lx), np.float32) + # taper edges of tiles + mask = _taper_mask(ly=y.shape[-2], lx=y.shape[-1]) + for j in range(len(ysub)): + yf[:, ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += y[j] * mask + Navg[ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += mask + yf /= Navg + return yf + + +def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1): + """Make tiles of image to run at test-time. + + Args: + imgi (np.ndarray): Array of shape (nchan, Ly, Lx) representing the input image. + bsize (int, optional): Size of tiles. Defaults to 224. + augment (bool, optional): Whether to flip tiles and set tile_overlap=2. Defaults to False. + tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1. + + Returns: + A tuple containing (IMG, ysub, xsub, Ly, Lx): + IMG (np.ndarray): Array of shape (ntiles, nchan, bsize, bsize) representing the tiles. + ysub (list): List of arrays with start and end of tiles in Y of length ntiles. + xsub (list): List of arrays with start and end of tiles in X of length ntiles. + Ly (int): Height of the input image. + Lx (int): Width of the input image. + """ + nchan, Ly, Lx = imgi.shape + if augment: + bsize = np.int32(bsize) + # pad if image smaller than bsize + if Ly < bsize: + imgi = np.concatenate((imgi, np.zeros((nchan, bsize - Ly, Lx))), axis=1) + Ly = bsize + if Lx < bsize: + imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2) + Ly, Lx = imgi.shape[-2:] + + # tiles overlap by half of tile size + ny = max(2, int(np.ceil(2. * Ly / bsize))) + nx = max(2, int(np.ceil(2. * Lx / bsize))) + ystart = np.linspace(0, Ly - bsize, ny).astype(int) + xstart = np.linspace(0, Lx - bsize, nx).astype(int) + + ysub = [] + xsub = [] + + # flip tiles so that overlapping segments are processed in rotation + IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32) + for j in range(len(ystart)): + for i in range(len(xstart)): + ysub.append([ystart[j], ystart[j] + bsize]) + xsub.append([xstart[i], xstart[i] + bsize]) + IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]] + # flip tiles to allow for augmentation of overlapping segments + if j % 2 == 0 and i % 2 == 1: + IMG[j, i] = IMG[j, i, :, ::-1, :] + elif j % 2 == 1 and i % 2 == 0: + IMG[j, i] = IMG[j, i, :, :, ::-1] + elif j % 2 == 1 and i % 2 == 1: + IMG[j, i] = IMG[j, i, :, ::-1, ::-1] + else: + tile_overlap = min(0.5, max(0.05, tile_overlap)) + bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx) + bsizeY = np.int32(bsizeY) + bsizeX = np.int32(bsizeX) + # tiles overlap by 10% tile size + ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) + nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) + ystart = np.linspace(0, Ly - bsizeY, ny).astype(int) + xstart = np.linspace(0, Lx - bsizeX, nx).astype(int) + + ysub = [] + xsub = [] + IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32) + for j in range(len(ystart)): + for i in range(len(xstart)): + ysub.append([ystart[j], ystart[j] + bsizeY]) + xsub.append([xstart[i], xstart[i] + bsizeX]) + IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]] + + return IMG, ysub, xsub, Ly, Lx + + +def normalize99(Y, lower=1, upper=99, copy=True, downsample=False): + """ + Normalize the image so that 0.0 corresponds to the 1st percentile and 1.0 corresponds to the 99th percentile. + + Args: + Y (ndarray): The input image (for downsample, use [Ly x Lx] or [Lz x Ly x Lx]). + lower (int, optional): The lower percentile. Defaults to 1. + upper (int, optional): The upper percentile. Defaults to 99. + copy (bool, optional): Whether to create a copy of the input image. Defaults to True. + downsample (bool, optional): Whether to downsample image to compute percentiles. Defaults to False. + + Returns: + ndarray: The normalized image. + """ + X = Y.copy() if copy else Y + X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X + if downsample and X.size > 224**3: + nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)] + nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0] + slc = tuple([slice(0, X.shape[i], nskip[i]) for i in range(X.ndim)]) + x01 = np.percentile(X[slc], lower) + x99 = np.percentile(X[slc], upper) + else: + x01 = np.percentile(X, lower) + x99 = np.percentile(X, upper) + if x99 - x01 > 1e-3: + X -= x01 + X /= (x99 - x01) + else: + X[:] = 0 + return X + + +def normalize99_tile(img, blocksize=100, lower=1., upper=99., tile_overlap=0.1, + norm3D=False, smooth3D=1, is3D=False): + """Compute normalization like normalize99 function but in tiles. + + Args: + img (numpy.ndarray): Array of shape (Lz x) Ly x Lx (x nchan) containing the image. + blocksize (float, optional): Size of tiles. Defaults to 100. + lower (float, optional): Lower percentile for normalization. Defaults to 1.0. + upper (float, optional): Upper percentile for normalization. Defaults to 99.0. + tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1. + norm3D (bool, optional): Use same tiled normalization for each z-plane. Defaults to False. + smooth3D (int, optional): Smoothing factor for 3D normalization. Defaults to 1. + is3D (bool, optional): Set to True if image is a 3D stack. Defaults to False. + + Returns: + numpy.ndarray: Normalized image array of shape (Lz x) Ly x Lx (x nchan). + """ + is1c = True if img.ndim == 2 or (is3D and img.ndim == 3) else False + is3D = True if img.ndim > 3 or (is3D and img.ndim == 3) else False + img = img[..., np.newaxis] if is1c else img + img = img[np.newaxis, ...] if img.ndim == 3 else img + Lz, Ly, Lx, nchan = img.shape + + tile_overlap = min(0.5, max(0.05, tile_overlap)) + blocksizeY, blocksizeX = min(blocksize, Ly), min(blocksize, Lx) + blocksizeY = np.int32(blocksizeY) + blocksizeX = np.int32(blocksizeX) + # tiles overlap by 10% tile size + ny = 1 if Ly <= blocksize else int(np.ceil( + (1. + 2 * tile_overlap) * Ly / blocksize)) + nx = 1 if Lx <= blocksize else int(np.ceil( + (1. + 2 * tile_overlap) * Lx / blocksize)) + ystart = np.linspace(0, Ly - blocksizeY, ny).astype(int) + xstart = np.linspace(0, Lx - blocksizeX, nx).astype(int) + ysub = [] + xsub = [] + for j in range(len(ystart)): + for i in range(len(xstart)): + ysub.append([ystart[j], ystart[j] + blocksizeY]) + xsub.append([xstart[i], xstart[i] + blocksizeX]) + + x01_tiles_z = [] + x99_tiles_z = [] + for z in range(Lz): + IMG = np.zeros((len(ystart), len(xstart), blocksizeY, blocksizeX, nchan), + "float32") + k = 0 + for j in range(len(ystart)): + for i in range(len(xstart)): + IMG[j, i] = img[z, ysub[k][0]:ysub[k][1], xsub[k][0]:xsub[k][1], :] + k += 1 + x01_tiles = np.percentile(IMG, lower, axis=(-3, -2)) + x99_tiles = np.percentile(IMG, upper, axis=(-3, -2)) + + # fill areas with small differences with neighboring squares + to_fill = np.zeros(x01_tiles.shape[:2], "bool") + for c in range(nchan): + to_fill = x99_tiles[:, :, c] - x01_tiles[:, :, c] < +1e-3 + if to_fill.sum() > 0 and to_fill.sum() < x99_tiles[:, :, c].size: + fill_vals = np.nonzero(to_fill) + fill_neigh = np.nonzero(~to_fill) + nearest_neigh = ( + (fill_vals[0] - fill_neigh[0][:, np.newaxis])**2 + + (fill_vals[1] - fill_neigh[1][:, np.newaxis])**2).argmin(axis=0) + x01_tiles[fill_vals[0], fill_vals[1], + c] = x01_tiles[fill_neigh[0][nearest_neigh], + fill_neigh[1][nearest_neigh], c] + x99_tiles[fill_vals[0], fill_vals[1], + c] = x99_tiles[fill_neigh[0][nearest_neigh], + fill_neigh[1][nearest_neigh], c] + elif to_fill.sum() > 0 and to_fill.sum() == x99_tiles[:, :, c].size: + x01_tiles[:, :, c] = 0 + x99_tiles[:, :, c] = 1 + x01_tiles_z.append(x01_tiles) + x99_tiles_z.append(x99_tiles) + + x01_tiles_z = np.array(x01_tiles_z) + x99_tiles_z = np.array(x99_tiles_z) + # do not smooth over z-axis if not normalizing separately per plane + for a in range(2): + x01_tiles_z = gaussian_filter1d(x01_tiles_z, 1, axis=a) + x99_tiles_z = gaussian_filter1d(x99_tiles_z, 1, axis=a) + if norm3D: + smooth3D = 1 if smooth3D == 0 else smooth3D + x01_tiles_z = gaussian_filter1d(x01_tiles_z, smooth3D, axis=a) + x99_tiles_z = gaussian_filter1d(x99_tiles_z, smooth3D, axis=a) + + if not norm3D and Lz > 1: + x01 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32") + x99 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32") + for z in range(Lz): + x01_rsz = cv2.resize(x01_tiles_z[z], (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + x01[z] = x01_rsz[..., np.newaxis] if nchan == 1 else x01_rsz + x99_rsz = cv2.resize(x99_tiles_z[z], (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + x99[z] = x99_rsz[..., np.newaxis] if nchan == 1 else x01_rsz + if (x99 - x01).min() < 1e-3: + raise ZeroDivisionError( + "cannot use norm3D=False with tile_norm, sample is too sparse; set norm3D=True or tile_norm=0" + ) + else: + x01 = cv2.resize(x01_tiles_z.mean(axis=0), (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + x99 = cv2.resize(x99_tiles_z.mean(axis=0), (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + if x01.ndim < 3: + x01 = x01[..., np.newaxis] + x99 = x99[..., np.newaxis] + + if is1c: + img, x01, x99 = img.squeeze(), x01.squeeze(), x99.squeeze() + elif not is3D: + img, x01, x99 = img[0], x01[0], x99[0] + + # normalize + img -= x01 + img /= (x99 - x01) + + return img + + +def gaussian_kernel(sigma, Ly, Lx, device=torch.device("cpu")): + """ + Generates a 2D Gaussian kernel. + + Args: + sigma (float): Standard deviation of the Gaussian distribution. + Ly (int): Number of pixels in the y-axis. + Lx (int): Number of pixels in the x-axis. + device (torch.device, optional): Device to store the kernel tensor. Defaults to torch.device("cpu"). + + Returns: + torch.Tensor: 2D Gaussian kernel tensor. + + """ + y = torch.linspace(-Ly / 2, Ly / 2 + 1, Ly, device=device) + x = torch.linspace(-Ly / 2, Ly / 2 + 1, Lx, device=device) + y, x = torch.meshgrid(y, x, indexing="ij") + kernel = torch.exp(-(y**2 + x**2) / (2 * sigma**2)) + kernel /= kernel.sum() + return kernel + + +def smooth_sharpen_img(img, smooth_radius=6, sharpen_radius=12, + device=torch.device("cpu"), is3D=False): + """Sharpen blurry images with surround subtraction and/or smooth noisy images. + + Args: + img (float32): Array that's (Lz x) Ly x Lx (x nchan). + smooth_radius (float, optional): Size of gaussian smoothing filter, recommended to be 1/10-1/4 of cell diameter + (if also sharpening, should be 2-3x smaller than sharpen_radius). Defaults to 6. + sharpen_radius (float, optional): Size of gaussian surround filter, recommended to be 1/8-1/2 of cell diameter + (if also smoothing, should be 2-3x larger than smooth_radius). Defaults to 12. + device (torch.device, optional): Device on which to perform sharpening. + Will be faster on GPU but need to ensure GPU has RAM for image. Defaults to torch.device("cpu"). + is3D (bool, optional): If image is 3D stack (only necessary to set if img.ndim==3). Defaults to False. + + Returns: + img_sharpen (float32): Array that's (Lz x) Ly x Lx (x nchan). + """ + img_sharpen = torch.from_numpy(img.astype("float32")).to(device) + shape = img_sharpen.shape + + is1c = True if img_sharpen.ndim == 2 or (is3D and img_sharpen.ndim == 3) else False + is3D = True if img_sharpen.ndim > 3 or (is3D and img_sharpen.ndim == 3) else False + img_sharpen = img_sharpen.unsqueeze(-1) if is1c else img_sharpen + img_sharpen = img_sharpen.unsqueeze(0) if img_sharpen.ndim == 3 else img_sharpen + Lz, Ly, Lx, nchan = img_sharpen.shape + + if smooth_radius > 0: + kernel = gaussian_kernel(smooth_radius, Ly, Lx, device=device) + if sharpen_radius > 0: + kernel += -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device) + elif sharpen_radius > 0: + kernel = -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device) + kernel[Ly // 2, Lx // 2] = 1 + + fhp = fft2(kernel) + for z in range(Lz): + for c in range(nchan): + img_filt = torch.real(ifft2( + fft2(img_sharpen[z, :, :, c]) * torch.conj(fhp))) + img_filt = fftshift(img_filt) + img_sharpen[z, :, :, c] = img_filt + + img_sharpen = img_sharpen.reshape(shape) + return img_sharpen.cpu().numpy() + + +def move_axis(img, m_axis=-1, first=True): + """ move axis m_axis to first or last position """ + if m_axis == -1: + m_axis = img.ndim - 1 + m_axis = min(img.ndim - 1, m_axis) + axes = np.arange(0, img.ndim) + if first: + axes[1:m_axis + 1] = axes[:m_axis] + axes[0] = m_axis + else: + axes[m_axis:-1] = axes[m_axis + 1:] + axes[-1] = m_axis + img = img.transpose(tuple(axes)) + return img + + +def move_min_dim(img, force=False): + """Move the minimum dimension last as channels if it is less than 10 or force is True. + + Args: + img (ndarray): The input image. + force (bool, optional): If True, the minimum dimension will always be moved. + Defaults to False. + + Returns: + ndarray: The image with the minimum dimension moved to the last axis as channels. + """ + if len(img.shape) > 2: + min_dim = min(img.shape) + if min_dim < 10 or force: + if img.shape[-1] == min_dim: + channel_axis = -1 + else: + channel_axis = (img.shape).index(min_dim) + img = move_axis(img, m_axis=channel_axis, first=False) + return img + + +def update_axis(m_axis, to_squeeze, ndim): + """ + Squeeze the axis value based on the given parameters. + + Args: + m_axis (int): The current axis value. + to_squeeze (numpy.ndarray): An array of indices to squeeze. + ndim (int): The number of dimensions. + + Returns: + int or None: The updated axis value. + """ + if m_axis == -1: + m_axis = ndim - 1 + if (to_squeeze == m_axis).sum() == 1: + m_axis = None + else: + inds = np.ones(ndim, bool) + inds[to_squeeze] = False + m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0] + if len(m_axis) > 0: + m_axis = m_axis[0] + else: + m_axis = None + return m_axis + + +def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, nchan=2): + """Converts the image to have the z-axis first, channels last. + + Args: + x (numpy.ndarray or torch.Tensor): The input image. + channels (list or None): The list of channels to use (ones-based, 0=gray). If None, all channels are kept. + channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically. + z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically. + do_3D (bool): Whether to process the image in 3D mode. Defaults to False. + nchan (int): The number of channels to keep if the input image has more than nchan channels. + + Returns: + numpy.ndarray: The converted image. + + Raises: + ValueError: If the input image has less than two channels and channels are not specified. + ValueError: If the input image is 2D and do_3D is True. + ValueError: If the input image is 4D and do_3D is False. + """ + # check if image is a torch array instead of numpy array + # converts torch to numpy + ndim = x.ndim + if torch.is_tensor(x): + transforms_logger.warning("torch array used as input, converting to numpy") + x = x.cpu().numpy() + + # squeeze image, and if channel_axis or z_axis given, transpose image + if x.ndim > 3: + to_squeeze = np.array([int(isq) for isq, s in enumerate(x.shape) if s == 1]) + # remove channel axis if number of channels is 1 + if len(to_squeeze) > 0: + channel_axis = update_axis( + channel_axis, to_squeeze, + x.ndim) if channel_axis is not None else None + z_axis = update_axis(z_axis, to_squeeze, + x.ndim) if z_axis is not None else None + x = x.squeeze() + + # put z axis first + if z_axis is not None and x.ndim > 2 and z_axis != 0: + x = move_axis(x, m_axis=z_axis, first=True) + if channel_axis is not None: + channel_axis += 1 + z_axis = 0 + elif z_axis is None and x.ndim > 2 and channels is not None and min(x.shape) > 5 : + # if there are > 5 channels and channels!=None, assume first dimension is z + min_dim = min(x.shape) + if min_dim != channel_axis: + z_axis = (x.shape).index(min_dim) + if z_axis != 0: + x = move_axis(x, m_axis=z_axis, first=True) + if channel_axis is not None: + channel_axis += 1 + transforms_logger.warning(f"z_axis not specified, assuming it is dim {z_axis}") + transforms_logger.warning(f"if this is actually the channel_axis, use 'model.eval(channel_axis={z_axis}, ...)'") + z_axis = 0 + + if z_axis is not None: + if x.ndim == 3: + x = x[..., np.newaxis] + + # put channel axis last + if channel_axis is not None and x.ndim > 2: + x = move_axis(x, m_axis=channel_axis, first=False) + elif x.ndim == 2: + x = x[:, :, np.newaxis] + + if do_3D: + if ndim < 3: + transforms_logger.critical("ERROR: cannot process 2D images in 3D mode") + raise ValueError("ERROR: cannot process 2D images in 3D mode") + elif x.ndim < 4: + x = x[..., np.newaxis] + + if channel_axis is None: + x = move_min_dim(x) + + if x.ndim > 3: + transforms_logger.info( + "multi-stack tiff read in as having %d planes %d channels" % + (x.shape[0], x.shape[-1])) + + # convert to float32 + x = x.astype("float32") + + if channels is not None: + channels = channels[0] if len(channels) == 1 else channels + if len(channels) < 2: + transforms_logger.critical("ERROR: two channels not specified") + raise ValueError("ERROR: two channels not specified") + x = reshape(x, channels=channels) + + else: + # code above put channels last + if nchan is not None and x.shape[-1] > nchan: + transforms_logger.warning( + "WARNING: more than %d channels given, use 'channels' input for specifying channels - just using first %d channels to run processing" + % (nchan, nchan)) + x = x[..., :nchan] + + # if not do_3D and x.ndim > 3: + # transforms_logger.critical("ERROR: cannot process 4D images in 2D mode") + # raise ValueError("ERROR: cannot process 4D images in 2D mode") + + if nchan is not None and x.shape[-1] < nchan: + x = np.concatenate((x, np.tile(np.zeros_like(x), (1, 1, nchan - 1))), + axis=-1) + + return x + + +def reshape(data, channels=[0, 0], chan_first=False): + """Reshape data using channels. + + Args: + data (numpy.ndarray): The input data. It should have shape (Z x ) Ly x Lx x nchan + if data.ndim==3 and data.shape[0]<8, it is assumed to be nchan x Ly x Lx. + channels (list of int, optional): The channels to use for reshaping. The first element + of the list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). The + second element of the list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to train on grayscale images, input [0,0]. To train on images with cells + in green and nuclei in blue, input [2,3]. Defaults to [0, 0]. + chan_first (bool, optional): Whether to return the reshaped data with channel as the first + dimension. Defaults to False. + + Returns: + numpy.ndarray: The reshaped data with shape (Z x ) Ly x Lx x nchan (if chan_first==False). + """ + if data.ndim < 3: + data = data[:, :, np.newaxis] + elif data.shape[0] < 8 and data.ndim == 3: + data = np.transpose(data, (1, 2, 0)) + + # use grayscale image + if data.shape[-1] == 1: + data = np.concatenate((data, np.zeros(data.shape, "float32")), axis=-1) + else: + if channels[0] == 0: + data = data.mean(axis=-1, keepdims=True) + data = np.concatenate((data, np.zeros(data.shape, "float32")), axis=-1) + else: + chanid = [channels[0] - 1] + if channels[1] > 0: + chanid.append(channels[1] - 1) + data = data[..., chanid] + for i in range(data.shape[-1]): + if np.ptp(data[..., i]) == 0.0: + if i == 0: + warnings.warn("'chan to seg' to seg has value range of ZERO") + else: + warnings.warn( + "'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0" + ) + if data.shape[-1] == 1: + data = np.concatenate((data, np.zeros(data.shape, "float32")), axis=-1) + if chan_first: + if data.ndim == 4: + data = np.transpose(data, (3, 0, 1, 2)) + else: + data = np.transpose(data, (2, 0, 1)) + return data + + +def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None, + percentile=(1., 99.), sharpen_radius=0, smooth_radius=0, + tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1): + """Normalize each channel of the image with optional inversion, smoothing, and sharpening. + + Args: + img (ndarray): The input image. It should have at least 3 dimensions. + If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension. + normalize (bool, optional): Whether to perform normalization. Defaults to True. + norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will + be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False. + invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright. + Defaults to False. + lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization. + Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2) + for per-channel normalization. Incompatible with smoothing and sharpening. + Defaults to None. + percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be + a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0). + sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0. + smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0. + tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0. + tile_norm_smooth3D (int, optional): The smoothness factor for tile-based normalization in 3D. Defaults to 1. + axis (int, optional): The channel axis to loop over for normalization. Defaults to -1. + + Returns: + ndarray: The normalized image of the same size. + + Raises: + ValueError: If the image has less than 3 dimensions. + ValueError: If the provided lowhigh or percentile values are invalid. + ValueError: If the image is inverted without normalization. + + """ + if img.ndim < 3: + error_message = "Image needs to have at least 3 dimensions" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + img_norm = img if img.dtype=="float32" else img.astype(np.float32) + if axis != -1 and axis != img_norm.ndim - 1: + img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last + + nchan = img_norm.shape[-1] + + # Validate and handle lowhigh bounds + if lowhigh is not None: + lowhigh = np.array(lowhigh) + if lowhigh.shape == (2,): + lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds + elif lowhigh.shape != (nchan, 2): + error_message = "`lowhigh` must have shape (2,) or (nchan, 2)" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Validate percentile + if percentile is None: + percentile = (1.0, 99.0) + elif not (0 <= percentile[0] < percentile[1] <= 100): + error_message = "Invalid percentile range, should be between 0 and 100" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Apply normalization based on lowhigh or percentile + cgood = np.zeros(nchan, "bool") + if lowhigh is not None: + for c in range(nchan): + lower = lowhigh[c, 0] + upper = lowhigh[c, 1] + img_norm[..., c] -= lower + img_norm[..., c] /= (upper - lower) + cgood[c] = True + else: + # Apply sharpening and smoothing if specified + if sharpen_radius > 0 or smooth_radius > 0: + img_norm = smooth_sharpen_img( + img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius + ) + + # Apply tile-based normalization or standard normalization + if tile_norm_blocksize > 0: + img_norm = normalize99_tile( + img_norm, + blocksize=tile_norm_blocksize, + lower=percentile[0], + upper=percentile[1], + smooth3D=tile_norm_smooth3D, + norm3D=norm3D, + ) + elif normalize: + if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True + for c in range(nchan): + if np.ptp(img_norm[..., c]) > 0.: + img_norm[..., c] = normalize99( + img_norm[..., c], + lower=percentile[0], + upper=percentile[1], + copy=False, downsample=True, + ) + cgood[c] = True + else: # i.e. if ZYXC with norm3D=False then per Z-slice + for z in range(img_norm.shape[0]): + for c in range(nchan): + if np.ptp(img_norm[z, ..., c]) > 0.: + img_norm[z, ..., c] = normalize99( + img_norm[z, ..., c], + lower=percentile[0], + upper=percentile[1], + copy=False, downsample=True, + ) + cgood[c] = True + + + if invert: + if lowhigh is not None or tile_norm_blocksize > 0 or normalize: + for c in range(nchan): + if cgood[c]: + img_norm[..., c] = 1 - img_norm[..., c] + else: + error_message = "Cannot invert image without normalization" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Move channel axis back to the original position + if axis != -1 and axis != img_norm.ndim - 1: + img_norm = np.moveaxis(img_norm, -1, axis) + + return img_norm + +def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR): + """OpenCV resize function does not support uint32. + + This function converts the image to float32 before resizing and then converts it back to uint32. Not safe! + References issue: https://github.com/MouseLand/cellpose/issues/937 + + Implications: + * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not + a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU. + * Memory: However, memory usage increases. Not tested by how much. + + Args: + img (ndarray): Image of size [Ly x Lx]. + Ly (int): Desired height of the resized image. + Lx (int): Desired width of the resized image. + interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. + + Returns: + ndarray: Resized image of size [Ly x Lx]. + + """ + + # cast image + cast = img.dtype == np.uint32 + if cast: + img = img.astype(np.float32) + + # resize + img = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + + # cast back + if cast: + img = img.round().astype(np.uint32) + + return img + + +def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR, + no_channels=False): + """Resize image for computing flows / unresize for computing dynamics. + + Args: + img0 (ndarray): Image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]. + Ly (int, optional): Desired height of the resized image. Defaults to None. + Lx (int, optional): Desired width of the resized image. Defaults to None. + rsz (float, optional): Resize coefficient(s) for the image. If Ly is None, rsz is used. Defaults to None. + interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. + no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel. + Defaults to False. + + Returns: + ndarray: Resized image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]. + + Raises: + ValueError: If Ly is None and rsz is None. + + """ + if Ly is None and rsz is None: + error_message = "must give size to resize to or factor to use for resizing" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + if Ly is None: + # determine Ly and Lx using rsz + if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray): + rsz = [rsz, rsz] + if no_channels: + Ly = int(img0.shape[-2] * rsz[-2]) + Lx = int(img0.shape[-1] * rsz[-1]) + else: + Ly = int(img0.shape[-3] * rsz[-2]) + Lx = int(img0.shape[-2] * rsz[-1]) + + # no_channels useful for z-stacks, so the third dimension is not treated as a channel + # but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but + if (img0.ndim > 2 and no_channels) or (img0.ndim == 4 and not no_channels): + if Ly == 0 or Lx == 0: + raise ValueError( + "anisotropy too high / low -- not enough pixels to resize to ratio") + for i, img in enumerate(img0): + imgi = resize_safe(img, Ly, Lx, interpolation=interpolation) + if i==0: + if no_channels: + imgs = np.zeros((img0.shape[0], Ly, Lx), imgi.dtype) + else: + imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), imgi.dtype) + imgs[i] = imgi if imgi.ndim > 2 or no_channels else imgi[..., np.newaxis] + else: + imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation) + return imgs + +def get_pad_yx(Ly, Lx, div=16, extra=1, min_size=None): + if min_size is None or Ly >= min_size[-2]: + Lpad = int(div * np.ceil(Ly / div) - Ly) + else: + Lpad = min_size[-2] - Ly + ypad1 = extra * div // 2 + Lpad // 2 + ypad2 = extra * div // 2 + Lpad - Lpad // 2 + if min_size is None or Lx >= min_size[-1]: + Lpad = int(div * np.ceil(Lx / div) - Lx) + else: + Lpad = min_size[-1] - Lx + xpad1 = extra * div // 2 + Lpad // 2 + xpad2 = extra * div // 2 + Lpad - Lpad // 2 + + return ypad1, ypad2, xpad1, xpad2 + + +def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False): + """Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D). + + Args: + img0 (ndarray): Image of size [nchan (x Lz) x Ly x Lx]. + div (int, optional): Divisor for padding. Defaults to 16. + extra (int, optional): Extra padding. Defaults to 1. + min_size (tuple, optional): Minimum size of the image. Defaults to None. + + Returns: + A tuple containing (I, ysub, xsub) or (I, ysub, xsub, zsub), I is padded image, -sub are ranges of pixels in the padded image corresponding to img0. + + """ + Ly, Lx = img0.shape[-2:] + ypad1, ypad2, xpad1, xpad2 = get_pad_yx(Ly, Lx, div=div, extra=extra, min_size=min_size) + + if img0.ndim > 3: + if zpad: + Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3]) + zpad1 = extra * div // 2 + Lpad // 2 + zpad2 = extra * div // 2 + Lpad - Lpad // 2 + else: + zpad1, zpad2 = 0, 0 + pads = np.array([[0, 0], [zpad1, zpad2], [ypad1, ypad2], [xpad1, xpad2]]) + else: + pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]]) + + I = np.pad(img0, pads, mode="constant") + + ysub = np.arange(ypad1, ypad1 + Ly) + xsub = np.arange(xpad1, xpad1 + Lx) + if zpad: + zsub = np.arange(zpad1, zpad1 + img0.shape[-3]) + return I, ysub, xsub, zsub + else: + return I, ysub, xsub + + +def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False, + zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False, + random_per_image=True): + """Augmentation by random rotation and resizing. + + Args: + X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx]. + Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx]. + The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation). + If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow]. + If unet, second channel is dist_to_bound. Defaults to None. + scale_range (float, optional): Range of resizing of images for augmentation. + Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0. + xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224). + do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True. + rotate (bool, optional): Whether or not to rotate images. Defaults to True. + rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None. + unet (bool, optional): Whether or not to use unet. Defaults to False. + random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True. + + Returns: + A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]]; + lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]]; + scale (array, float): Amount each image was resized by. + """ + scale_range = max(0, min(2, float(scale_range))) + nimg = len(X) + if X[0].ndim > 2: + nchan = X[0].shape[0] + else: + nchan = 1 + if do_3D and X[0].ndim > 3: + shape = (zcrop, xy[0], xy[1]) + else: + shape = (xy[0], xy[1]) + imgi = np.zeros((nimg, nchan, *shape), "float32") + + lbl = [] + if Y is not None: + if Y[0].ndim > 2: + nt = Y[0].shape[0] + else: + nt = 1 + lbl = np.zeros((nimg, nt, *shape), np.float32) + + scale = np.ones(nimg, np.float32) + + for n in range(nimg): + + if random_per_image or n == 0: + Ly, Lx = X[n].shape[-2:] + # generate random augmentation parameters + flip = np.random.rand() > .5 + theta = np.random.rand() * np.pi * 2 if rotate else 0. + scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand() + if rescale is not None: + scale[n] *= 1. / rescale[n] + dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1], + Ly * scale[n] - xy[0]])) + dxy = (np.random.rand(2,) - .5) * dxy + + # create affine transform + cc = np.array([Lx / 2, Ly / 2]) + cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy + pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])]) + pts2 = np.float32([ + cc1, + cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]), + cc1 + scale[n] * + np.array([np.cos(np.pi / 2 + theta), + np.sin(np.pi / 2 + theta)]) + ]) + M = cv2.getAffineTransform(pts1, pts2) + + img = X[n].copy() + if Y is not None: + labels = Y[n].copy() + if labels.ndim < 3: + labels = labels[np.newaxis, :, :] + + if do_3D: + Lz = X[n].shape[-3] + flip_z = np.random.rand() > .5 + lz = int(np.round(zcrop / scale[n])) + iz = np.random.randint(0, Lz - lz) + img = img[:,iz:iz + lz,:,:] + if Y is not None: + labels = labels[:,iz:iz + lz,:,:] + + if do_flip: + if flip: + img = img[..., ::-1] + if Y is not None: + labels = labels[..., ::-1] + if nt > 1 and not unet: + labels[-1] = -labels[-1] + if do_3D and flip_z: + img = img[:, ::-1] + if Y is not None: + labels = labels[:,::-1] + if nt > 1 and not unet: + labels[-3] = -labels[-3] + + for k in range(nchan): + if do_3D: + img0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]), + flags=cv2.INTER_LINEAR) + img0[z] = I + if scale[n] != 1.0: + for y in range(imgi.shape[-2]): + imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop), + interpolation=cv2.INTER_LINEAR) + else: + imgi[n, k] = img0 + else: + I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR) + imgi[n, k] = I + + if Y is not None: + for k in range(nt): + flag = cv2.INTER_NEAREST if k == 0 else cv2.INTER_LINEAR + if do_3D: + lbl0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]), + flags=flag) + lbl0[z] = I + if scale[n] != 1.0: + for y in range(lbl.shape[-2]): + lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop), + interpolation=flag) + else: + lbl[n, k] = lbl0 + else: + lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag) + + if nt > 1 and not unet: + v1 = lbl[n, -1].copy() + v2 = lbl[n, -2].copy() + lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta)) + lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta)) + + return imgi, lbl, scale diff --git a/cellpose/utils.py b/cellpose/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6549608ebd8fd7309fe17322b176bd8598e642 --- /dev/null +++ b/cellpose/utils.py @@ -0,0 +1,659 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" +import logging +import os, tempfile, shutil, io +from tqdm import tqdm, trange +from urllib.request import urlopen +import cv2 +from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label, maximum_filter1d, binary_fill_holes +from scipy.spatial import ConvexHull +import numpy as np +import colorsys +import fastremap +from multiprocessing import Pool, cpu_count + +from . import metrics + +try: + from skimage.morphology import remove_small_holes + SKIMAGE_ENABLED = True +except: + SKIMAGE_ENABLED = False + + +class TqdmToLogger(io.StringIO): + """ + Output stream for TQDM which will output to logger module instead of + the StdOut. + """ + logger = None + level = None + buf = "" + + def __init__(self, logger, level=None): + super(TqdmToLogger, self).__init__() + self.logger = logger + self.level = level or logging.INFO + + def write(self, buf): + self.buf = buf.strip("\r\n\t ") + + def flush(self): + self.logger.log(self.level, self.buf) + + +def rgb_to_hsv(arr): + rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv) + r, g, b = np.rollaxis(arr, axis=-1) + h, s, v = rgb_to_hsv_channels(r, g, b) + hsv = np.stack((h, s, v), axis=-1) + return hsv + + +def hsv_to_rgb(arr): + hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb) + h, s, v = np.rollaxis(arr, axis=-1) + r, g, b = hsv_to_rgb_channels(h, s, v) + rgb = np.stack((r, g, b), axis=-1) + return rgb + + +def download_url_to_file(url, dst, progress=True): + r"""Download object at the given URL to a local path. + Thanks to torch, slightly modified + Args: + url (string): URL of the object to download + dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` + progress (bool, optional): whether or not to display a progress bar to stderr + Default: True + """ + file_size = None + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + u = urlopen(url) + meta = u.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + # We deliberately save it in a temp file and move it after + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + try: + with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True, + unit_divisor=1024) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + pbar.update(len(buffer)) + f.close() + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def distance_to_boundary(masks): + """Get the distance to the boundary of mask pixels. + + Args: + masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels. + + Returns: + dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx]. + + Raises: + ValueError: If the masks array is not 2D or 3D. + + """ + if masks.ndim > 3 or masks.ndim < 2: + raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" % + masks.ndim) + dist_to_bound = np.zeros(masks.shape, np.float64) + + if masks.ndim == 3: + for i in range(masks.shape[0]): + dist_to_bound[i] = distance_to_boundary(masks[i]) + return dist_to_bound + else: + slices = find_objects(masks) + for i, si in enumerate(slices): + if si is not None: + sr, sc = si + mask = (masks[sr, sc] == (i + 1)).astype(np.uint8) + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T + ypix, xpix = np.nonzero(mask) + min_dist = ((ypix[:, np.newaxis] - pvr)**2 + + (xpix[:, np.newaxis] - pvc)**2).min(axis=1) + dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist + return dist_to_bound + + +def masks_to_edges(masks, threshold=1.0): + """Get edges of masks as a 0-1 array. + + Args: + masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels. + threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0. + + Returns: + edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels. + """ + dist_to_bound = distance_to_boundary(masks) + edges = (dist_to_bound < threshold) * (masks > 0) + return edges + + +def remove_edge_masks(masks, change_index=True): + """Removes masks with pixels on the edge of the image. + + Args: + masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels. + change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True. + + Returns: + outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels. + """ + slices = find_objects(masks.astype(int)) + for i, si in enumerate(slices): + remove = False + if si is not None: + for d, sid in enumerate(si): + if sid.start == 0 or sid.stop == masks.shape[d]: + remove = True + break + if remove: + masks[si][masks[si] == i + 1] = 0 + shape = masks.shape + if change_index: + _, masks = np.unique(masks, return_inverse=True) + masks = np.reshape(masks, shape).astype(np.int32) + + return masks + + +def masks_to_outlines(masks): + """Get outlines of masks as a 0-1 array. + + Args: + masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels. + + Returns: + outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines. + """ + if masks.ndim > 3 or masks.ndim < 2: + raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" % + masks.ndim) + outlines = np.zeros(masks.shape, bool) + + if masks.ndim == 3: + for i in range(masks.shape[0]): + outlines[i] = masks_to_outlines(masks[i]) + return outlines + else: + slices = find_objects(masks.astype(int)) + for i, si in enumerate(slices): + if si is not None: + sr, sc = si + mask = (masks[sr, sc] == (i + 1)).astype(np.uint8) + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T + vr, vc = pvr + sr.start, pvc + sc.start + outlines[vr, vc] = 1 + return outlines + + +def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None): + """Get outlines of masks as a list to loop over for plotting. + + Args: + masks (ndarray): Array of masks. + multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000. + multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None. + + Returns: + list: List of outlines. + + Raises: + None + + Notes: + - This function is a wrapper for outlines_list_single and outlines_list_multi. + - Multiprocessing is disabled for Windows. + """ + # default to use multiprocessing if not few_masks, but allow user to override + if multiprocessing is None: + few_masks = np.max(masks) < multiprocessing_threshold + multiprocessing = not few_masks + + # disable multiprocessing for Windows + if os.name == "nt": + if multiprocessing: + logging.getLogger(__name__).warning( + "Multiprocessing is disabled for Windows") + multiprocessing = False + + if multiprocessing: + return outlines_list_multi(masks) + else: + return outlines_list_single(masks) + + +def outlines_list_single(masks): + """Get outlines of masks as a list to loop over for plotting. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + list: List of outlines as pixel coordinates. + + """ + outpix = [] + for n in np.unique(masks)[1:]: + mn = masks == n + if mn.sum() > 0: + contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, + method=cv2.CHAIN_APPROX_NONE) + contours = contours[-2] + cmax = np.argmax([c.shape[0] for c in contours]) + pix = contours[cmax].astype(int).squeeze() + if len(pix) > 4: + outpix.append(pix) + else: + outpix.append(np.zeros((0, 2))) + return outpix + + +def outlines_list_multi(masks, num_processes=None): + """ + Get outlines of masks as a list to loop over for plotting. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + list: List of outlines as pixel coordinates. + """ + if num_processes is None: + num_processes = cpu_count() + + unique_masks = np.unique(masks)[1:] + with Pool(processes=num_processes) as pool: + outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks]) + return outpix + + +def get_outline_multi(args): + """Get the outline of a specific mask in a multi-mask image. + + Args: + args (tuple): A tuple containing the masks and the mask number. + + Returns: + numpy.ndarray: The outline of the specified mask as an array of coordinates. + + """ + masks, n = args + mn = masks == n + if mn.sum() > 0: + contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, + method=cv2.CHAIN_APPROX_NONE) + contours = contours[-2] + cmax = np.argmax([c.shape[0] for c in contours]) + pix = contours[cmax].astype(int).squeeze() + return pix if len(pix) > 4 else np.zeros((0, 2)) + return np.zeros((0, 2)) + + +def dilate_masks(masks, n_iter=5): + """Dilate masks by n_iter pixels. + + Args: + masks (ndarray): Array of masks. + n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5. + + Returns: + ndarray: Dilated masks. + """ + dilated_masks = masks.copy() + for n in range(n_iter): + # define the structuring element to use for dilation + kernel = np.ones((3, 3), "uint8") + # find the distance to each mask (distances are zero within masks) + dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"), + cv2.DIST_L2, 5) + # dilate each mask and assign to it the pixels along the border of the mask + # (does not allow dilation into other masks since dist_transform is zero there) + for i in range(1, np.max(masks) + 1): + mask = (dilated_masks == i).astype("uint8") + dilated_mask = cv2.dilate(mask, kernel, iterations=1) + dilated_mask = np.logical_and(dist_transform < 2, dilated_mask) + dilated_masks[dilated_mask > 0] = i + return dilated_masks + + +def get_perimeter(points): + """ + Calculate the perimeter of a set of points. + + Parameters: + points (ndarray): An array of points with shape (npoints, ndim). + + Returns: + float: The perimeter of the points. + + """ + if points.shape[0] > 4: + points = np.append(points, points[:1], axis=0) + return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum() + else: + return 0 + + +def get_mask_compactness(masks): + """ + Calculate the compactness of masks. + + Parameters: + masks (ndarray): Binary masks representing objects. + + Returns: + ndarray: Array of compactness values for each mask. + """ + perimeters = get_mask_perimeters(masks) + npoints = np.unique(masks, return_counts=True)[1][1:] + areas = npoints + compactness = 4 * np.pi * areas / perimeters**2 + compactness[perimeters == 0] = 0 + compactness[compactness > 1.0] = 1.0 + return compactness + + +def get_mask_perimeters(masks): + """ + Calculate the perimeters of the given masks. + + Parameters: + masks (numpy.ndarray): Binary masks representing objects. + + Returns: + numpy.ndarray: Array containing the perimeters of each mask. + """ + perimeters = np.zeros(masks.max()) + for n in range(masks.max()): + mn = masks == (n + 1) + if mn.sum() > 0: + contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, + method=cv2.CHAIN_APPROX_NONE)[-2] + perimeters[n] = np.array( + [get_perimeter(c.astype(int).squeeze()) for c in contours]).sum() + + return perimeters + + +def circleMask(d0): + """ + Creates an array with indices which are the radius of that x,y point. + + Args: + d0 (tuple): Patch of (-d0, d0+1) over which radius is computed. + + Returns: + tuple: A tuple containing: + - rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1). + - dx (ndarray): Indices of the patch along the x-axis. + - dy (ndarray): Indices of the patch along the y-axis. + """ + dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1)) + dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1)) + dy = dy.transpose() + + rs = (dy**2 + dx**2)**0.5 + return rs, dx, dy + + +def get_mask_stats(masks_true): + """ + Calculate various statistics for the given binary masks. + + Parameters: + masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + convexity (ndarray): Convexity values for each mask. + solidity (ndarray): Solidity values for each mask. + compactness (ndarray): Compactness values for each mask. + """ + mask_perimeters = get_mask_perimeters(masks_true) + + # disk for compactness + rs, dy, dx = circleMask(np.array([100, 100])) + rsort = np.sort(rs.flatten()) + + # area for solidity + npoints = np.unique(masks_true, return_counts=True)[1][1:] + areas = npoints - mask_perimeters / 2 - 1 + + compactness = np.zeros(masks_true.max()) + convexity = np.zeros(masks_true.max()) + solidity = np.zeros(masks_true.max()) + convex_perimeters = np.zeros(masks_true.max()) + convex_areas = np.zeros(masks_true.max()) + for ic in range(masks_true.max()): + points = np.array(np.nonzero(masks_true == (ic + 1))).T + if len(points) > 15 and mask_perimeters[ic] > 0: + med = np.median(points, axis=0) + # compute compactness of ROI + r2 = ((points - med)**2).sum(axis=1)**0.5 + compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean() + try: + hull = ConvexHull(points) + convex_perimeters[ic] = hull.area + convex_areas[ic] = hull.volume + except: + convex_perimeters[ic] = 0 + + convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] / + mask_perimeters[mask_perimeters > 0.0]) + solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] / + convex_areas[convex_areas > 0.0]) + convexity = np.clip(convexity, 0.0, 1.0) + solidity = np.clip(solidity, 0.0, 1.0) + compactness = np.clip(compactness, 0.0, 1.0) + return convexity, solidity, compactness + + +def get_masks_unet(output, cell_threshold=0, boundary_threshold=0): + """Create masks using cell probability and cell boundary. + + Args: + output (ndarray): The output array containing cell probability and cell boundary. + cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0. + boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0. + + Returns: + ndarray: The masks representing the segmented cells. + + """ + cells = (output[..., 1] - output[..., 0]) > cell_threshold + selem = generate_binary_structure(cells.ndim, connectivity=1) + labels, nlabels = label(cells, selem) + + if output.shape[-1] > 2: + slices = find_objects(labels) + dists = 10000 * np.ones(labels.shape, np.float32) + mins = np.zeros(labels.shape, np.int32) + borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold) + pad = 10 + for i, slc in enumerate(slices): + if slc is not None: + slc_pad = tuple([ + slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad)) + for j, sli in enumerate(slc) + ]) + msk = (labels[slc_pad] == (i + 1)).astype(np.float32) + msk = 1 - gaussian_filter(msk, 5) + dists[slc_pad] = np.minimum(dists[slc_pad], msk) + mins[slc_pad][dists[slc_pad] == msk] = (i + 1) + labels[labels == 0] = borders[labels == 0] * mins[labels == 0] + + masks = labels + shape0 = masks.shape + _, masks = np.unique(masks, return_inverse=True) + masks = np.reshape(masks, shape0) + return masks + + +def stitch3D(masks, stitch_threshold=0.25): + """ + Stitch 2D masks into a 3D volume using a stitch_threshold on IOU. + + Args: + masks (list or ndarray): List of 2D masks. + stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25. + + Returns: + list: List of stitched 3D masks. + """ + mmax = masks[0].max() + empty = 0 + for i in trange(len(masks) - 1): + iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:] + if not iou.size and empty == 0: + masks[i + 1] = masks[i + 1] + mmax = masks[i + 1].max() + elif not iou.size and not empty == 0: + icount = masks[i + 1].max() + istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype) + mmax += icount + istitch = np.append(np.array(0), istitch) + masks[i + 1] = istitch[masks[i + 1]] + else: + iou[iou < stitch_threshold] = 0.0 + iou[iou < iou.max(axis=0)] = 0.0 + istitch = iou.argmax(axis=1) + 1 + ino = np.nonzero(iou.max(axis=1) == 0.0)[0] + istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype) + mmax += len(ino) + istitch = np.append(np.array(0), istitch) + masks[i + 1] = istitch[masks[i + 1]] + empty = 1 + + return masks + + +def diameters(masks): + """ + Calculate the diameters of the objects in the given masks. + + Parameters: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + tuple: A tuple containing the median diameter and an array of diameters for each object. + + Examples: + >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]]) + >>> diameters(masks) + (1.0, array([1.41421356, 1.0, 1.0])) + """ + uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True) + counts = counts[1:] + md = np.median(counts**0.5) + if np.isnan(md): + md = 0 + md /= (np.pi**0.5) / 2 + return md, counts**0.5 + + +def radius_distribution(masks, bins): + """ + Calculate the radius distribution of masks. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + bins (int): Number of bins for the histogram. + + Returns: + A tuple containing a normalized histogram of radii, median radius, array of radii. + + """ + unique, counts = np.unique(masks, return_counts=True) + counts = counts[unique != 0] + nb, _ = np.histogram((counts**0.5) * 0.5, bins) + nb = nb.astype(np.float32) + if nb.sum() > 0: + nb = nb / nb.sum() + md = np.median(counts**0.5) * 0.5 + if np.isnan(md): + md = 0 + md /= (np.pi**0.5) / 2 + return nb, md, (counts**0.5) / 2 + + +def size_distribution(masks): + """ + Calculates the size distribution of masks. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes. + """ + counts = np.unique(masks, return_counts=True)[1][1:] + return np.percentile(counts, 25) / np.percentile(counts, 75) + + +def fill_holes_and_remove_small_masks(masks, min_size=15): + """ Fills holes in masks (2D/3D) and discards masks smaller than min_size. + + This function fills holes in each mask using scipy.ndimage.morphology.binary_fill_holes. + It also removes masks that are smaller than the specified min_size. + + Parameters: + masks (ndarray): Int, 2D or 3D array of labelled masks. + 0 represents no mask, while positive integers represent mask labels. + The size can be [Ly x Lx] or [Lz x Ly x Lx]. + min_size (int, optional): Minimum number of pixels per mask. + Masks smaller than min_size will be removed. + Set to -1 to turn off this functionality. Default is 15. + + Returns: + ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed. + 0 represents no mask, while positive integers represent mask labels. + The size is [Ly x Lx] or [Lz x Ly x Lx]. + """ + + if masks.ndim > 3 or masks.ndim < 2: + raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" % + masks.ndim) + + slices = find_objects(masks) + j = 0 + for i, slc in enumerate(slices): + if slc is not None: + msk = masks[slc] == (i + 1) + npix = msk.sum() + if min_size > 0 and npix < min_size: + masks[slc][msk] = 0 + elif npix > 0: + if msk.ndim == 3: + for k in range(msk.shape[0]): + msk[k] = binary_fill_holes(msk[k]) + else: + msk = binary_fill_holes(msk) + masks[slc][msk] = (j + 1) + j += 1 + return masks diff --git a/cellpose/version.py b/cellpose/version.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bc34f08462dbb2c8e3c41c80240fbbf4334d38 --- /dev/null +++ b/cellpose/version.py @@ -0,0 +1,19 @@ +""" +Copyright Β© 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +from importlib.metadata import PackageNotFoundError, version +import sys +from platform import python_version +import torch + +try: + version = version("cellpose") +except PackageNotFoundError: + version = "unknown" + +version_str = f""" +cellpose version: \t{version} +platform: \t{sys.platform} +python version: \t{python_version()} +torch version: \t{torch.__version__}""" diff --git a/example_images/er_cos-7.png b/example_images/er_cos-7.png new file mode 100644 index 0000000000000000000000000000000000000000..3d88935594de634a67da1fb23e62c1ad5cd85f48 --- /dev/null +++ b/example_images/er_cos-7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c56d2010cee1cd6d023979577c59b07ce97b1026b76e7562dd23c7bdf9393ca +size 153569 diff --git a/example_images/er_hela.png b/example_images/er_hela.png new file mode 100644 index 0000000000000000000000000000000000000000..2ba48891af510b6d58d8e94ce4517017bfcf3923 --- /dev/null +++ b/example_images/er_hela.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:823ba07e97664d1887382272ff086796956ecdc567a59cfe6f064b1ad5fb1483 +size 124319 diff --git a/example_images/f-actin_cos-7.png b/example_images/f-actin_cos-7.png new file mode 100644 index 0000000000000000000000000000000000000000..0f2216147fa2feea0982b32374806d552528f3af --- /dev/null +++ b/example_images/f-actin_cos-7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00a8ee7bcfb7d16c9ee8e95aaf8f54417877085bdd819f19f99459931efb170d +size 180910 diff --git a/example_images/microtubules_hela.png b/example_images/microtubules_hela.png new file mode 100644 index 0000000000000000000000000000000000000000..73c3ed75b39ac756fd071865f8228b6d2992acfe --- /dev/null +++ b/example_images/microtubules_hela.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3733f4a2c8de149cb94f91a24e4b7fb02985c62bb573410d94b4575ff4399d7d +size 90439 diff --git a/example_images/mitochondria_bpae.png b/example_images/mitochondria_bpae.png new file mode 100644 index 0000000000000000000000000000000000000000..d31e29ff41053f9e57c1d6ee206e5a16f7d190c4 --- /dev/null +++ b/example_images/mitochondria_bpae.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5a81c59d2aaa187f1737a673b25da8bd2dd58e944d6eb2c38d5e8927b5a4f96 +size 99075 diff --git a/example_images/nucleus_bpae.png b/example_images/nucleus_bpae.png new file mode 100644 index 0000000000000000000000000000000000000000..c8c191dc4491d4cdde1f953f951766af46e4fee9 --- /dev/null +++ b/example_images/nucleus_bpae.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0425ddce55927bd67407a356ab16ac5e1c484ecf8258efb6ed7033c9413a1a1 +size 75708 diff --git a/example_images_cls/cls_input_1754670277543.png b/example_images_cls/cls_input_1754670277543.png new file mode 100644 index 0000000000000000000000000000000000000000..522eda5f6bf11465c1493339f361aa4ac825d025 --- /dev/null +++ b/example_images_cls/cls_input_1754670277543.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc2c084aecd4610274280ffc0b8dd2eafdb0899a7cbf70d4b15708b2f13cb733 +size 63959 diff --git a/example_images_cls/cls_input_1754670358223.png b/example_images_cls/cls_input_1754670358223.png new file mode 100644 index 0000000000000000000000000000000000000000..5a4288ee086455ace24309f6033a6c23ebcdae70 --- /dev/null +++ b/example_images_cls/cls_input_1754670358223.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:922489201a874d70d6bed45e30dc98fb75f7d90d4910f868d53b2a380caa8821 +size 37023 diff --git a/example_images_cls/cls_input_1754670363268.png b/example_images_cls/cls_input_1754670363268.png new file mode 100644 index 0000000000000000000000000000000000000000..5a4288ee086455ace24309f6033a6c23ebcdae70 --- /dev/null +++ b/example_images_cls/cls_input_1754670363268.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:922489201a874d70d6bed45e30dc98fb75f7d90d4910f868d53b2a380caa8821 +size 37023 diff --git a/example_images_cls/cls_input_1754670366893.png b/example_images_cls/cls_input_1754670366893.png new file mode 100644 index 0000000000000000000000000000000000000000..4c8bfc632735cb6080ebf6dea426363cebb4ba48 --- /dev/null +++ b/example_images_cls/cls_input_1754670366893.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f704fe2bbd72985491fb419dd9e8998e76b21e0b059cff092cfd1dbf2da9d5d7 +size 42489 diff --git a/example_images_cls/cls_input_1754670372624.png b/example_images_cls/cls_input_1754670372624.png new file mode 100644 index 0000000000000000000000000000000000000000..4c8bfc632735cb6080ebf6dea426363cebb4ba48 --- /dev/null +++ b/example_images_cls/cls_input_1754670372624.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f704fe2bbd72985491fb419dd9e8998e76b21e0b059cff092cfd1dbf2da9d5d7 +size 42489 diff --git a/example_images_cls/erdak_1.png b/example_images_cls/erdak_1.png new file mode 100644 index 0000000000000000000000000000000000000000..0ba4daecdd67b9b3d5ee0d104656367161be541b --- /dev/null +++ b/example_images_cls/erdak_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b37c672e4debb464d1d3756a3a862721b310f762472914236325c44779a13a8a +size 37826 diff --git a/example_images_cls/gidap_1.png b/example_images_cls/gidap_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ccd2b45720a749dba46d6f52be277f785d0dd1f5 --- /dev/null +++ b/example_images_cls/gidap_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6743b29ae6cbfbf118102e15370e8e3510efc81613e535ae4d231af68f84c37 +size 25636 diff --git a/example_images_cls/tubul_1.png b/example_images_cls/tubul_1.png new file mode 100644 index 0000000000000000000000000000000000000000..6ab81138d994d037274e0e0305ed05818e592a80 --- /dev/null +++ b/example_images_cls/tubul_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc21b98eb667f6560f846293a60534e455e619ff8c42f3b66f4d9fb1ab32d0a5 +size 21904 diff --git a/example_images_dn/dn_input_MICE_1754056108129.tif b/example_images_dn/dn_input_MICE_1754056108129.tif new file mode 100644 index 0000000000000000000000000000000000000000..b0717a6c509111081d2dfa1248d3161d82c10bd8 --- /dev/null +++ b/example_images_dn/dn_input_MICE_1754056108129.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e95984e9259e8b08f653693d692808c0001794fe180a10faf2e2640cfbba92bd +size 262400 diff --git a/example_images_m2i/0002.tif b/example_images_m2i/0002.tif new file mode 100644 index 0000000000000000000000000000000000000000..d9e8d54d32fa4ee24a945176df8b81aa248ce38c --- /dev/null +++ b/example_images_m2i/0002.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:566676c8e6bfe1cb3e952a63504d35fbff43ad78a41985a1d0357c2a0e5210c6 +size 262266 diff --git a/example_images_seg/seg_input_1754626817311.tif b/example_images_seg/seg_input_1754626817311.tif new file mode 100644 index 0000000000000000000000000000000000000000..e51336fff932e7e396c0208d3b3ca8db5b48cdc7 --- /dev/null +++ b/example_images_seg/seg_input_1754626817311.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2de23ffc7eb40791dfea1ec07e92d7295c55e3f5408827b1591e4fcc7a1e489c +size 262400 diff --git a/example_images_sr/00000001.tif b/example_images_sr/00000001.tif new file mode 100644 index 0000000000000000000000000000000000000000..653f4cacdb2d1d85957fbb005ecad12406ecc6f3 --- /dev/null +++ b/example_images_sr/00000001.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:436a899290607814ba5f5c0e9456e9d6551bb2cb986223aff7fc4ae396d7a86a +size 1181232 diff --git a/example_images_t2i/er_cos-7.png b/example_images_t2i/er_cos-7.png new file mode 100644 index 0000000000000000000000000000000000000000..7ee5823b6ce8160f7c87b1c6ca8bd931e88a9e90 --- /dev/null +++ b/example_images_t2i/er_cos-7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67d75d10bd20a95f10e4a4a8993e4260c36f4ff576686321f2582deef738ae1e +size 138638 diff --git a/example_images_t2i/er_hela.png b/example_images_t2i/er_hela.png new file mode 100644 index 0000000000000000000000000000000000000000..31d3f2b71b1c229ab2166210bdfc44f06c0c4b37 --- /dev/null +++ b/example_images_t2i/er_hela.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec846061c99ac794d0bd44292792d762dab9136438d0f7db892630ba9e715573 +size 124240 diff --git a/example_images_t2i/f-actin_cos-7.png b/example_images_t2i/f-actin_cos-7.png new file mode 100644 index 0000000000000000000000000000000000000000..d1a99e4a2451cfb4543404cca83c0e9a9d44455a --- /dev/null +++ b/example_images_t2i/f-actin_cos-7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbfcd14c047749fb3a304c3e70e92ce4f8b890ac2d1f1785260aab3313bcda05 +size 158679 diff --git a/example_images_t2i/microtubules_hela.png b/example_images_t2i/microtubules_hela.png new file mode 100644 index 0000000000000000000000000000000000000000..bafb3034b2fca58d53afff8dc5b9013a90fb2c9e --- /dev/null +++ b/example_images_t2i/microtubules_hela.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:deb5da46a71c3d1796e3353b9fd5127afed6e723de0250ff4eda8a005be21573 +size 136049 diff --git a/example_images_t2i/mitochondria_bpae.png b/example_images_t2i/mitochondria_bpae.png new file mode 100644 index 0000000000000000000000000000000000000000..39e8abad8765766da69515a2a181aad8b06867e3 --- /dev/null +++ b/example_images_t2i/mitochondria_bpae.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4f7b3eb68ac603e8868e196289517a16a00579d5968ad5c4c06d825e8a6d001 +size 94118 diff --git a/example_images_t2i/nucleus_bpae.png b/example_images_t2i/nucleus_bpae.png new file mode 100644 index 0000000000000000000000000000000000000000..01add92f86b7e66d821ed7c61164c0d6ed277546 --- /dev/null +++ b/example_images_t2i/nucleus_bpae.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63987c6ccef723b9e7f71a9bc4ebdee54fc9b5c800f89848b212f0a7afd8ffce +size 31676 diff --git a/models/__pycache__/controlnet.cpython-311.pyc b/models/__pycache__/controlnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc14d9367d1441e0607e920112b4cf12a6b2a38b Binary files /dev/null and b/models/__pycache__/controlnet.cpython-311.pyc differ diff --git a/models/__pycache__/pipeline_controlnet.cpython-311.pyc b/models/__pycache__/pipeline_controlnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..949440f11b21a0ff28ce4acc422ede50d575b1e9 Binary files /dev/null and b/models/__pycache__/pipeline_controlnet.cpython-311.pyc differ diff --git a/models/__pycache__/pipeline_ddpm_text_encoder.cpython-311.pyc b/models/__pycache__/pipeline_ddpm_text_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d690cb649898871fe3b65e48ce39537a31c80e75 Binary files /dev/null and b/models/__pycache__/pipeline_ddpm_text_encoder.cpython-311.pyc differ diff --git a/models/__pycache__/unet_2d.cpython-311.pyc b/models/__pycache__/unet_2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94444c02e9b8cc26287fdab50b9c3c2df08e8b98 Binary files /dev/null and b/models/__pycache__/unet_2d.cpython-311.pyc differ diff --git a/models/__pycache__/unet_2d_condition.cpython-311.pyc b/models/__pycache__/unet_2d_condition.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dbfcf2d6ef7045ee62018481f48dc93a77fc585 Binary files /dev/null and b/models/__pycache__/unet_2d_condition.cpython-311.pyc differ diff --git a/models/controlnet.py b/models/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9df3908b3cccad8e06725f25bfbbc7285b16a3eb --- /dev/null +++ b/models/controlnet.py @@ -0,0 +1,494 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import pdb +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from models.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the middle block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 Γ— 512 images into smaller 64 Γ— 64 β€œlatent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 Γ— 64 feature space to match the + convolution size. We use a tiny network E(Β·) of four convolution layers with 4 Γ— 4 kernels and 2 Γ— 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=1)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2D", + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + attention_head_dim: Union[int] = 8, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + use_prompt: bool = False, + encoder_size: Optional[int] = None, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # prompt + if use_prompt: + self.prompt_embedding = nn.Sequential( + nn.Linear(encoder_size, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim) + ) + else: + self.prompt_embedding = None + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=add_attention, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + mid_block_type = unet.config.mid_block_type if "mid_block_type" in unet.config else "UNetMidBlock2D" + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + in_channels=unet.config.in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + mid_block_type=mid_block_type, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=unet.config.attention_head_dim, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + use_prompt=unet.config.use_prompt, + encoder_size=unet.config.encoder_size, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + controlnet.prompt_embedding.load_state_dict(unet.prompt_embedding.state_dict()) + + if hasattr(controlnet, "add_embedding"): + controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + controlnet_cond: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + conditioning_scale: float = 1.0, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if self.prompt_embedding is not None: + encoder_hidden_states = encoder_hidden_states.reshape(sample.shape[0], -1).contiguous() + prompt_emb = self.prompt_embedding(encoder_hidden_states) + emb = emb + prompt_emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/models/pipeline_controlnet.py b/models/pipeline_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f83e12ba7f9f5608c2aec6154ed7c6c7d98a6d15 --- /dev/null +++ b/models/pipeline_controlnet.py @@ -0,0 +1,106 @@ +from typing import List, Optional, Tuple, Union +from diffusers.pipelines.pipeline_utils import ImagePipelineOutput +from diffusers.image_processor import PipelineImageInput +from transformers import CLIPTextModel, CLIPTokenizer +from models.controlnet import ControlNetModel +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput + +from models.pipeline_ddpm_text_encoder import DDPMPipeline + +import torch +import pdb +import skimage +import numpy as np + +class DDPMControlnetPipeline(DiffusionPipeline): + def __init__( + self, + unet, + scheduler, + controlnet, + text_encoder: CLIPTextModel | None = None, + tokenizer: CLIPTokenizer | None = None + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + controlnet=controlnet, + ) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + image_cond: PipelineImageInput = None, + generator: torch.Generator | None = None, + num_inference_steps: int = 1000, + output_type: str | None = "pil", + return_dict: bool = True, + prompt: Optional[str] = None, + ) -> ImagePipelineOutput : + text_inputs = self.tokenizer( + prompt.lower(), + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.device) + encoder_hidden_states = self.text_encoder(text_input_ids, return_dict=False)[0] + + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + # denoising loop + for t in self.progress_bar(self.scheduler.timesteps): + # 1. controlnet output + down_block_res_samples, mid_block_res_sample = self.controlnet( + sample=image, + timestep=t, + encoder_hidden_states = encoder_hidden_states, + controlnet_cond=image_cond, + return_dict=False, + ) + # 2. predict noise model_output + model_output = self.unet( + sample=image, + timestep=t, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + encoder_hidden_states = encoder_hidden_states, + return_dict=False, + )[0] + + # 3. compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file diff --git a/models/pipeline_ddpm_text_encoder.py b/models/pipeline_ddpm_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e79be7a8c12a95ea8621475ca322061dd166d184 --- /dev/null +++ b/models/pipeline_ddpm_text_encoder.py @@ -0,0 +1,155 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pdb +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from transformers import CLIPTextModel, CLIPTokenizer + + +class DDPMPipeline(DiffusionPipeline): + r""" + Pipeline for image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__( + self, + unet, + scheduler, + text_encoder: Optional[CLIPTextModel]=None, + tokenizer: Optional[CLIPTokenizer]=None, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 1000, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDPMPipeline + + >>> # load model and scheduler + >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + + >>> # save image + >>> image.save("ddpm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # Prepare prompt embedding + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt.lower(), + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.device) + encoder_hidden_states = self.text_encoder(text_input_ids, return_dict=False)[0] + + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t, encoder_hidden_states).sample + # 2. compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file diff --git a/models/unet_2d.py b/models/unet_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ed20ebb352c5b16aecc0503d940af8384cfc6372 --- /dev/null +++ b/models/unet_2d.py @@ -0,0 +1,340 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + +import pdb + +@dataclass +class UNet2DOutput(BaseOutput): + """ + The output of [`UNet2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.Tensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - + 1)`. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): + Tuple of downsample block types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + attn_norm_num_groups (`int`, *optional*, defaults to `None`): + If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the + given number of groups. If left as `None`, the group norm layer will only be created if + `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class + conditioning with `class_embed_type` equal to `None`. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0.0, + act_fn: str = "silu", + attention_head_dim: Optional[int] = 8, + norm_num_groups: int = 32, + attn_norm_num_groups: Optional[int] = None, + norm_eps: float = 1e-5, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, + num_train_timesteps: Optional[int] = None, + use_prompt: bool = False, + encoder_size: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + elif time_embedding_type == "learned": + self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # prompt + if use_prompt: + self.prompt_embedding = nn.Sequential( + nn.Linear(encoder_size, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim) + ) + else: + self.prompt_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, + add_attention=add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + r""" + The [`UNet2DModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.prompt_embedding is not None: + encoder_hidden_states = encoder_hidden_states.reshape(sample.shape[0], -1) + prompt_emb = self.prompt_embedding(encoder_hidden_states) + emb = emb + prompt_emb + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/models/unet_2d_condition.py b/models/unet_2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7a4b4d928882fda3c7e1006b5d360ee0da8e7c --- /dev/null +++ b/models/unet_2d_condition.py @@ -0,0 +1,215 @@ +from typing import Tuple +from diffusers.models.unets.unet_2d import UNet2DOutput +from typing import Any, Dict, List, Optional, Tuple, Union +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers + +from models.unet_2d import UNet2DModel + +import torch + +class UNet2DConditionModel(UNet2DModel): + def __init__( + self, + sample_size: int | Tuple[int] | None = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = False, + down_block_types: Tuple[str] = ..., + up_block_types: Tuple[str] = ..., + block_out_channels: Tuple[int] = ..., + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0, + act_fn: str = "silu", + attention_head_dim: int | None = 8, + norm_num_groups: int = 32, + attn_norm_num_groups: int | None = None, + norm_eps: float = 0.00001, + resnet_time_scale_shift: str = "default", + add_attention: bool = False, + num_train_timesteps: int | None = None, + use_prompt: bool = False, + encoder_size: int | None = None + ): + super().__init__( + sample_size, + in_channels, + out_channels, + center_input_sample, + time_embedding_type, + freq_shift, + flip_sin_to_cos, + down_block_types, + up_block_types, + block_out_channels, + layers_per_block, + mid_block_scale_factor, + downsample_padding, + downsample_type, + upsample_type, + dropout, + act_fn, + attention_head_dim, + norm_num_groups, + attn_norm_num_groups, + norm_eps, + resnet_time_scale_shift, + add_attention, + num_train_timesteps, + use_prompt, + encoder_size + ) + + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.prompt_embedding is not None: + encoder_hidden_states = encoder_hidden_states.reshape(sample.shape[0], -1).to(dtype=self.dtype).contiguous() + prompt_emb = self.prompt_embedding(encoder_hidden_states) + emb = emb + prompt_emb + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + skip_sample = None + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b148c94777f29f22c9091dd5be8cc8d5644a6b7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,105 @@ +accelerate==1.8.1 +aiofiles==24.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.9.0 +attrs==25.3.0 +certifi==2025.6.15 +charset-normalizer==3.4.2 +click==8.2.1 +datasets==4.0.0 +diffusers==0.34.0 +dill==0.3.8 +einops==0.8.0 +fastapi==0.116.0 +fastremap==1.17.2 +ffmpy==0.6.0 +filelock==3.18.0 +frozenlist==1.7.0 +fsspec==2025.3.0 +gradio==5.35.0 +gradio_client==1.10.4 +groovy==0.1.2 +h11==0.16.0 +hf-xet==1.1.5 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.33.2 +hydra-core==1.3.2 +idna==3.10 +imageio==2.37.0 +importlib_metadata==8.7.0 +iopath==0.1.10 +isat-sam-backend==1.0.0 +Jinja2==3.1.6 +lazy_loader==0.4 +llvmlite==0.44.0 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.6.3 +multiprocess==0.70.16 +natsort==8.4.0 +networkx +numba==0.61.2 +numpy==1.25.2 +omegaconf==2.3.0 +opencv-python==4.10.0.82 +orjson==3.10.18 +packaging==25.0 +pandas==2.3.1 +pillow==11.3.0 +portalocker==3.2.0 +propcache==0.3.2 +psutil==7.0.0 +pyarrow==21.0.0 +pydantic==2.11.7 +pydantic_core==2.33.2 +pydub==0.25.1 +Pygments==2.19.2 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytorch-msssim==1.0.0 +pytz==2025.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.4 +rich==14.0.0 +roifile==2025.5.10 +ruff==0.12.2 +safehttpx==0.1.6 +safetensors==0.5.3 +scikit-image==0.25.2 +scipy +semantic-version==2.10.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +starlette==0.46.2 +sympy==1.14.0 +tabulate==0.9.0 +tifffile +timm==1.0.19 +tokenizers==0.21.2 +tomlkit==0.13.3 +torch==2.7.1 +torchaudio==2.7.1 +torchvision==0.22.1 +tqdm==4.67.1 +transformers==4.53.1 +triton==3.3.1 +typer==0.16.0 +typing-inspection==0.4.1 +typing_extensions==4.14.1 +tzdata==2025.2 +urllib3==2.5.0 +uvicorn==0.35.0 +warmup_scheduler==0.3 +websockets==15.0.1 +xxhash==3.5.0 +yarl==1.20.1 +zipp==3.23.0 diff --git a/utils/logo2_resize.png b/utils/logo2_resize.png new file mode 100644 index 0000000000000000000000000000000000000000..1fe5e7b1f3de97cd7f5547a49148b8f6cf6eb138 --- /dev/null +++ b/utils/logo2_resize.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94f7d69a514ce901b27788782160194cca676b1a44a5565643e08800eb23e3b5 +size 251472 diff --git a/utils/logo2_transparent.png b/utils/logo2_transparent.png new file mode 100644 index 0000000000000000000000000000000000000000..6176eb2063d671f430872cc55d32f6c5f249d7c8 --- /dev/null +++ b/utils/logo2_transparent.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddd12a8a1702ab903916cf99d4c5f555e8ba111f6bedf915f79d87cf54092ae6 +size 211713 diff --git a/utils/logo_0801_1.png b/utils/logo_0801_1.png new file mode 100644 index 0000000000000000000000000000000000000000..f9f200a8af59dce3fd2266bdaf78a291afd8328f --- /dev/null +++ b/utils/logo_0801_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cad83ec9f0fbf178f50dcbfb0661928fba007a06f223b23d940ec904bf8f8569 +size 1529724 diff --git a/utils/logo_0801_2.png b/utils/logo_0801_2.png new file mode 100644 index 0000000000000000000000000000000000000000..27b1f6d343026a72ab4d835c31275bca2426c9da --- /dev/null +++ b/utils/logo_0801_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26465d64eaa2767c32f9b763ccf57bef72bbdf16ef13c82bf36886dc29e30328 +size 1226031