"""Auxiliary module for bioimageio format export Example usage: ```bash #!/bin/bash # Define default paths and parameters DEFAULT_CHANNELS="1 0" DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995" DEFAULT_PATH_README="/home/qinyu/models/cp/README.md" DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg" DEFAULT_MODEL_ID="philosophical-panda" DEFAULT_MODEL_ICON="🐼" DEFAULT_MODEL_VERSION="0.1.0" DEFAULT_MODEL_NAME="My Cool Cellpose" DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset." DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]' DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]' DEFAULT_MODEL_TAGS="cellpose 3d 2d" DEFAULT_MODEL_LICENSE="MIT" DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear" # Run the Python script with default parameters python export.py \ --channels $DEFAULT_CHANNELS \ --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \ --path_readme "$DEFAULT_PATH_README" \ --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \ --model_version "$DEFAULT_MODEL_VERSION" \ --model_name "$DEFAULT_MODEL_NAME" \ --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \ --model_authors "$DEFAULT_MODEL_AUTHORS" \ --model_cite "$DEFAULT_MODEL_CITE" \ --model_tags $DEFAULT_MODEL_TAGS \ --model_license "$DEFAULT_MODEL_LICENSE" \ --model_repo "$DEFAULT_MODEL_REPO" ``` """ import os import sys import json import argparse from pathlib import Path from urllib.parse import urlparse import torch import numpy as np from cellpose.io import imread from cellpose.utils import download_url_to_file from cellpose.transforms import pad_image_ND, normalize_img, convert_image from cellpose.resnet_torch import CPnetBioImageIO from bioimageio.spec.model.v0_5 import ( ArchitectureFromFileDescr, Author, AxisId, ChannelAxis, CiteEntry, Doi, FileDescr, Identifier, InputTensorDescr, IntervalOrRatioDataDescr, LicenseId, ModelDescr, ModelId, OrcidId, OutputTensorDescr, ParameterizedSize, PytorchStateDictWeightsDescr, SizeReference, SpaceInputAxis, SpaceOutputAxis, TensorId, TorchscriptWeightsDescr, Version, WeightsDescr, ) # Define ARBITRARY_SIZE if it is not available in the module try: from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE except ImportError: ARBITRARY_SIZE = ParameterizedSize(min=1, step=1) from bioimageio.spec.common import HttpUrl from bioimageio.spec import save_bioimageio_package from bioimageio.core import test_model DEFAULT_CHANNELS = [2, 1] DEFAULT_NORMALIZE_PARAMS = { "axis": -1, "lowhigh": None, "percentile": None, "normalize": True, "norm3D": False, "sharpen_radius": 0, "smooth_radius": 0, "tile_norm_blocksize": 0, "tile_norm_smooth3D": 1, "invert": False, } IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif" def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS): """ Download and normalize image. """ filename = os.path.basename(urlparse(IMAGE_URL).path) path_image = path_dir_temp / filename if not path_image.exists(): sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n') download_url_to_file(IMAGE_URL, path_image) img = imread(path_image).astype(np.float32) img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2) img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS) img = np.transpose(img, (0, 3, 1, 2)) img, _, _ = pad_image_ND(img) return img def load_bioimageio_cpnet_model(path_model_weight, nchan=2): cpnet_kwargs = { "nbase": [nchan, 32, 64, 128, 256], "nout": 3, "sz": 3, "mkldnn": False, "conv_3D": False, "max_pool": True, } cpnet_biio = CPnetBioImageIO(**cpnet_kwargs) state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True) cpnet_biio.load_state_dict(state_dict_cuda) cpnet_biio.eval() # crucial for the prediction results return cpnet_biio, cpnet_kwargs def descr_gen_input(path_test_input, nchan=2): input_axes = [ SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE), ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]), SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)), SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)), ] data_descr = IntervalOrRatioDataDescr(type="float32") path_test_input = Path(path_test_input) descr_input = InputTensorDescr( id=TensorId("raw"), axes=input_axes, test_tensor=FileDescr(source=path_test_input), data=data_descr, ) return descr_input def descr_gen_output_flow(path_test_output): output_axes_output_tensor = [ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]), SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))), SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))), ] path_test_output = Path(path_test_output) descr_output = OutputTensorDescr( id=TensorId("flow"), axes=output_axes_output_tensor, test_tensor=FileDescr(source=path_test_output), ) return descr_output def descr_gen_output_downsampled(path_dir_temp, nbase=None): if nbase is None: nbase = [32, 64, 128, 256] output_axes_downsampled_tensors = [ [ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]), SpaceOutputAxis( id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")), scale=2**offset, ), SpaceOutputAxis( id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")), scale=2**offset, ), ] for offset, base in enumerate(nbase) ] path_downsampled_tensors = [ Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors)) ] descr_output_downsampled_tensors = [ OutputTensorDescr( id=TensorId(f"downsampled_{i}"), axes=axes, test_tensor=FileDescr(source=path), ) for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors)) ] return descr_output_downsampled_tensors def descr_gen_output_style(path_test_style, nchannel=256): output_axes_style_tensor = [ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]), ] path_style_tensor = Path(path_test_style) descr_output_style_tensor = OutputTensorDescr( id=TensorId("style"), axes=output_axes_style_tensor, test_tensor=FileDescr(source=path_style_tensor), ) return descr_output_style_tensor def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None): if path_cpnet_wrapper is None: path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py" pytorch_architecture = ArchitectureFromFileDescr( callable=Identifier("CPnetBioImageIO"), source=Path(path_cpnet_wrapper), kwargs=cpnet_kwargs, ) return pytorch_architecture def descr_gen_documentation(path_doc, markdown_text): with open(path_doc, "w") as f: f.write(markdown_text) def package_to_bioimageio( path_pretrained_model, path_save_trace, path_readme, list_path_cover_images, descr_input, descr_output, descr_output_downsampled_tensors, descr_output_style_tensor, pytorch_version, pytorch_architecture, model_id, model_icon, model_version, model_name, model_documentation, model_authors, model_cite, model_tags, model_license, model_repo, ): """Package model description to BioImage.IO format.""" my_model_descr = ModelDescr( id=ModelId(model_id) if model_id is not None else None, id_emoji=model_icon, version=Version(model_version), name=model_name, description=model_documentation, authors=[ Author( name=author["name"], affiliation=author["affiliation"], github_user=author["github_user"], orcid=OrcidId(author["orcid"]), ) for author in model_authors ], cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite], covers=[Path(img) for img in list_path_cover_images], license=LicenseId(model_license), tags=model_tags, documentation=Path(path_readme), git_repo=HttpUrl(model_repo), inputs=[descr_input], outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors, weights=WeightsDescr( pytorch_state_dict=PytorchStateDictWeightsDescr( source=Path(path_pretrained_model), architecture=pytorch_architecture, pytorch_version=pytorch_version, ), torchscript=TorchscriptWeightsDescr( source=Path(path_save_trace), pytorch_version=pytorch_version, parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights. ), ), ) return my_model_descr def parse_args(): # fmt: off parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose") parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]") parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995") parser.add_argument("--path_readme", required=True, type=str, help="Path to README file") parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images") parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None) parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None) parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0") parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose") parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.") parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'") parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'") parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d") parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT") parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL") return parser.parse_args() # fmt: on def main(): args = parse_args() # Parse user-provided paths and arguments channels = args.channels model_cite = json.loads(args.model_cite) model_authors = json.loads(args.model_authors) path_readme = Path(args.path_readme) path_pretrained_model = Path(args.path_pretrained_model) list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images] # Auto-generated paths path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py" path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem path_dir_temp.mkdir(parents=True, exist_ok=True) path_save_trace = path_dir_temp / "cp_traced.pt" path_test_input = path_dir_temp / "test_input.npy" path_test_output = path_dir_temp / "test_output.npy" path_test_style = path_dir_temp / "test_style.npy" path_bioimageio_package = path_dir_temp / "cellpose_model.zip" # Download test input image img_np = download_and_normalize_image(path_dir_temp, channels=channels) np.save(path_test_input, img_np) img = torch.tensor(img_np).float() # Load model cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model) # Test model and save output tuple_output_tensor = cpnet_biio(img) np.save(path_test_output, tuple_output_tensor[0].detach().numpy()) np.save(path_test_style, tuple_output_tensor[1].detach().numpy()) for i, t in enumerate(tuple_output_tensor[2:]): np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy()) # Save traced model model_traced = torch.jit.trace(cpnet_biio, img) model_traced.save(path_save_trace) # Generate model description descr_input = descr_gen_input(path_test_input) descr_output = descr_gen_output_flow(path_test_output) descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:]) descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1]) pytorch_version = Version(torch.__version__) pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper) # Package model my_model_descr = package_to_bioimageio( path_pretrained_model, path_save_trace, path_readme, list_path_cover_images, descr_input, descr_output, descr_output_downsampled_tensors, descr_output_style_tensor, pytorch_version, pytorch_architecture, args.model_id, args.model_icon, args.model_version, args.model_name, args.model_documentation, model_authors, model_cite, args.model_tags, args.model_license, args.model_repo, ) # Test model summary = test_model(my_model_descr, weight_format="pytorch_state_dict") summary.display() summary = test_model(my_model_descr, weight_format="torchscript") summary.display() # Save BioImage.IO package package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package)) print("package path:", package_path) if __name__ == "__main__": main()