rayquaza384mega commited on
Commit
9060565
·
1 Parent(s): bd22fc0

Upload example images and assets using LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -5
  2. app.py +664 -0
  3. cache/00000001.tif +3 -0
  4. cellpose/__init__.py +1 -0
  5. cellpose/__main__.py +380 -0
  6. cellpose/__pycache__/__init__.cpython-310.pyc +0 -0
  7. cellpose/__pycache__/__init__.cpython-311.pyc +0 -0
  8. cellpose/__pycache__/__init__.cpython-312.pyc +0 -0
  9. cellpose/__pycache__/core.cpython-310.pyc +0 -0
  10. cellpose/__pycache__/core.cpython-311.pyc +0 -0
  11. cellpose/__pycache__/core.cpython-312.pyc +0 -0
  12. cellpose/__pycache__/dynamics.cpython-310.pyc +0 -0
  13. cellpose/__pycache__/dynamics.cpython-311.pyc +0 -0
  14. cellpose/__pycache__/dynamics.cpython-312.pyc +0 -0
  15. cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc +0 -0
  16. cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc +0 -0
  17. cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi +0 -0
  18. cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc +0 -0
  19. cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc +0 -0
  20. cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi +0 -0
  21. cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc +0 -0
  22. cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc +0 -0
  23. cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi +0 -0
  24. cellpose/__pycache__/io.cpython-310.pyc +0 -0
  25. cellpose/__pycache__/io.cpython-311.pyc +0 -0
  26. cellpose/__pycache__/io.cpython-312.pyc +0 -0
  27. cellpose/__pycache__/metrics.cpython-310.pyc +0 -0
  28. cellpose/__pycache__/metrics.cpython-311.pyc +0 -0
  29. cellpose/__pycache__/metrics.cpython-312.pyc +0 -0
  30. cellpose/__pycache__/models.cpython-310.pyc +0 -0
  31. cellpose/__pycache__/models.cpython-311.pyc +0 -0
  32. cellpose/__pycache__/models.cpython-312.pyc +0 -0
  33. cellpose/__pycache__/plot.cpython-310.pyc +0 -0
  34. cellpose/__pycache__/plot.cpython-311.pyc +0 -0
  35. cellpose/__pycache__/plot.cpython-312.pyc +0 -0
  36. cellpose/__pycache__/resnet_torch.cpython-310.pyc +0 -0
  37. cellpose/__pycache__/resnet_torch.cpython-311.pyc +0 -0
  38. cellpose/__pycache__/resnet_torch.cpython-312.pyc +0 -0
  39. cellpose/__pycache__/train.cpython-310.pyc +0 -0
  40. cellpose/__pycache__/train.cpython-312.pyc +0 -0
  41. cellpose/__pycache__/transforms.cpython-310.pyc +0 -0
  42. cellpose/__pycache__/transforms.cpython-311.pyc +0 -0
  43. cellpose/__pycache__/transforms.cpython-312.pyc +0 -0
  44. cellpose/__pycache__/utils.cpython-310.pyc +0 -0
  45. cellpose/__pycache__/utils.cpython-311.pyc +0 -0
  46. cellpose/__pycache__/utils.cpython-312.pyc +0 -0
  47. cellpose/__pycache__/version.cpython-310.pyc +0 -0
  48. cellpose/__pycache__/version.cpython-311.pyc +0 -0
  49. cellpose/__pycache__/version.cpython-312.pyc +0 -0
  50. cellpose/cli.py +231 -0
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: FluoGen
3
- emoji: 🚀
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: TMP
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: FluoGen Demo
3
+ emoji: 📉
4
+ colorFrom: red
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: 'Demo space of FluoGen: An Open-Source Generative Foundation '
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import os
6
+ import json
7
+ import glob
8
+ import random
9
+ import tifffile
10
+ import re
11
+ import imageio
12
+ from torchvision import transforms, models
13
+ import accelerate
14
+ import shutil
15
+ import time
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from collections import OrderedDict
19
+
20
+
21
+ # --- Imports from both scripts ---
22
+ from diffusers import DDPMScheduler, DDIMScheduler
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+ from accelerate.state import AcceleratorState
25
+ from transformers.utils import ContextManagers
26
+
27
+ # --- Custom Model Imports ---
28
+ from models.pipeline_ddpm_text_encoder import DDPMPipeline
29
+ from models.unet_2d import UNet2DModel
30
+ from models.controlnet import ControlNetModel
31
+ from models.unet_2d_condition import UNet2DConditionModel
32
+ from models.pipeline_controlnet import DDPMControlnetPipeline
33
+
34
+ # --- New Import for Segmentation ---
35
+ from cellpose import models as cellpose_models
36
+ from cellpose import plot as cellpose_plot
37
+ from huggingface_hub import hf_hub_download
38
+
39
+ # --- 0. Configuration & Constants ---
40
+ # --- General ---
41
+ MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite"
42
+ MODEL_DESCRIPTION = """
43
+ **Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis*
44
+ <br>
45
+ Select a task from the tabs below: generate new images from text, enhance existing images using super-resolution, denoise them, generate training data from segmentation masks, perform cell segmentation, or classify cell images.
46
+ """
47
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
48
+ WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
49
+ LOGO_PATH = "utils/logo_0801_2.png"
50
+
51
+ # --- Global switch to control example saving ---
52
+ SAVE_EXAMPLES = False
53
+
54
+ # --- Base directory for all models ---
55
+ # NOTE: All model paths are now relative.
56
+ # Run the `copy_weights.py` script once to copy all necessary model files into this local directory.
57
+ REPO_ID = "rayquaza384mega/FluoGen-demo-test-ckpts"
58
+ MODELS_ROOT_DIR = hf_hub_download(repo_id=REPO_ID) #"models_collection"
59
+
60
+
61
+ # --- Tab 1: Mask-to-Image Config (Formerly Segmentation-to-Image) ---
62
+ M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
63
+ M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
64
+
65
+ # --- Tab 2: Text-to-Image Config ---
66
+ T2I_PROMPTS = ["F-actin of COS-7", "ER of COS-7", "Mitochondria of BPAE", "Nucleus of BPAE", "ER of HeLa", "Microtubules of HeLa"]
67
+ T2I_EXAMPLE_IMG_DIR = "example_images"
68
+ T2I_CHECKPOINT = 285000
69
+ T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
70
+ T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-{T2I_CHECKPOINT}"
71
+
72
+ # --- Tab 3, 4: ControlNet-based Tasks Config ---
73
+ CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
74
+ CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
75
+
76
+ # Super-Resolution Models
77
+ SR_CONTROLNET_MODELS = {
78
+ "Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000",
79
+ "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
80
+ }
81
+ SR_EXAMPLE_IMG_DIR = "example_images_sr"
82
+
83
+ # Denoising Model
84
+ DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
85
+ DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
86
+ DN_EXAMPLE_IMG_DIR = "example_images_dn"
87
+
88
+ # --- Tab 5: Cell Segmentation Config ---
89
+ SEG_MODELS = {
90
+ "DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
91
+ "DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
92
+ "DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135",
93
+ "DSB Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135",
94
+ }
95
+ SEG_EXAMPLE_IMG_DIR = "example_images_seg"
96
+
97
+
98
+ # --- Tab 6: Classification Config ---
99
+ CLS_MODEL_PATHS = OrderedDict({
100
+ "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
101
+ #"10shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_re",
102
+ #"15shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_re",
103
+ #"20shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_re",
104
+ "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
105
+ #"10shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_aug_re",
106
+ #"15shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_aug_re",
107
+ #"20shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_aug_re",
108
+ })
109
+ CLS_CLASS_NAMES = ['dap', 'erdak', 'giant', 'gpp130', 'h4b4', 'mc151', 'nucle', 'phal', 'tfr', 'tubul']
110
+ CLS_EXAMPLE_IMG_DIR = "example_images_cls"
111
+
112
+
113
+ # --- Helper Functions ---
114
+ def sanitize_prompt_for_filename(prompt):
115
+ prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt)
116
+ return f"{prompt}.png"
117
+
118
+ def min_max_norm(x):
119
+ x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x)
120
+ if max_val - min_val < 1e-8: return np.zeros_like(x)
121
+ return (x - min_val) / (max_val - min_val)
122
+
123
+ def numpy_to_pil(image_np, target_mode="RGB"):
124
+ # If the input is already a PIL image, just ensure mode and return
125
+ if isinstance(image_np, Image.Image):
126
+ if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB")
127
+ if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
128
+ return image_np
129
+
130
+ # Handle numpy array conversion
131
+ squeezed_np = np.squeeze(image_np);
132
+ if squeezed_np.dtype == np.uint8:
133
+ # If it's already uint8, it's likely in the 0-255 range.
134
+ image_8bit = squeezed_np
135
+ else:
136
+ # Normalize and scale for other types
137
+ normalized_np = min_max_norm(squeezed_np)
138
+ image_8bit = (normalized_np * 255).astype(np.uint8)
139
+
140
+ pil_image = Image.fromarray(image_8bit)
141
+
142
+ if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
143
+ elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L")
144
+ return pil_image
145
+
146
+ # --- 1. Model Loading ---
147
+ print("--- Initializing FluoGen Application ---")
148
+ t2i_pipe, controlnet_pipe = None, None
149
+ try:
150
+ print("Loading Text-to-Image model...")
151
+ t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing")
152
+ t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet")
153
+ t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE)
154
+ t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
155
+ t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
156
+ t2i_pipe.to(DEVICE)
157
+ print("✓ Text-to-Image model loaded successfully!")
158
+ except Exception as e:
159
+ print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
160
+
161
+ try:
162
+ print("Loading shared ControlNet pipeline components...")
163
+ controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
164
+ default_controlnet_path = M2I_CONTROLNET_PATH # Start with the first tab's model
165
+ controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
166
+ controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
167
+ controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
168
+ with ContextManagers([]):
169
+ controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
170
+ controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
171
+ controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
172
+ controlnet_pipe.current_controlnet_path = default_controlnet_path
173
+ print("✓ Shared ControlNet pipeline loaded successfully!")
174
+ except Exception as e:
175
+ print(f"!!!!!! FATAL: ControlNet Pipeline Loading Failed !!!!!!\nError: {e}")
176
+
177
+ # --- 2. Core Logic Functions ---
178
+ def swap_controlnet(pipe, target_path):
179
+ if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path):
180
+ print(f"Swapping ControlNet model to: {target_path}")
181
+ try:
182
+ pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
183
+ pipe.current_controlnet_path = target_path
184
+ except Exception as e:
185
+ raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
186
+ return pipe
187
+
188
+ def generate_t2i(prompt, num_inference_steps):
189
+ if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
190
+ print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}")
191
+ image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
192
+ generated_image = numpy_to_pil(image_np)
193
+ print("✓ Image generated")
194
+ if SAVE_EXAMPLES:
195
+ example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
196
+ if not os.path.exists(example_filepath):
197
+ generated_image.save(example_filepath); print(f"✓ New T2I example saved: {example_filepath}")
198
+ return generated_image
199
+
200
+ def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed):
201
+ if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
202
+ if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask TIF file.")
203
+ if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.")
204
+
205
+ if SAVE_EXAMPLES:
206
+ input_path = mask_file_obj.name
207
+ filename = os.path.basename(input_path)
208
+ dest_path = os.path.join(M2I_EXAMPLE_IMG_DIR, filename)
209
+ if not os.path.exists(dest_path):
210
+ shutil.copy(input_path, dest_path)
211
+ print(f"✓ New Mask-to-Image example saved: {dest_path}")
212
+
213
+ pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
214
+ try:
215
+ mask_np = tifffile.imread(mask_file_obj.name)
216
+ except Exception as e:
217
+ raise gr.Error(f"Failed to read the TIF file. Error: {e}")
218
+
219
+ input_display_image = numpy_to_pil(mask_np, "L")
220
+ mask_normalized = min_max_norm(mask_np)
221
+ image_tensor = torch.from_numpy(mask_normalized.astype(np.float32))
222
+ image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
223
+ image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
224
+
225
+ prompt = f"nuclei of {cell_type.strip()}"
226
+ print(f"\nTask started... | Task: Mask-to-Image | Prompt: '{prompt}' | Steps: {steps} | Images: {num_images}")
227
+
228
+ generated_images_pil = []
229
+ for i in range(int(num_images)):
230
+ current_seed = int(seed) + i
231
+ generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
232
+ with torch.autocast("cuda"):
233
+ output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
234
+ pil_image = numpy_to_pil(output_np)
235
+ generated_images_pil.append(pil_image)
236
+ print(f"✓ Generated image {i+1}/{int(num_images)}")
237
+
238
+ return input_display_image, generated_images_pil
239
+
240
+ def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed):
241
+ if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
242
+ if low_res_file_obj is None: raise gr.Error("Please upload a low-resolution TIF file.")
243
+
244
+ if SAVE_EXAMPLES:
245
+ input_path = low_res_file_obj.name
246
+ filename = os.path.basename(input_path)
247
+ dest_path = os.path.join(SR_EXAMPLE_IMG_DIR, filename)
248
+ if not os.path.exists(dest_path):
249
+ shutil.copy(input_path, dest_path)
250
+ print(f"✓ New SR example saved: {dest_path}")
251
+
252
+ target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
253
+ if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.")
254
+
255
+ pipe = swap_controlnet(controlnet_pipe, target_path)
256
+ try:
257
+ image_stack_np = tifffile.imread(low_res_file_obj.name)
258
+ except Exception as e:
259
+ raise gr.Error(f"Failed to read the TIF file. Error: {e}")
260
+
261
+ if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
262
+ raise gr.Error(f"Invalid TIF shape. Expected 9 channels (shape 9, H, W), but got {image_stack_np.shape}.")
263
+
264
+ average_projection_np = np.mean(image_stack_np, axis=0)
265
+ input_display_image = numpy_to_pil(average_projection_np, "L")
266
+
267
+ image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
268
+ image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
269
+
270
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
271
+ with torch.autocast("cuda"):
272
+ output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
273
+
274
+ return input_display_image, numpy_to_pil(output_np)
275
+
276
+ def run_denoising(noisy_image_np, image_type, steps, seed):
277
+ if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
278
+ if noisy_image_np is None: raise gr.Error("Please upload a noisy image.")
279
+
280
+ if SAVE_EXAMPLES:
281
+ timestamp = int(time.time() * 1000)
282
+ filename = f"dn_input_{image_type}_{timestamp}.tif"
283
+ dest_path = os.path.join(DN_EXAMPLE_IMG_DIR, filename)
284
+ try:
285
+ img_to_save = noisy_image_np.astype(np.uint8) if noisy_image_np.dtype != np.uint8 else noisy_image_np
286
+ tifffile.imwrite(dest_path, img_to_save)
287
+ print(f"✓ New Denoising example saved: {dest_path}")
288
+ except Exception as e:
289
+ print(f"✗ Failed to save denoising example: {e}")
290
+
291
+ pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
292
+ prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
293
+ print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}")
294
+
295
+ image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0)
296
+ image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
297
+ image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
298
+
299
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
300
+ with torch.autocast("cuda"):
301
+ output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
302
+
303
+ return numpy_to_pil(noisy_image_np, "L"), numpy_to_pil(output_np)
304
+
305
+ def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
306
+ """
307
+ Runs cell segmentation and creates a dark red overlay.
308
+ """
309
+ if input_image_np is None:
310
+ raise gr.Error("Please upload an image to segment.")
311
+
312
+ model_path = SEG_MODELS.get(model_name)
313
+ if not model_path:
314
+ raise gr.Error(f"Segmentation model '{model_name}' not found.")
315
+
316
+ if not os.path.exists(model_path):
317
+ raise gr.Error(f"Model file not found at path: {model_path}. Please check the configuration.")
318
+
319
+ print(f"\nTask started... | Task: Cell Segmentation | Model: '{model_name}'")
320
+
321
+ # 1. Load Cellpose Model
322
+ try:
323
+ use_gpu = torch.cuda.is_available()
324
+ model = cellpose_models.CellposeModel(gpu=use_gpu, pretrained_model=model_path)
325
+ except Exception as e:
326
+ raise gr.Error(f"Failed to load Cellpose model. Error: {e}")
327
+
328
+ diameter_to_use = model.diam_labels if diameter == 0 else float(diameter)
329
+ print(f"Using Diameter: {diameter_to_use}")
330
+
331
+ # 2. Run model evaluation
332
+ try:
333
+ masks, _, _ = model.eval(
334
+ [input_image_np],
335
+ channels=[0, 0],
336
+ diameter=diameter_to_use,
337
+ flow_threshold=flow_threshold,
338
+ cellprob_threshold=cellprob_threshold
339
+ )
340
+ mask_output = masks[0]
341
+ except Exception as e:
342
+ raise gr.Error(f"Cellpose model evaluation failed. Error: {e}")
343
+
344
+ # 3. Create custom dark red overlay
345
+ # Ensure input image is uint8 and 3-channel for blending
346
+ original_rgb = numpy_to_pil(input_image_np, "RGB")
347
+ original_rgb_np = np.array(original_rgb)
348
+
349
+ # Create a blank layer for the red mask
350
+ red_mask_layer = np.zeros_like(original_rgb_np)
351
+ dark_red_color = [139, 0, 0]
352
+
353
+ # Apply the red color where the mask is present
354
+ is_mask_pixels = mask_output > 0
355
+ red_mask_layer[is_mask_pixels] = dark_red_color
356
+
357
+ # Blend the original image with the red mask layer
358
+ alpha = 0.4 # Opacity of the mask
359
+ blended_image_np = ((1 - alpha) * original_rgb_np + alpha * red_mask_layer).astype(np.uint8)
360
+
361
+ # 4. Save example if enabled
362
+ if SAVE_EXAMPLES:
363
+ timestamp = int(time.time() * 1000)
364
+ filename = f"seg_input_{timestamp}.tif"
365
+ dest_path = os.path.join(SEG_EXAMPLE_IMG_DIR, filename)
366
+ try:
367
+ img_to_save = input_image_np.astype(np.uint8) if input_image_np.dtype != np.uint8 else input_image_np
368
+ tifffile.imwrite(dest_path, img_to_save)
369
+ print(f"✓ New Segmentation example saved: {dest_path}")
370
+ except Exception as e:
371
+ print(f"✗ Failed to save segmentation example: {e}")
372
+
373
+ print("✓ Segmentation complete")
374
+
375
+ return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB")
376
+
377
+ def run_classification(input_image_np, model_name):
378
+ """
379
+ Runs classification on a single image using a pre-trained ResNet50 model.
380
+ """
381
+ if input_image_np is None:
382
+ raise gr.Error("Please upload an image to classify.")
383
+
384
+ model_dir = CLS_MODEL_PATHS.get(model_name)
385
+ if not model_dir:
386
+ raise gr.Error(f"Classification model '{model_name}' not found.")
387
+
388
+ model_path = os.path.join(model_dir, "best_resnet50.pth")
389
+ if not os.path.exists(model_path):
390
+ raise gr.Error(f"Model file not found at {model_path}. Please check the configuration.")
391
+
392
+ print(f"\nTask started... | Task: Classification | Model: '{model_name}'")
393
+
394
+ # 1. Load Model
395
+ try:
396
+ model = models.resnet50(weights=None)
397
+ num_features = model.fc.in_features
398
+ model.fc = nn.Linear(num_features, len(CLS_CLASS_NAMES))
399
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
400
+ model.to(DEVICE)
401
+ model.eval()
402
+ except Exception as e:
403
+ raise gr.Error(f"Failed to load classification model. Error: {e}")
404
+
405
+ # 2. Preprocess Image
406
+ # Grayscale numpy -> RGB PIL -> transform -> tensor
407
+ input_pil = numpy_to_pil(input_image_np, "RGB")
408
+
409
+ transform_test = transforms.Compose([
410
+ transforms.ToTensor(),
411
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # ResNet needs 3-channel norm
412
+ ])
413
+ input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE)
414
+
415
+ # 3. Perform Inference
416
+ with torch.no_grad():
417
+ outputs = model(input_tensor)
418
+ probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()
419
+
420
+ # 4. Format output for Gradio Label component
421
+ confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)}
422
+
423
+ # 5. Save example
424
+ if SAVE_EXAMPLES:
425
+ timestamp = int(time.time() * 1000)
426
+ filename = f"cls_input_{timestamp}.png" # Save as png for compatibility
427
+ dest_path = os.path.join(CLS_EXAMPLE_IMG_DIR, filename)
428
+ try:
429
+ input_pil.save(dest_path)
430
+ print(f"✓ New Classification example saved: {dest_path}")
431
+ except Exception as e:
432
+ print(f"✗ Failed to save classification example: {e}")
433
+
434
+ print("✓ Classification complete")
435
+
436
+ return numpy_to_pil(input_image_np, "L"), confidences
437
+
438
+
439
+ # --- 3. Gradio UI Layout ---
440
+ print("Building Gradio interface...")
441
+ # Create directories for all example types
442
+ os.makedirs(M2I_EXAMPLE_IMG_DIR, exist_ok=True)
443
+ os.makedirs(T2I_EXAMPLE_IMG_DIR, exist_ok=True)
444
+ os.makedirs(SR_EXAMPLE_IMG_DIR, exist_ok=True)
445
+ os.makedirs(DN_EXAMPLE_IMG_DIR, exist_ok=True)
446
+ os.makedirs(SEG_EXAMPLE_IMG_DIR, exist_ok=True)
447
+ os.makedirs(CLS_EXAMPLE_IMG_DIR, exist_ok=True)
448
+
449
+ # --- Load examples ---
450
+ filename_to_prompt_map = { sanitize_prompt_for_filename(prompt): prompt for prompt in T2I_PROMPTS }
451
+ t2i_gallery_examples = []
452
+ for filename in os.listdir(T2I_EXAMPLE_IMG_DIR):
453
+ if filename in filename_to_prompt_map:
454
+ filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, filename)
455
+ prompt = filename_to_prompt_map[filename]
456
+ t2i_gallery_examples.append((filepath, prompt))
457
+
458
+ def load_image_examples(example_dir, is_stack=False):
459
+ examples = []
460
+ if not os.path.exists(example_dir): return examples
461
+ for f in sorted(os.listdir(example_dir)):
462
+ if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
463
+ filepath = os.path.join(example_dir, f)
464
+ try:
465
+ if f.lower().endswith(('.tif', '.tiff')):
466
+ img_np = tifffile.imread(filepath)
467
+ else:
468
+ img_np = np.array(Image.open(filepath).convert("L"))
469
+
470
+ if is_stack and img_np.ndim == 3:
471
+ img_np = np.mean(img_np, axis=0)
472
+
473
+ display_img = numpy_to_pil(img_np, "L")
474
+ examples.append((display_img, filepath))
475
+ except Exception as e:
476
+ print(f"Warning: Could not load gallery image {filepath}. Error: {e}")
477
+ return examples
478
+
479
+ m2i_gallery_examples = load_image_examples(M2I_EXAMPLE_IMG_DIR)
480
+ sr_gallery_examples = load_image_examples(SR_EXAMPLE_IMG_DIR, is_stack=True)
481
+ dn_gallery_examples = load_image_examples(DN_EXAMPLE_IMG_DIR)
482
+ seg_gallery_examples = load_image_examples(SEG_EXAMPLE_IMG_DIR)
483
+ cls_gallery_examples = load_image_examples(CLS_EXAMPLE_IMG_DIR)
484
+
485
+ # --- Universal event handlers ---
486
+ def select_example_prompt(evt: gr.SelectData):
487
+ return evt.value['caption']
488
+
489
+ def select_example_input_file(evt: gr.SelectData):
490
+ return evt.value['caption']
491
+
492
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
493
+ with gr.Row():
494
+ gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
495
+ gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}")
496
+
497
+ with gr.Tabs():
498
+ # --- TAB 1: Mask-to-Image ---
499
+ with gr.Tab("Mask-to-Image", id="mask2img"):
500
+ gr.Markdown("""
501
+ ### Instructions
502
+ 1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below.
503
+ 2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt.
504
+ 3. Select how many sample images you want to generate.
505
+ 4. Adjust 'Inference Steps' and 'Seed' as needed.
506
+ 5. Click 'Generate Training Samples' to start the process.
507
+ 6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference.
508
+ """) # Content hidden for brevity
509
+ with gr.Row(variant="panel"):
510
+ with gr.Column(scale=1, min_width=350):
511
+ m2i_input_file = gr.File(label="Upload Segmentation Mask (.tif)", file_types=['.tif', '.tiff'])
512
+ m2i_cell_type_input = gr.Textbox(label="Cell Type (for prompt)", placeholder="e.g., CoNSS, HeLa, MCF-7")
513
+ m2i_num_images_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Images to Generate")
514
+ m2i_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
515
+ m2i_seed_input = gr.Number(label="Seed", value=42)
516
+ m2i_generate_button = gr.Button("Generate Training Samples", variant="primary")
517
+ with gr.Column(scale=2):
518
+ m2i_output_gallery = gr.Gallery(label="Generated Samples", columns=5, object_fit="contain", height="auto")
519
+ m2i_input_display = gr.Image(label="Input Mask", type="pil", interactive=False)
520
+ m2i_gallery = gr.Gallery(value=m2i_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
521
+
522
+ # --- TAB 2: Text-to-Image ---
523
+ with gr.Tab("Text-to-Image Generation", id="txt2img"):
524
+ gr.Markdown("""
525
+ ### Instructions
526
+ 1. Select a desired prompt from the dropdown menu.
527
+ 2. Adjust the 'Inference Steps' slider to control generation quality.
528
+ 3. Click the 'Generate' button to create a new image.
529
+ 4. Explore the 'Examples' gallery; clicking an image will load its prompt.
530
+
531
+ **Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model.
532
+ """) # Content hidden for brevity
533
+ with gr.Row(variant="panel"):
534
+ with gr.Column(scale=1, min_width=350):
535
+ t2i_prompt_input = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Select a Prompt")
536
+ t2i_steps_slider = gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Inference Steps")
537
+ t2i_generate_button = gr.Button("Generate", variant="primary")
538
+ with gr.Column(scale=2):
539
+ t2i_generated_output = gr.Image(label="Generated Image", type="pil", interactive=False)
540
+ t2i_gallery = gr.Gallery(value=t2i_gallery_examples, label="Examples (Click an image to use its prompt)", columns=6, object_fit="contain", height="auto")
541
+
542
+ # --- TAB 3: Super-Resolution ---
543
+ with gr.Tab("Super-Resolution", id="super_res"):
544
+ gr.Markdown("""
545
+ ### Instructions
546
+ 1. Upload a low-resolution 9-channel TIF stack, or select one from the examples.
547
+ 2. Select a 'Super-Resolution Model' from the dropdown.
548
+ 3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7').
549
+ 4. Adjust 'Inference Steps' and 'Seed' as needed.
550
+ 5. Click 'Generate Super-Resolution' to process the image.
551
+
552
+ **Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
553
+ """) # Content hidden for brevity
554
+ with gr.Row(variant="panel"):
555
+ with gr.Column(scale=1, min_width=350):
556
+ sr_input_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
557
+ sr_model_selector = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Select Super-Resolution Model")
558
+ sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
559
+ sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
560
+ sr_seed_input = gr.Number(label="Seed", value=42)
561
+ sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary")
562
+ with gr.Column(scale=2):
563
+ with gr.Row():
564
+ sr_input_display = gr.Image(label="Input (Average Projection)", type="pil", interactive=False)
565
+ sr_output_image = gr.Image(label="Super-Resolved Image", type="pil", interactive=False)
566
+ sr_gallery = gr.Gallery(value=sr_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
567
+
568
+ # --- TAB 4: Denoising ---
569
+ with gr.Tab("Denoising", id="denoising"):
570
+ gr.Markdown("""
571
+ ### Instructions
572
+ 1. Upload a noisy single-channel image, or select one from the examples.
573
+ 2. Select the 'Image Type' from the dropdown to provide context for the model.
574
+ 3. Adjust 'Inference Steps' and 'Seed' as needed.
575
+ 4. Click 'Denoise Image' to reduce the noise.
576
+
577
+ **Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
578
+ """) # Content hidden for brevity
579
+ with gr.Row(variant="panel"):
580
+ with gr.Column(scale=1, min_width=350):
581
+ dn_input_image = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
582
+ dn_image_type_selector = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Select Image Type (for Prompt)")
583
+ dn_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
584
+ dn_seed_input = gr.Number(label="Seed", value=42)
585
+ dn_generate_button = gr.Button("Denoise Image", variant="primary")
586
+ with gr.Column(scale=2):
587
+ with gr.Row():
588
+ dn_original_display = gr.Image(label="Original Noisy Image", type="pil", interactive=False)
589
+ dn_output_image = gr.Image(label="Denoised Image", type="pil", interactive=False)
590
+ dn_gallery = gr.Gallery(value=dn_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
591
+
592
+ # --- TAB 5: Cell Segmentation ---
593
+ with gr.Tab("Cell Segmentation", id="segmentation"):
594
+ gr.Markdown("""
595
+ ### Instructions
596
+ 1. Upload a single-channel image for segmentation, or select one from the examples.
597
+ 2. Select a 'Segmentation Model' from the dropdown menu.
598
+ 3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it.
599
+ 4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control.
600
+ 5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image.
601
+ """)
602
+ with gr.Row(variant="panel"):
603
+ with gr.Column(scale=1, min_width=350):
604
+ gr.Markdown("### 1. Inputs & Controls")
605
+ seg_input_image = gr.Image(type="numpy", label="Upload Image for Segmentation", image_mode="L")
606
+ seg_model_selector = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Select Segmentation Model")
607
+ seg_diameter_input = gr.Number(label="Cell Diameter (pixels, 0=auto)", value=30)
608
+ seg_flow_slider = gr.Slider(minimum=0.0, maximum=3.0, step=0.1, value=0.4, label="Flow Threshold")
609
+ seg_cellprob_slider = gr.Slider(minimum=-6.0, maximum=6.0, step=0.5, value=0.0, label="Cell Probability Threshold")
610
+ seg_generate_button = gr.Button("Segment Cells", variant="primary")
611
+ with gr.Column(scale=2):
612
+ gr.Markdown("### 2. Results")
613
+ with gr.Row():
614
+ seg_original_display = gr.Image(label="Original Image", type="pil", interactive=False)
615
+ seg_output_image = gr.Image(label="Segmented Image (Overlay)", type="pil", interactive=False)
616
+ seg_gallery = gr.Gallery(value=seg_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
617
+
618
+ # --- NEW TAB 6: Classification ---
619
+ with gr.Tab("Classification", id="classification"):
620
+ gr.Markdown("""
621
+ ### Instructions
622
+ 1. Upload a single-channel image for classification, or select an example.
623
+ 2. Select a pre-trained 'Classification Model' from the dropdown menu.
624
+ 3. Click 'Classify Image' to view the prediction probabilities for each class.
625
+
626
+ **Note:** The models provided are ResNet50 trained on the 2D HeLa dataset.
627
+ """)
628
+ with gr.Row(variant="panel"):
629
+ with gr.Column(scale=1, min_width=350):
630
+ gr.Markdown("### 1. Inputs & Controls")
631
+ cls_input_image = gr.Image(type="numpy", label="Upload Image for Classification", image_mode="L")
632
+ cls_model_selector = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Select Classification Model")
633
+ cls_generate_button = gr.Button("Classify Image", variant="primary")
634
+ with gr.Column(scale=2):
635
+ gr.Markdown("### 2. Results")
636
+ cls_original_display = gr.Image(label="Input Image", type="pil", interactive=False)
637
+ cls_output_label = gr.Label(label="Classification Results", num_top_classes=len(CLS_CLASS_NAMES))
638
+ cls_gallery = gr.Gallery(value=cls_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
639
+
640
+
641
+ # --- Event Handlers ---
642
+ m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input], outputs=[m2i_input_display, m2i_output_gallery])
643
+ m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file)
644
+
645
+ t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider], outputs=[t2i_generated_output])
646
+ t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input)
647
+
648
+ sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input], outputs=[sr_input_display, sr_output_image])
649
+ sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file)
650
+
651
+ dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input], outputs=[dn_original_display, dn_output_image])
652
+ dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image)
653
+
654
+ seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image])
655
+ seg_gallery.select(fn=select_example_input_file, outputs=seg_input_image)
656
+
657
+ cls_generate_button.click(fn=run_classification, inputs=[cls_input_image, cls_model_selector], outputs=[cls_original_display, cls_output_label])
658
+ cls_gallery.select(fn=select_example_input_file, outputs=cls_input_image)
659
+
660
+
661
+ # --- 4. Launch Application ---
662
+ if __name__ == "__main__":
663
+ print("Interface built. Launching server...")
664
+ demo.launch()
cache/00000001.tif ADDED

Git LFS Details

  • SHA256: 436a899290607814ba5f5c0e9456e9d6551bb2cb986223aff7fc4ae396d7a86a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
cellpose/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from cellpose.version import version, version_str
cellpose/__main__.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
3
+ """
4
+
5
+ import sys, os, glob, pathlib, time
6
+ import numpy as np
7
+ from natsort import natsorted
8
+ from tqdm import tqdm
9
+ from cellpose import utils, models, io, version_str, train, denoise
10
+ from cellpose.cli import get_arg_parser
11
+
12
+ try:
13
+ from cellpose.gui import gui3d, gui
14
+ GUI_ENABLED = True
15
+ except ImportError as err:
16
+ GUI_ERROR = err
17
+ GUI_ENABLED = False
18
+ GUI_IMPORT = True
19
+ except Exception as err:
20
+ GUI_ENABLED = False
21
+ GUI_ERROR = err
22
+ GUI_IMPORT = False
23
+ raise
24
+
25
+ import logging
26
+
27
+
28
+ # settings re-grouped a bit
29
+ def main():
30
+ """ Run cellpose from command line
31
+ """
32
+
33
+ args = get_arg_parser().parse_args(
34
+ ) # this has to be in a separate file for autodoc to work
35
+
36
+ if args.version:
37
+ print(version_str)
38
+ return
39
+
40
+ if args.check_mkl:
41
+ mkl_enabled = models.check_mkl()
42
+ else:
43
+ mkl_enabled = True
44
+
45
+ if len(args.dir) == 0 and len(args.image_path) == 0:
46
+ if args.add_model:
47
+ io.add_model(args.add_model)
48
+ else:
49
+ if not GUI_ENABLED:
50
+ print("GUI ERROR: %s" % GUI_ERROR)
51
+ if GUI_IMPORT:
52
+ print(
53
+ "GUI FAILED: GUI dependencies may not be installed, to install, run"
54
+ )
55
+ print(" pip install 'cellpose[gui]'")
56
+ else:
57
+ if args.Zstack:
58
+ gui3d.run()
59
+ else:
60
+ gui.run()
61
+
62
+ else:
63
+ if args.verbose:
64
+ from .io import logger_setup
65
+ logger, log_file = logger_setup()
66
+ else:
67
+ print(
68
+ ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
69
+ print("No --verbose => no progress or info printed")
70
+ logger = logging.getLogger(__name__)
71
+
72
+ use_gpu = False
73
+ channels = [args.chan, args.chan2]
74
+
75
+ # find images
76
+ if len(args.img_filter) > 0:
77
+ imf = args.img_filter
78
+ else:
79
+ imf = None
80
+
81
+ # Check with user if they REALLY mean to run without saving anything
82
+ if not (args.train or args.train_size):
83
+ saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
84
+
85
+ device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
86
+ device=args.gpu_device)
87
+
88
+ if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
89
+ pretrained_model = False
90
+ else:
91
+ pretrained_model = args.pretrained_model
92
+
93
+ restore_type = args.restore_type
94
+ if restore_type is not None:
95
+ try:
96
+ denoise.model_path(restore_type)
97
+ except Exception as e:
98
+ raise ValueError("restore_type invalid")
99
+ if args.train or args.train_size:
100
+ raise ValueError("restore_type cannot be used with training on CLI yet")
101
+
102
+ if args.transformer and (restore_type is None):
103
+ default_model = "transformer_cp3"
104
+ backbone = "transformer"
105
+ elif args.transformer and restore_type is not None:
106
+ raise ValueError("no transformer based restoration")
107
+ else:
108
+ default_model = "cyto3"
109
+ backbone = "default"
110
+
111
+ if args.norm_percentile is not None:
112
+ value1, value2 = args.norm_percentile
113
+ normalize = {'percentile': (float(value1), float(value2))}
114
+ else:
115
+ normalize = (not args.no_norm)
116
+
117
+
118
+ model_type = None
119
+ if pretrained_model and not os.path.exists(pretrained_model):
120
+ model_type = pretrained_model if pretrained_model is not None else "cyto3"
121
+ model_strings = models.get_user_models()
122
+ all_models = models.MODEL_NAMES.copy()
123
+ all_models.extend(model_strings)
124
+ if ~np.any([model_type == s for s in all_models]):
125
+ model_type = default_model
126
+ logger.warning(
127
+ f"pretrained model has incorrect path, using {default_model}")
128
+ if model_type == "nuclei":
129
+ szmean = 17.
130
+ else:
131
+ szmean = 30.
132
+ builtin_size = (model_type == "cyto" or model_type == "cyto2" or
133
+ model_type == "nuclei" or model_type == "cyto3")
134
+
135
+ if len(args.image_path) > 0 and (args.train or args.train_size):
136
+ raise ValueError("ERROR: cannot train model with single image input")
137
+
138
+ if not args.train and not args.train_size:
139
+ tic = time.time()
140
+ if len(args.dir) > 0:
141
+ image_names = io.get_image_files(
142
+ args.dir, args.mask_filter, imf=imf,
143
+ look_one_level_down=args.look_one_level_down)
144
+ else:
145
+ if os.path.exists(args.image_path):
146
+ image_names = [args.image_path]
147
+ else:
148
+ raise ValueError(f"ERROR: no file found at {args.image_path}")
149
+ nimg = len(image_names)
150
+
151
+ if args.savedir:
152
+ if not os.path.exists(args.savedir):
153
+ raise FileExistsError("--savedir {args.savedir} does not exist")
154
+
155
+ cstr0 = ["GRAY", "RED", "GREEN", "BLUE"]
156
+ cstr1 = ["NONE", "RED", "GREEN", "BLUE"]
157
+ logger.info(
158
+ ">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s"
159
+ % (nimg, cstr0[channels[0]], cstr1[channels[1]]))
160
+
161
+ # handle built-in model exceptions
162
+ if builtin_size and restore_type is None and not args.pretrained_model_ortho:
163
+ model = models.Cellpose(gpu=gpu, device=device, model_type=model_type,
164
+ backbone=backbone)
165
+ else:
166
+ builtin_size = False
167
+ if args.all_channels:
168
+ channels = None
169
+ img = io.imread(image_names[0])
170
+ if img.ndim == 3:
171
+ nchan = min(img.shape)
172
+ elif img.ndim == 2:
173
+ nchan = 1
174
+ channels = None
175
+ else:
176
+ nchan = 2
177
+
178
+ pretrained_model = None if model_type is not None else pretrained_model
179
+ if restore_type is None:
180
+ pretrained_model_ortho = None if args.pretrained_model_ortho is None else args.pretrained_model_ortho
181
+ model = models.CellposeModel(gpu=gpu, device=device,
182
+ pretrained_model=pretrained_model,
183
+ model_type=model_type,
184
+ nchan=nchan,
185
+ backbone=backbone,
186
+ pretrained_model_ortho=pretrained_model_ortho)
187
+ else:
188
+ model = denoise.CellposeDenoiseModel(
189
+ gpu=gpu, device=device, pretrained_model=pretrained_model,
190
+ model_type=model_type, restore_type=restore_type, nchan=nchan,
191
+ chan2_restore=args.chan2_restore)
192
+
193
+ # handle diameters
194
+ if args.diameter == 0:
195
+ if builtin_size:
196
+ diameter = None
197
+ logger.info(">>>> estimating diameter for each image")
198
+ else:
199
+ if restore_type is None:
200
+ logger.info(
201
+ ">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
202
+ )
203
+ else:
204
+ logger.info(
205
+ ">>>> cannot auto-estimate diameter for image restoration")
206
+ diameter = model.diam_labels
207
+ logger.info(">>>> using diameter %0.3f for all images" % diameter)
208
+ else:
209
+ diameter = args.diameter
210
+ logger.info(">>>> using diameter %0.3f for all images" % diameter)
211
+
212
+ tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
213
+
214
+ for image_name in tqdm(image_names, file=tqdm_out):
215
+ image = io.imread(image_name)
216
+ out = model.eval(
217
+ image, channels=channels, diameter=diameter, do_3D=args.do_3D,
218
+ augment=args.augment, resample=(not args.no_resample),
219
+ flow_threshold=args.flow_threshold,
220
+ cellprob_threshold=args.cellprob_threshold,
221
+ stitch_threshold=args.stitch_threshold, min_size=args.min_size,
222
+ invert=args.invert, batch_size=args.batch_size,
223
+ interp=(not args.no_interp), normalize=normalize,
224
+ channel_axis=args.channel_axis, z_axis=args.z_axis,
225
+ anisotropy=args.anisotropy, niter=args.niter,
226
+ flow3D_smooth=args.flow3D_smooth)
227
+ masks, flows = out[:2]
228
+ if len(out) > 3 and restore_type is None:
229
+ diams = out[-1]
230
+ else:
231
+ diams = diameter
232
+ ratio = 1.
233
+ if restore_type is not None:
234
+ imgs_dn = out[-1]
235
+ ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1.
236
+ diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams
237
+ else:
238
+ imgs_dn = None
239
+ if args.exclude_on_edges:
240
+ masks = utils.remove_edge_masks(masks)
241
+ if not args.no_npy:
242
+ io.masks_flows_to_seg(image, masks, flows, image_name,
243
+ imgs_restore=imgs_dn, channels=channels,
244
+ diams=diams, restore_type=restore_type,
245
+ ratio=1.)
246
+ if saving_something:
247
+ suffix = "_cp_masks"
248
+ if args.output_name is not None:
249
+ # (1) If `savedir` is not defined, then must have a non-zero `suffix`
250
+ if args.savedir is None and len(args.output_name) > 0:
251
+ suffix = args.output_name
252
+ elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
253
+ # (2) If `savedir` is defined, and different from `dir` then
254
+ # takes the value passed as a param. (which can be empty string)
255
+ suffix = args.output_name
256
+
257
+ io.save_masks(image, masks, flows, image_name,
258
+ suffix=suffix, png=args.save_png,
259
+ tif=args.save_tif, save_flows=args.save_flows,
260
+ save_outlines=args.save_outlines,
261
+ dir_above=args.dir_above, savedir=args.savedir,
262
+ save_txt=args.save_txt, in_folders=args.in_folders,
263
+ save_mpl=args.save_mpl)
264
+ if args.save_rois:
265
+ io.save_rois(masks, image_name)
266
+ logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
267
+ else:
268
+ test_dir = None if len(args.test_dir) == 0 else args.test_dir
269
+ images, labels, image_names, train_probs = None, None, None, None
270
+ test_images, test_labels, image_names_test, test_probs = None, None, None, None
271
+ compute_flows = False
272
+ if len(args.file_list) > 0:
273
+ if os.path.exists(args.file_list):
274
+ dat = np.load(args.file_list, allow_pickle=True).item()
275
+ image_names = dat["train_files"]
276
+ image_names_test = dat.get("test_files", None)
277
+ train_probs = dat.get("train_probs", None)
278
+ test_probs = dat.get("test_probs", None)
279
+ compute_flows = dat.get("compute_flows", False)
280
+ load_files = False
281
+ else:
282
+ logger.critical(f"ERROR: {args.file_list} does not exist")
283
+ else:
284
+ output = io.load_train_test_data(args.dir, test_dir, imf,
285
+ args.mask_filter,
286
+ args.look_one_level_down)
287
+ images, labels, image_names, test_images, test_labels, image_names_test = output
288
+ load_files = True
289
+
290
+ # training with all channels
291
+ if args.all_channels:
292
+ img = images[0] if images is not None else io.imread(image_names[0])
293
+ if img.ndim == 3:
294
+ nchan = min(img.shape)
295
+ elif img.ndim == 2:
296
+ nchan = 1
297
+ channels = None
298
+ else:
299
+ nchan = 2
300
+
301
+ # model path
302
+ szmean = args.diam_mean
303
+ if not os.path.exists(pretrained_model) and model_type is None:
304
+ if not args.train:
305
+ error_message = "ERROR: model path missing or incorrect - cannot train size model"
306
+ logger.critical(error_message)
307
+ raise ValueError(error_message)
308
+ pretrained_model = False
309
+ logger.info(">>>> training from scratch")
310
+ if args.train:
311
+ logger.info(
312
+ ">>>> during training rescaling images to fixed diameter of %0.1f pixels"
313
+ % args.diam_mean)
314
+
315
+ # initialize model
316
+ model = models.CellposeModel(
317
+ device=device, model_type=model_type, diam_mean=szmean, nchan=nchan,
318
+ pretrained_model=pretrained_model if model_type is None else None,
319
+ backbone=backbone)
320
+
321
+ # train segmentation model
322
+ if args.train:
323
+ cpmodel_path = train.train_seg(
324
+ model.net, images, labels, train_files=image_names,
325
+ test_data=test_images, test_labels=test_labels,
326
+ test_files=image_names_test, train_probs=train_probs,
327
+ test_probs=test_probs, compute_flows=compute_flows,
328
+ load_files=load_files, normalize=normalize,
329
+ channels=channels, channel_axis=args.channel_axis, rgb=(nchan == 3),
330
+ learning_rate=args.learning_rate, weight_decay=args.weight_decay,
331
+ SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size,
332
+ min_train_masks=args.min_train_masks,
333
+ nimg_per_epoch=args.nimg_per_epoch,
334
+ nimg_test_per_epoch=args.nimg_test_per_epoch,
335
+ save_path=os.path.realpath(args.dir), save_every=args.save_every,
336
+ model_name=args.model_name_out)[0]
337
+ model.pretrained_model = cpmodel_path
338
+ logger.info(">>>> model trained and saved to %s" % cpmodel_path)
339
+
340
+ # train size model
341
+ if args.train_size:
342
+ sz_model = models.SizeModel(cp_model=model, device=device)
343
+ # data has already been normalized and reshaped
344
+ sz_model.params = train.train_size(
345
+ model.net, model.pretrained_model, images, labels,
346
+ train_files=image_names, test_data=test_images,
347
+ test_labels=test_labels, test_files=image_names_test,
348
+ train_probs=train_probs, test_probs=test_probs,
349
+ load_files=load_files, channels=channels,
350
+ min_train_masks=args.min_train_masks,
351
+ channel_axis=args.channel_axis, rgb=(nchan == 3),
352
+ nimg_per_epoch=args.nimg_per_epoch, normalize=normalize,
353
+ nimg_test_per_epoch=args.nimg_test_per_epoch,
354
+ batch_size=args.batch_size)
355
+ if test_images is not None:
356
+ test_masks = [lbl[0] for lbl in test_labels
357
+ ] if test_labels is not None else test_labels
358
+ predicted_diams, diams_style = sz_model.eval(
359
+ test_images, channels=channels)
360
+ ccs = np.corrcoef(
361
+ diams_style,
362
+ np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0, 1]
363
+ cc = np.corrcoef(
364
+ predicted_diams,
365
+ np.array([utils.diameters(lbl)[0] for lbl in test_masks]))[0, 1]
366
+ logger.info(
367
+ "style test correlation: %0.4f; final test correlation: %0.4f" %
368
+ (ccs, cc))
369
+ np.save(
370
+ os.path.join(
371
+ args.test_dir,
372
+ "%s_predicted_diams.npy" % os.path.split(cpmodel_path)[1]),
373
+ {
374
+ "predicted_diams": predicted_diams,
375
+ "diams_style": diams_style
376
+ })
377
+
378
+
379
+ if __name__ == "__main__":
380
+ main()
cellpose/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (206 Bytes). View file
 
cellpose/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (241 Bytes). View file
 
cellpose/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (216 Bytes). View file
 
cellpose/__pycache__/core.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
cellpose/__pycache__/core.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
cellpose/__pycache__/core.cpython-312.pyc ADDED
Binary file (16.6 kB). View file
 
cellpose/__pycache__/dynamics.cpython-310.pyc ADDED
Binary file (34.5 kB). View file
 
cellpose/__pycache__/dynamics.cpython-311.pyc ADDED
Binary file (63.9 kB). View file
 
cellpose/__pycache__/dynamics.cpython-312.pyc ADDED
Binary file (58.4 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py310.1.nbc ADDED
Binary file (55 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py310.2.nbc ADDED
Binary file (54.9 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py310.nbi ADDED
Binary file (2.09 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py311.1.nbc ADDED
Binary file (54.5 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py311.2.nbc ADDED
Binary file (54.4 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py311.nbi ADDED
Binary file (2.09 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py312.1.nbc ADDED
Binary file (54.6 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py312.2.nbc ADDED
Binary file (54.5 kB). View file
 
cellpose/__pycache__/dynamics.map_coordinates-414.py312.nbi ADDED
Binary file (2.08 kB). View file
 
cellpose/__pycache__/io.cpython-310.pyc ADDED
Binary file (21.6 kB). View file
 
cellpose/__pycache__/io.cpython-311.pyc ADDED
Binary file (41.6 kB). View file
 
cellpose/__pycache__/io.cpython-312.pyc ADDED
Binary file (35.8 kB). View file
 
cellpose/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
cellpose/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
cellpose/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (14.8 kB). View file
 
cellpose/__pycache__/models.cpython-310.pyc ADDED
Binary file (32.4 kB). View file
 
cellpose/__pycache__/models.cpython-311.pyc ADDED
Binary file (48.3 kB). View file
 
cellpose/__pycache__/models.cpython-312.pyc ADDED
Binary file (44.9 kB). View file
 
cellpose/__pycache__/plot.cpython-310.pyc ADDED
Binary file (8.84 kB). View file
 
cellpose/__pycache__/plot.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
cellpose/__pycache__/plot.cpython-312.pyc ADDED
Binary file (15.6 kB). View file
 
cellpose/__pycache__/resnet_torch.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
cellpose/__pycache__/resnet_torch.cpython-311.pyc ADDED
Binary file (21.9 kB). View file
 
cellpose/__pycache__/resnet_torch.cpython-312.pyc ADDED
Binary file (19.1 kB). View file
 
cellpose/__pycache__/train.cpython-310.pyc ADDED
Binary file (24.4 kB). View file
 
cellpose/__pycache__/train.cpython-312.pyc ADDED
Binary file (37.9 kB). View file
 
cellpose/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (32.2 kB). View file
 
cellpose/__pycache__/transforms.cpython-311.pyc ADDED
Binary file (57.9 kB). View file
 
cellpose/__pycache__/transforms.cpython-312.pyc ADDED
Binary file (53.7 kB). View file
 
cellpose/__pycache__/utils.cpython-310.pyc ADDED
Binary file (20.9 kB). View file
 
cellpose/__pycache__/utils.cpython-311.pyc ADDED
Binary file (36.1 kB). View file
 
cellpose/__pycache__/utils.cpython-312.pyc ADDED
Binary file (33.3 kB). View file
 
cellpose/__pycache__/version.cpython-310.pyc ADDED
Binary file (646 Bytes). View file
 
cellpose/__pycache__/version.cpython-311.pyc ADDED
Binary file (883 Bytes). View file
 
cellpose/__pycache__/version.cpython-312.pyc ADDED
Binary file (801 Bytes). View file
 
cellpose/cli.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
3
+ """
4
+
5
+ import argparse
6
+
7
+
8
+ def get_arg_parser():
9
+ """ Parses command line arguments for cellpose main function
10
+
11
+ Note: this function has to be in a separate file to allow autodoc to work for CLI.
12
+ The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
13
+ see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
14
+ """
15
+
16
+ parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
17
+
18
+ # misc settings
19
+ parser.add_argument("--version", action="store_true",
20
+ help="show cellpose version info")
21
+ parser.add_argument(
22
+ "--verbose", action="store_true",
23
+ help="show information about running and settings and save to log")
24
+ parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
25
+
26
+ # settings for CPU vs GPU
27
+ hardware_args = parser.add_argument_group("Hardware Arguments")
28
+ hardware_args.add_argument("--use_gpu", action="store_true",
29
+ help="use gpu if torch with cuda installed")
30
+ hardware_args.add_argument(
31
+ "--gpu_device", required=False, default="0", type=str,
32
+ help="which gpu device to use, use an integer for torch, or mps for M1")
33
+ hardware_args.add_argument("--check_mkl", action="store_true",
34
+ help="check if mkl working")
35
+
36
+ # settings for locating and formatting images
37
+ input_img_args = parser.add_argument_group("Input Image Arguments")
38
+ input_img_args.add_argument("--dir", default=[], type=str,
39
+ help="folder containing data to run or train on.")
40
+ input_img_args.add_argument(
41
+ "--image_path", default=[], type=str, help=
42
+ "if given and --dir not given, run on single image instead of folder (cannot train with this option)"
43
+ )
44
+ input_img_args.add_argument(
45
+ "--look_one_level_down", action="store_true",
46
+ help="run processing on all subdirectories of current folder")
47
+ input_img_args.add_argument("--img_filter", default=[], type=str,
48
+ help="end string for images to run on")
49
+ input_img_args.add_argument(
50
+ "--channel_axis", default=None, type=int,
51
+ help="axis of image which corresponds to image channels")
52
+ input_img_args.add_argument("--z_axis", default=None, type=int,
53
+ help="axis of image which corresponds to Z dimension")
54
+ input_img_args.add_argument(
55
+ "--chan", default=0, type=int, help=
56
+ "channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s")
57
+ input_img_args.add_argument(
58
+ "--chan2", default=0, type=int, help=
59
+ "nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s"
60
+ )
61
+ input_img_args.add_argument("--invert", action="store_true",
62
+ help="invert grayscale channel")
63
+ input_img_args.add_argument(
64
+ "--all_channels", action="store_true", help=
65
+ "use all channels in image if using own model and images with special channels")
66
+
67
+ # model settings
68
+ model_args = parser.add_argument_group("Model Arguments")
69
+ model_args.add_argument("--pretrained_model", required=False, default="cyto3",
70
+ type=str,
71
+ help="model to use for running or starting training")
72
+ model_args.add_argument("--restore_type", required=False, default=None, type=str,
73
+ help="model to use for image restoration")
74
+ model_args.add_argument("--chan2_restore", action="store_true",
75
+ help="use nuclei restore model for second channel")
76
+ model_args.add_argument(
77
+ "--add_model", required=False, default=None, type=str,
78
+ help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
79
+ model_args.add_argument(
80
+ "--transformer", action="store_true", help=
81
+ "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
82
+ model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
83
+ type=str,
84
+ help="model to use for running 3D ortho views (ZY and ZX)")
85
+
86
+ # algorithm settings
87
+ algorithm_args = parser.add_argument_group("Algorithm Arguments")
88
+ algorithm_args.add_argument(
89
+ "--no_resample", action="store_true", help=
90
+ "disable dynamics on full image (makes algorithm faster for images with large diameters)"
91
+ )
92
+ algorithm_args.add_argument(
93
+ "--no_interp", action="store_true",
94
+ help="do not interpolate when running dynamics (was default)")
95
+ algorithm_args.add_argument("--no_norm", action="store_true",
96
+ help="do not normalize images (normalize=False)")
97
+ parser.add_argument(
98
+ '--norm_percentile',
99
+ nargs=2, # Require exactly two values
100
+ metavar=('VALUE1', 'VALUE2'),
101
+ help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
102
+ )
103
+ algorithm_args.add_argument(
104
+ "--do_3D", action="store_true",
105
+ help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
106
+ algorithm_args.add_argument(
107
+ "--diameter", required=False, default=30., type=float, help=
108
+ "cell diameter, if 0 will use the diameter of the training labels used in the model, or with built-in model will estimate diameter for each image"
109
+ )
110
+ algorithm_args.add_argument(
111
+ "--stitch_threshold", required=False, default=0.0, type=float,
112
+ help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
113
+ )
114
+ algorithm_args.add_argument(
115
+ "--min_size", required=False, default=15, type=int,
116
+ help="minimum number of pixels per mask, can turn off with -1")
117
+ algorithm_args.add_argument(
118
+ "--flow3D_smooth", required=False, default=0, type=float,
119
+ help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
120
+
121
+ algorithm_args.add_argument(
122
+ "--flow_threshold", default=0.4, type=float, help=
123
+ "flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
124
+ algorithm_args.add_argument(
125
+ "--cellprob_threshold", default=0, type=float,
126
+ help="cellprob threshold, default is 0, decrease to find more and larger masks")
127
+ algorithm_args.add_argument(
128
+ "--niter", default=0, type=int, help=
129
+ "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
130
+ )
131
+
132
+ algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
133
+ help="anisotropy of volume in 3D")
134
+ algorithm_args.add_argument("--exclude_on_edges", action="store_true",
135
+ help="discard masks which touch edges of image")
136
+ algorithm_args.add_argument(
137
+ "--augment", action="store_true",
138
+ help="tiles image with overlapping tiles and flips overlapped regions to augment"
139
+ )
140
+
141
+ # output settings
142
+ output_args = parser.add_argument_group("Output Arguments")
143
+ output_args.add_argument(
144
+ "--save_png", action="store_true",
145
+ help="save masks as png")
146
+ output_args.add_argument(
147
+ "--save_tif", action="store_true",
148
+ help="save masks as tif")
149
+ output_args.add_argument(
150
+ "--output_name", default=None, type=str,
151
+ help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
152
+ output_args.add_argument("--no_npy", action="store_true",
153
+ help="suppress saving of npy")
154
+ output_args.add_argument(
155
+ "--savedir", default=None, type=str, help=
156
+ "folder to which segmentation results will be saved (defaults to input image directory)"
157
+ )
158
+ output_args.add_argument(
159
+ "--dir_above", action="store_true", help=
160
+ "save output folders adjacent to image folder instead of inside it (off by default)"
161
+ )
162
+ output_args.add_argument("--in_folders", action="store_true",
163
+ help="flag to save output in folders (off by default)")
164
+ output_args.add_argument(
165
+ "--save_flows", action="store_true", help=
166
+ "whether or not to save RGB images of flows when masks are saved (disabled by default)"
167
+ )
168
+ output_args.add_argument(
169
+ "--save_outlines", action="store_true", help=
170
+ "whether or not to save RGB outline images when masks are saved (disabled by default)"
171
+ )
172
+ output_args.add_argument(
173
+ "--save_rois", action="store_true",
174
+ help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
175
+ )
176
+ output_args.add_argument(
177
+ "--save_txt", action="store_true",
178
+ help="flag to enable txt outlines for ImageJ (disabled by default)")
179
+ output_args.add_argument(
180
+ "--save_mpl", action="store_true",
181
+ help="save a figure of image/mask/flows using matplotlib (disabled by default). "
182
+ "This is slow, especially with large images.")
183
+
184
+ # training settings
185
+ training_args = parser.add_argument_group("Training Arguments")
186
+ training_args.add_argument("--train", action="store_true",
187
+ help="train network using images in dir")
188
+ training_args.add_argument("--train_size", action="store_true",
189
+ help="train size network at end of training")
190
+ training_args.add_argument("--test_dir", default=[], type=str,
191
+ help="folder containing test data (optional)")
192
+ training_args.add_argument(
193
+ "--file_list", default=[], type=str, help=
194
+ "path to list of files for training and testing and probabilities for each image (optional)"
195
+ )
196
+ training_args.add_argument(
197
+ "--mask_filter", default="_masks", type=str, help=
198
+ "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
199
+ )
200
+ training_args.add_argument(
201
+ "--diam_mean", default=30., type=float, help=
202
+ "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
203
+ )
204
+ training_args.add_argument("--learning_rate", default=0.2, type=float,
205
+ help="learning rate. Default: %(default)s")
206
+ training_args.add_argument("--weight_decay", default=0.00001, type=float,
207
+ help="weight decay. Default: %(default)s")
208
+ training_args.add_argument("--n_epochs", default=500, type=int,
209
+ help="number of epochs. Default: %(default)s")
210
+ training_args.add_argument("--batch_size", default=8, type=int,
211
+ help="batch size. Default: %(default)s")
212
+ training_args.add_argument(
213
+ "--nimg_per_epoch", default=None, type=int,
214
+ help="number of train images per epoch. Default is to use all train images.")
215
+ training_args.add_argument(
216
+ "--nimg_test_per_epoch", default=None, type=int,
217
+ help="number of test images per epoch. Default is to use all test images.")
218
+ training_args.add_argument(
219
+ "--min_train_masks", default=5, type=int, help=
220
+ "minimum number of masks a training image must have to be used. Default: %(default)s"
221
+ )
222
+ training_args.add_argument("--SGD", default=1, type=int, help="use SGD")
223
+ training_args.add_argument(
224
+ "--save_every", default=100, type=int,
225
+ help="number of epochs to skip between saves. Default: %(default)s")
226
+ training_args.add_argument(
227
+ "--model_name_out", default=None, type=str,
228
+ help="Name of model to save as, defaults to name describing model architecture. "
229
+ "Model is saved in the folder specified by --dir in models subfolder.")
230
+
231
+ return parser