"""
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""
import sys, os, pathlib, warnings, datetime, time, copy
from qtpy import QtGui, QtCore
from superqt import QRangeSlider, QCollapsible
from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
import pyqtgraph as pg
import numpy as np
from scipy.stats import mode
import cv2
from . import guiparts, menus, io
from .. import models, core, dynamics, version, denoise, train
from ..utils import download_url_to_file, masks_to_outlines, diameters
from ..io import get_image_files, imsave, imread
from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
from ..models import normalize_default
from ..plot import disk
try:
import matplotlib.pyplot as plt
MATPLOTLIB = True
except:
MATPLOTLIB = False
try:
from google.cloud import storage
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "key/cellpose-data-writer.json")
SERVER_UPLOAD = True
except:
SERVER_UPLOAD = False
Horizontal = QtCore.Qt.Orientation.Horizontal
class Slider(QRangeSlider):
def __init__(self, parent, name, color):
super().__init__(Horizontal)
self.setEnabled(False)
self.valueChanged.connect(lambda: self.levelChanged(parent))
self.name = name
self.setStyleSheet(""" QSlider{
background-color: transparent;
}
""")
self.show()
def levelChanged(self, parent):
parent.level_change(self.name)
class QHLine(QFrame):
def __init__(self):
super(QHLine, self).__init__()
self.setFrameShape(QFrame.HLine)
#self.setFrameShadow(QFrame.Sunken)
self.setLineWidth(8)
def make_bwr():
# make a bwr colormap
b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis]
g = np.append(np.linspace(0, 255, 128),
np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
color = np.concatenate((r, g, b), axis=-1).astype(np.uint8)
bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
return bwr
def make_spectral():
# make spectral colormap
r = np.array([
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80,
84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128,
128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23,
27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103,
107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167,
171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231,
235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255
])
g = np.array([
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3,
2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111,
119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239,
247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143,
135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150,
151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175,
177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201,
202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226,
228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251,
253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199,
195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135,
131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63,
59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41,
49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180,
189, 197, 205, 213, 222, 230, 238, 246, 254
])
b = np.array([
0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143,
151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247,
243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183,
179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124,
122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90,
88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50,
48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10,
8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74,
82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205,
213, 222, 230, 238, 246, 254
])
color = (np.vstack((r, g, b)).T).astype(np.uint8)
spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
return spectral
def make_cmap(cm=0):
# make a single channel colormap
r = np.arange(0, 256)
color = np.zeros((256, 3))
color[:, cm] = r
color = color.astype(np.uint8)
cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
return cmap
def run(image=None):
from ..io import logger_setup
logger, log_file = logger_setup()
# Always start by initializing Qt (only once per application)
warnings.filterwarnings("ignore")
app = QApplication(sys.argv)
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
if not icon_path.is_file():
cp_dir = pathlib.Path.home().joinpath(".cellpose")
cp_dir.mkdir(exist_ok=True)
print("downloading logo")
download_url_to_file(
"https://www.cellpose.org/static/images/cellpose_transparent.png",
icon_path, progress=True)
if not guip_path.is_file():
print("downloading help window image")
download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
guip_path, progress=True)
icon_path = str(icon_path.resolve())
app_icon = QtGui.QIcon()
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
app.setWindowIcon(app_icon)
app.setStyle("Fusion")
app.setPalette(guiparts.DarkPalette())
#app.setStyleSheet("QLineEdit { color: yellow }")
# models.download_model_weights() # does not exist
MainW(image=image, logger=logger)
ret = app.exec_()
sys.exit(ret)
class MainW(QMainWindow):
def __init__(self, image=None, logger=None):
super(MainW, self).__init__()
self.logger = logger
pg.setConfigOptions(imageAxisOrder="row-major")
self.setGeometry(50, 50, 1200, 1000)
self.setWindowTitle(f"cellpose v{version}")
self.cp_path = os.path.dirname(os.path.realpath(__file__))
app_icon = QtGui.QIcon()
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
icon_path = str(icon_path.resolve())
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
self.setWindowIcon(app_icon)
# rgb(150,255,150)
self.setStyleSheet(guiparts.stylesheet())
menus.mainmenu(self)
menus.editmenu(self)
menus.modelmenu(self)
menus.helpmenu(self)
self.stylePressed = """QPushButton {Text-align: center;
background-color: rgb(150,50,150);
border-color: white;
color:white;}
QToolTip {
background-color: black;
color: white;
border: black solid 1px
}"""
self.styleUnpressed = """QPushButton {Text-align: center;
background-color: rgb(50,50,50);
border-color: white;
color:white;}
QToolTip {
background-color: black;
color: white;
border: black solid 1px
}"""
self.loaded = False
# ---- MAIN WIDGET LAYOUT ---- #
self.cwidget = QWidget(self)
self.lmain = QGridLayout()
self.cwidget.setLayout(self.lmain)
self.setCentralWidget(self.cwidget)
self.lmain.setVerticalSpacing(0)
self.lmain.setContentsMargins(0, 0, 0, 10)
self.imask = 0
self.scrollarea = QScrollArea()
self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
self.scrollarea.setStyleSheet("""QScrollArea { border: none }""")
self.scrollarea.setWidgetResizable(True)
self.swidget = QWidget(self)
self.scrollarea.setWidget(self.swidget)
self.l0 = QGridLayout()
self.swidget.setLayout(self.l0)
b = self.make_buttons()
self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9)
# ---- drawing area ---- #
self.win = pg.GraphicsLayoutWidget()
self.lmain.addWidget(self.win, 0, 9, 40, 30)
self.win.scene().sigMouseClicked.connect(self.plot_clicked)
self.win.scene().sigMouseMoved.connect(self.mouse_moved)
self.make_viewbox()
self.lmain.setColumnStretch(10, 1)
bwrmap = make_bwr()
self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
self.cmap = []
# spectral colormap
self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0,
alpha=False))
# single channel colormaps
for i in range(3):
self.cmap.append(
make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False))
if MATPLOTLIB:
self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) *
255).astype(np.uint8)
np.random.seed(42) # make colors stable
self.colormap = self.colormap[np.random.permutation(1000000)]
else:
np.random.seed(42) # make colors stable
self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype(
np.uint8)
self.NZ = 1
self.restore = None
self.ratio = 1.
self.reset()
# if called with image, load it
if image is not None:
self.filename = image
io._load_image(self, self.filename)
# training settings
d = datetime.datetime.now()
self.training_params = {
"model_index": 0,
"learning_rate": 0.1,
"weight_decay": 0.0001,
"n_epochs": 100,
"SGD": True,
"model_name": "CP" + d.strftime("_%Y%m%d_%H%M%S"),
}
self.load_3D = False
self.stitch_threshold = 0.
self.flow3D_smooth = 0.
self.anisotropy = 1.
self.min_size = 15
self.resample = True
self.setAcceptDrops(True)
self.win.show()
self.show()
def help_window(self):
HW = guiparts.HelpWindow(self)
HW.show()
def train_help_window(self):
THW = guiparts.TrainHelpWindow(self)
THW.show()
def gui_window(self):
EG = guiparts.ExampleGUI(self)
EG.show()
def make_buttons(self):
self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold)
self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold)
self.medfont = QtGui.QFont("Arial", 9)
self.smallfont = QtGui.QFont("Arial", 8)
b = 0
self.satBox = QGroupBox("Views")
self.satBox.setFont(self.boldfont)
self.satBoxG = QGridLayout()
self.satBox.setLayout(self.satBoxG)
self.l0.addWidget(self.satBox, b, 0, 1, 9)
b0 = 0
self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
self.RGBDropDown = QComboBox()
self.RGBDropDown.addItems(
["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"])
self.RGBDropDown.setFont(self.medfont)
self.RGBDropDown.currentIndexChanged.connect(self.color_choose)
self.satBoxG.addWidget(self.RGBDropDown, b0, 0, 1, 3)
label = QLabel("
[↑ / ↓ or W/S]
")
label.setFont(self.smallfont)
self.satBoxG.addWidget(label, b0, 3, 1, 3)
label = QLabel("[R / G / B \n toggles color ]")
label.setFont(self.smallfont)
self.satBoxG.addWidget(label, b0, 6, 1, 3)
b0 += 1
self.ViewDropDown = QComboBox()
self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"])
self.ViewDropDown.setFont(self.medfont)
self.ViewDropDown.model().item(3).setEnabled(False)
self.ViewDropDown.currentIndexChanged.connect(self.update_plot)
self.satBoxG.addWidget(self.ViewDropDown, b0, 0, 2, 3)
label = QLabel("[pageup / pagedown]")
label.setFont(self.smallfont)
self.satBoxG.addWidget(label, b0, 3, 1, 5)
b0 += 2
label = QLabel("")
label.setToolTip(
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
)
self.satBoxG.addWidget(label, b0, 0, 1, 5)
self.autobtn = QCheckBox("auto-adjust saturation")
self.autobtn.setToolTip("sets scale-bars as normalized for segmentation")
self.autobtn.setFont(self.medfont)
self.autobtn.setChecked(True)
self.satBoxG.addWidget(self.autobtn, b0, 1, 1, 8)
b0 += 1
self.sliders = []
colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]]
colornames = ["red", "Chartreuse", "DodgerBlue"]
names = ["red", "green", "blue"]
for r in range(3):
b0 += 1
if r == 0:
label = QLabel('gray/
red')
else:
label = QLabel(names[r] + ":")
label.setStyleSheet(f"color: {colornames[r]}")
label.setFont(self.boldmedfont)
self.satBoxG.addWidget(label, b0, 0, 1, 2)
self.sliders.append(Slider(self, names[r], colors[r]))
self.sliders[-1].setMinimum(-.1)
self.sliders[-1].setMaximum(255.1)
self.sliders[-1].setValue([0, 255])
self.sliders[-1].setToolTip(
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
)
#self.sliders[-1].setTickPosition(QSlider.TicksRight)
self.satBoxG.addWidget(self.sliders[-1], b0, 2, 1, 7)
b += 1
self.drawBox = QGroupBox("Drawing")
self.drawBox.setFont(self.boldfont)
self.drawBoxG = QGridLayout()
self.drawBox.setLayout(self.drawBoxG)
self.l0.addWidget(self.drawBox, b, 0, 1, 9)
self.autosave = True
b0 = 0
self.brush_size = 3
self.BrushChoose = QComboBox()
self.BrushChoose.addItems(["1", "3", "5", "7", "9"])
self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
self.BrushChoose.setFixedWidth(40)
self.BrushChoose.setFont(self.medfont)
self.drawBoxG.addWidget(self.BrushChoose, b0, 3, 1, 2)
label = QLabel("brush size:")
label.setFont(self.medfont)
self.drawBoxG.addWidget(label, b0, 0, 1, 3)
b0 += 1
# turn off masks
self.layer_off = False
self.masksOn = True
self.MCheckBox = QCheckBox("MASKS ON [X]")
self.MCheckBox.setFont(self.medfont)
self.MCheckBox.setChecked(True)
self.MCheckBox.toggled.connect(self.toggle_masks)
self.drawBoxG.addWidget(self.MCheckBox, b0, 0, 1, 5)
b0 += 1
# turn off outlines
self.outlinesOn = False # turn off by default
self.OCheckBox = QCheckBox("outlines on [Z]")
self.OCheckBox.setFont(self.medfont)
self.drawBoxG.addWidget(self.OCheckBox, b0, 0, 1, 5)
self.OCheckBox.setChecked(False)
self.OCheckBox.toggled.connect(self.toggle_masks)
b0 += 1
self.SCheckBox = QCheckBox("single stroke")
self.SCheckBox.setFont(self.medfont)
self.SCheckBox.setChecked(True)
self.SCheckBox.toggled.connect(self.autosave_on)
self.SCheckBox.setEnabled(True)
self.drawBoxG.addWidget(self.SCheckBox, b0, 0, 1, 5)
# buttons for deleting multiple cells
self.deleteBox = QGroupBox("delete multiple ROIs")
self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)")
self.deleteBox.setFont(self.medfont)
self.deleteBoxG = QGridLayout()
self.deleteBox.setLayout(self.deleteBoxG)
self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4)
self.MakeDeletionRegionButton = QPushButton("region-select")
self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells)
self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4)
self.MakeDeletionRegionButton.setFont(self.smallfont)
self.MakeDeletionRegionButton.setFixedWidth(70)
self.DeleteMultipleROIButton = QPushButton("click-select")
self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells)
self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4)
self.DeleteMultipleROIButton.setFont(self.smallfont)
self.DeleteMultipleROIButton.setFixedWidth(70)
self.DoneDeleteMultipleROIButton = QPushButton("done")
self.DoneDeleteMultipleROIButton.clicked.connect(
self.done_remove_multiple_cells)
self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2)
self.DoneDeleteMultipleROIButton.setFont(self.smallfont)
self.DoneDeleteMultipleROIButton.setFixedWidth(35)
self.CancelDeleteMultipleROIButton = QPushButton("cancel")
self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple)
self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2)
self.CancelDeleteMultipleROIButton.setFont(self.smallfont)
self.CancelDeleteMultipleROIButton.setFixedWidth(35)
b += 1
b0 = 0
self.segBox = QGroupBox("Segmentation")
self.segBoxG = QGridLayout()
self.segBox.setLayout(self.segBoxG)
self.l0.addWidget(self.segBox, b, 0, 1, 9)
self.segBox.setFont(self.boldfont)
self.diameter = 30
label = QLabel("diameter (pixels):")
label.setFont(self.medfont)
label.setToolTip(
'you can manually enter the approximate diameter for your cells, \nor press “calibrate” to let the model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking “scale disk on”)'
)
self.segBoxG.addWidget(label, b0, 0, 1, 4)
self.Diameter = QLineEdit()
self.Diameter.setToolTip(
'you can manually enter the approximate diameter for your cells, \nor press “calibrate” to let the "cyto3" model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking “scale disk on”)'
)
self.Diameter.setText(str(self.diameter))
self.Diameter.setFont(self.medfont)
self.Diameter.returnPressed.connect(self.update_scale)
self.Diameter.setFixedWidth(50)
self.segBoxG.addWidget(self.Diameter, b0, 4, 1, 2)
# compute diameter
self.SizeButton = QPushButton("calibrate")
self.SizeButton.setFont(self.medfont)
self.SizeButton.clicked.connect(self.calibrate_size)
self.segBoxG.addWidget(self.SizeButton, b0, 6, 1, 3)
#self.SizeButton.setFixedWidth(65)
self.SizeButton.setEnabled(False)
self.SizeButton.setToolTip(
'you can manually enter the approximate diameter for your cells, \nor press “calibrate” to let the cyto3 model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking “scale disk on”)'
)
b0 += 1
# choose channel
self.ChannelChoose = [QComboBox(), QComboBox()]
self.ChannelChoose[0].addItems(["0: gray", "1: red", "2: green", "3: blue"])
self.ChannelChoose[1].addItems(["0: none", "1: red", "2: green", "3: blue"])
cstr = ["chan to segment:", "chan2 (optional): "]
for i in range(2):
self.ChannelChoose[i].setFont(self.medfont)
label = QLabel(cstr[i])
label.setFont(self.medfont)
if i == 0:
label.setToolTip(
"this is the channel in which the cytoplasm or nuclei exist that you want to segment"
)
self.ChannelChoose[i].setToolTip(
"this is the channel in which the cytoplasm or nuclei exist that you want to segment"
)
else:
label.setToolTip(
"if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option"
)
self.ChannelChoose[i].setToolTip(
"if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option"
)
self.segBoxG.addWidget(label, b0 + i, 0, 1, 4)
self.segBoxG.addWidget(self.ChannelChoose[i], b0 + i, 4, 1, 5)
b0 += 2
# use GPU
self.useGPU = QCheckBox("use GPU")
self.useGPU.setToolTip(
"if you have specially installed the cuda version of torch, then you can activate this"
)
self.useGPU.setFont(self.medfont)
self.check_gpu()
self.segBoxG.addWidget(self.useGPU, b0, 0, 1, 3)
# compute segmentation with general models
self.net_text = ["run cyto3"]
nett = ["cellpose super-generalist model"]
#label = QLabel("Run:")
#label.setFont(self.boldfont)
#label.setFont(self.medfont)
#self.segBoxG.addWidget(label, b0, 0, 1, 2)
self.StyleButtons = []
jj = 4
for j in range(len(self.net_text)):
self.StyleButtons.append(
guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
w = 5
self.segBoxG.addWidget(self.StyleButtons[-1], b0, jj, 1, w)
jj += w
#self.StyleButtons[-1].setFixedWidth(140)
self.StyleButtons[-1].setToolTip(nett[j])
b0 += 1
self.roi_count = QLabel("0 ROIs")
self.roi_count.setFont(self.boldfont)
self.roi_count.setAlignment(QtCore.Qt.AlignLeft)
self.segBoxG.addWidget(self.roi_count, b0, 0, 1, 4)
self.progress = QProgressBar(self)
self.segBoxG.addWidget(self.progress, b0, 4, 1, 5)
b0 += 1
self.segaBox = QCollapsible("additional settings")
self.segaBox.setFont(self.medfont)
self.segaBox._toggle_btn.setFont(self.medfont)
self.segaBoxG = QGridLayout()
_content = QWidget()
_content.setLayout(self.segaBoxG)
_content.setMaximumHeight(0)
_content.setMinimumHeight(0)
#_content.layout().setContentsMargins(QtCore.QMargins(0, -20, -20, -20))
self.segaBox.setContent(_content)
self.segBoxG.addWidget(self.segaBox, b0, 0, 1, 9)
b0 = 0
# post-hoc paramater tuning
label = QLabel("flow\nthreshold:")
label.setToolTip(
"threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run"
)
label.setFont(self.medfont)
self.segaBoxG.addWidget(label, b0, 0, 1, 2)
self.flow_threshold = QLineEdit()
self.flow_threshold.setText("0.4")
self.flow_threshold.returnPressed.connect(self.compute_cprob)
self.flow_threshold.setFixedWidth(40)
self.flow_threshold.setFont(self.medfont)
self.segaBoxG.addWidget(self.flow_threshold, b0, 2, 1, 2)
self.flow_threshold.setToolTip(
"threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run"
)
label = QLabel("cellprob\nthreshold:")
label.setToolTip(
"threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run"
)
label.setFont(self.medfont)
self.segaBoxG.addWidget(label, b0, 4, 1, 2)
self.cellprob_threshold = QLineEdit()
self.cellprob_threshold.setText("0.0")
self.cellprob_threshold.returnPressed.connect(self.compute_cprob)
self.cellprob_threshold.setFixedWidth(40)
self.cellprob_threshold.setFont(self.medfont)
self.cellprob_threshold.setToolTip(
"threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run"
)
self.segaBoxG.addWidget(self.cellprob_threshold, b0, 6, 1, 2)
b0 += 1
label = QLabel("norm percentiles:")
label.setToolTip(
"sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)"
)
label.setFont(self.medfont)
self.segaBoxG.addWidget(label, b0, 0, 1, 8)
b0 += 1
self.norm_vals = [1., 99.]
self.norm_edits = []
labels = ["lower", "upper"]
tooltips = [
"pixels at this percentile set to 0 (default 1.0)",
"pixels at this percentile set to 1 (default 99.0)"
]
for p in range(2):
label = QLabel(f"{labels[p]}:")
label.setToolTip(tooltips[p])
label.setFont(self.medfont)
self.segaBoxG.addWidget(label, b0, 4 * (p % 2), 1, 2)
self.norm_edits.append(QLineEdit())
self.norm_edits[p].setText(str(self.norm_vals[p]))
self.norm_edits[p].setFixedWidth(40)
self.norm_edits[p].setFont(self.medfont)
self.segaBoxG.addWidget(self.norm_edits[p], b0, 4 * (p % 2) + 2, 1, 2)
self.norm_edits[p].setToolTip(tooltips[p])
b0 += 1
label = QLabel("niter dynamics:")
label.setFont(self.medfont)
label.setToolTip(
"number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria"
)
self.segaBoxG.addWidget(label, b0, 0, 1, 4)
self.niter = QLineEdit()
self.niter.setText("0")
self.niter.setFixedWidth(40)
self.niter.setFont(self.medfont)
self.niter.setToolTip(
"number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria"
)
self.segaBoxG.addWidget(self.niter, b0, 4, 1, 2)
b += 1
b0 = 0
self.modelBox = QGroupBox("Other models")
self.modelBoxG = QGridLayout()
self.modelBox.setLayout(self.modelBoxG)
self.l0.addWidget(self.modelBox, b, 0, 1, 9)
self.modelBox.setFont(self.boldfont)
# choose models
self.ModelChooseC = QComboBox()
self.ModelChooseC.setFont(self.medfont)
current_index = 0
self.ModelChooseC.addItems(["custom models"])
if len(self.model_strings) > 0:
self.ModelChooseC.addItems(self.model_strings)
self.ModelChooseC.setFixedWidth(175)
self.ModelChooseC.setCurrentIndex(current_index)
tipstr = 'add or train your own models in the "Models" file menu and choose model here'
self.ModelChooseC.setToolTip(tipstr)
self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True))
self.modelBoxG.addWidget(self.ModelChooseC, b0, 0, 1, 8)
# compute segmentation w/ custom model
self.ModelButtonC = QPushButton(u"run")
self.ModelButtonC.setFont(self.medfont)
self.ModelButtonC.setFixedWidth(35)
self.ModelButtonC.clicked.connect(
lambda: self.compute_segmentation(custom=True))
self.modelBoxG.addWidget(self.ModelButtonC, b0, 8, 1, 1)
self.ModelButtonC.setEnabled(False)
self.net_names = [
"nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3",
"yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3",
"cyto", "cyto2", "CPx"]
nett = [
"nuclei", "cellpose (cyto2_cp3)", "tissuenet_cp3", "livecell_cp3",
"yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3",
"deepbacs_cp3", "cyto", "cyto2",
"CPx (from Cellpose2)"
]
b0 += 1
self.ModelChooseB = QComboBox()
self.ModelChooseB.setFont(self.medfont)
self.ModelChooseB.addItems(["dataset-specific models"])
self.ModelChooseB.addItems(nett)
self.ModelChooseB.setFixedWidth(175)
tipstr = "dataset-specific models"
self.ModelChooseB.setToolTip(tipstr)
self.ModelChooseB.activated.connect(lambda: self.model_choose(custom=False))
self.modelBoxG.addWidget(self.ModelChooseB, b0, 0, 1, 8)
# compute segmentation w/ cp model
self.ModelButtonB = QPushButton(u"run")
self.ModelButtonB.setFont(self.medfont)
self.ModelButtonB.setFixedWidth(35)
self.ModelButtonB.clicked.connect(
lambda: self.compute_segmentation(custom=False))
self.modelBoxG.addWidget(self.ModelButtonB, b0, 8, 1, 1)
self.ModelButtonB.setEnabled(False)
b += 1
self.denoiseBox = QGroupBox("Image restoration")
self.denoiseBox.setFont(self.boldfont)
self.denoiseBoxG = QGridLayout()
self.denoiseBox.setLayout(self.denoiseBoxG)
self.l0.addWidget(self.denoiseBox, b, 0, 1, 9)
b0 = 0
# DENOISING
self.DenoiseButtons = []
nett = [
"clear restore/filter",
"filter image (settings below)",
"denoise (please set cell diameter first)",
"deblur (please set cell diameter first)",
"upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)",
"one-click model trained to denoise+deblur+upsample (please set cell diameter first)"
]
self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"]
self.restore = None
self.ratio = 1.
jj = 0
w = 3
for j in range(len(self.denoise_text)):
self.DenoiseButtons.append(
guiparts.DenoiseButton(self, self.denoise_text[j]))
self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w)
self.DenoiseButtons[-1].setFixedWidth(75)
self.DenoiseButtons[-1].setToolTip(nett[j])
self.DenoiseButtons[-1].setFont(self.medfont)
b0 += 1 if j%2==1 else 0
jj = 0 if j%2==1 else jj + w
# b0+=1
self.save_norm = QCheckBox("save restored/filtered image")
self.save_norm.setFont(self.medfont)
self.save_norm.setToolTip("save restored/filtered image in _seg.npy file")
self.save_norm.setChecked(True)
# self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8)
b0 -= 3
label = QLabel("restore-dataset:")
label.setToolTip(
"choose dataset and click [denoise], [deblur], [upsample], or [one-click]")
label.setFont(self.medfont)
self.denoiseBoxG.addWidget(label, b0, 6, 1, 3)
b0 += 1
self.DenoiseChoose = QComboBox()
self.DenoiseChoose.setFont(self.medfont)
self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"])
self.DenoiseChoose.setFixedWidth(85)
tipstr = "choose model type and click [denoise], [deblur], or [upsample]"
self.DenoiseChoose.setToolTip(tipstr)
self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3)
b0 += 2
# FILTERING
self.filtBox = QCollapsible("custom filter settings")
self.filtBox._toggle_btn.setFont(self.medfont)
self.filtBoxG = QGridLayout()
_content = QWidget()
_content.setLayout(self.filtBoxG)
_content.setMaximumHeight(0)
_content.setMinimumHeight(0)
#_content.layout().setContentsMargins(QtCore.QMargins(0, -20, -20, -20))
self.filtBox.setContent(_content)
self.denoiseBoxG.addWidget(self.filtBox, b0, 0, 1, 9)
self.filt_vals = [0., 0., 0., 0.]
self.filt_edits = []
labels = [
"sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize",
"tile_norm\nsmooth3D"
]
tooltips = [
"set size of surround-subtraction filter for sharpening image",
"set size of gaussian filter for smoothing image",
"set size of tiles to use to normalize image",
"set amount of smoothing of normalization values across planes"
]
for p in range(4):
label = QLabel(f"{labels[p]}:")
label.setToolTip(tooltips[p])
label.setFont(self.medfont)
self.filtBoxG.addWidget(label, b0 + p // 2, 4 * (p % 2), 1, 2)
self.filt_edits.append(QLineEdit())
self.filt_edits[p].setText(str(self.filt_vals[p]))
self.filt_edits[p].setFixedWidth(40)
self.filt_edits[p].setFont(self.medfont)
self.filtBoxG.addWidget(self.filt_edits[p], b0 + p // 2, 4 * (p % 2) + 2, 1,
2)
self.filt_edits[p].setToolTip(tooltips[p])
b0 += 3
self.norm3D_cb = QCheckBox("norm3D")
self.norm3D_cb.setFont(self.medfont)
self.norm3D_cb.setChecked(True)
self.norm3D_cb.setToolTip("run same normalization across planes")
self.filtBoxG.addWidget(self.norm3D_cb, b0, 0, 1, 3)
self.invert_cb = QCheckBox("invert")
self.invert_cb.setFont(self.medfont)
self.invert_cb.setToolTip("invert image")
self.filtBoxG.addWidget(self.invert_cb, b0, 3, 1, 3)
b += 1
self.l0.addWidget(QLabel(""), b, 0, 1, 9)
self.l0.setRowStretch(b, 100)
b += 1
# scale toggle
self.scale_on = True
self.ScaleOn = QCheckBox("scale disk on")
self.ScaleOn.setFont(self.medfont)
self.ScaleOn.setStyleSheet("color: rgb(150,50,150);")
self.ScaleOn.setChecked(True)
self.ScaleOn.setToolTip("see current diameter as red disk at bottom")
self.ScaleOn.toggled.connect(self.toggle_scale)
self.l0.addWidget(self.ScaleOn, b, 0, 1, 5)
return b
def level_change(self, r):
r = ["red", "green", "blue"].index(r)
if self.loaded:
sval = self.sliders[r].value()
self.saturation[r][self.currentZ] = sval
if not self.autobtn.isChecked():
for r in range(3):
for i in range(len(self.saturation[r])):
self.saturation[r][i] = self.saturation[r][self.currentZ]
self.update_plot()
def keyPressEvent(self, event):
if self.loaded:
if not (event.modifiers() &
(QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
QtCore.Qt.AltModifier) or self.in_stroke):
updated = False
if len(self.current_point_set) > 0:
if event.key() == QtCore.Qt.Key_Return:
self.add_set()
else:
nviews = self.ViewDropDown.count() - 1
nviews += int(
self.ViewDropDown.model().item(self.ViewDropDown.count() -
1).isEnabled())
if event.key() == QtCore.Qt.Key_X:
self.MCheckBox.toggle()
if event.key() == QtCore.Qt.Key_Z:
self.OCheckBox.toggle()
if event.key() == QtCore.Qt.Key_Left or event.key(
) == QtCore.Qt.Key_A:
self.get_prev_image()
elif event.key() == QtCore.Qt.Key_Right or event.key(
) == QtCore.Qt.Key_D:
self.get_next_image()
elif event.key() == QtCore.Qt.Key_PageDown:
self.view = (self.view + 1) % (nviews)
self.ViewDropDown.setCurrentIndex(self.view)
elif event.key() == QtCore.Qt.Key_PageUp:
self.view = (self.view - 1) % (nviews)
self.ViewDropDown.setCurrentIndex(self.view)
# can change background or stroke size if cell not finished
if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
self.color = (self.color - 1) % (6)
self.RGBDropDown.setCurrentIndex(self.color)
elif event.key() == QtCore.Qt.Key_Down or event.key(
) == QtCore.Qt.Key_S:
self.color = (self.color + 1) % (6)
self.RGBDropDown.setCurrentIndex(self.color)
elif event.key() == QtCore.Qt.Key_R:
if self.color != 1:
self.color = 1
else:
self.color = 0
self.RGBDropDown.setCurrentIndex(self.color)
elif event.key() == QtCore.Qt.Key_G:
if self.color != 2:
self.color = 2
else:
self.color = 0
self.RGBDropDown.setCurrentIndex(self.color)
elif event.key() == QtCore.Qt.Key_B:
if self.color != 3:
self.color = 3
else:
self.color = 0
self.RGBDropDown.setCurrentIndex(self.color)
elif (event.key() == QtCore.Qt.Key_Comma or
event.key() == QtCore.Qt.Key_Period):
count = self.BrushChoose.count()
gci = self.BrushChoose.currentIndex()
if event.key() == QtCore.Qt.Key_Comma:
gci = max(0, gci - 1)
else:
gci = min(count - 1, gci + 1)
self.BrushChoose.setCurrentIndex(gci)
self.brush_choose()
if not updated:
self.update_plot()
if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
self.p0.keyPressEvent(event)
def autosave_on(self):
if self.SCheckBox.isChecked():
self.autosave = True
else:
self.autosave = False
def check_gpu(self, torch=True):
# also decide whether or not to use torch
self.useGPU.setChecked(False)
self.useGPU.setEnabled(False)
if core.use_gpu(use_torch=True):
self.useGPU.setEnabled(True)
self.useGPU.setChecked(True)
else:
self.useGPU.setStyleSheet("color: rgb(80,80,80);")
def get_channels(self):
channels = [
self.ChannelChoose[0].currentIndex(), self.ChannelChoose[1].currentIndex()
]
if hasattr(self, "current_model"):
if self.current_model == "nuclei":
channels[1] = 0
if channels[0] == 0:
channels[1] = 0
if self.nchan == 1:
channels = [0, 0]
elif self.nchan == 2:
if channels[0] == 3:
channels[0] = 1 if channels[1] != 1 else 2
print(
f"GUI_WARNING: only two channels in image, cannot use blue channel, changing channels"
)
if channels[1] == 3:
channels[1] = 1 if channels[0] != 1 else 2
print(
f"GUI_WARNING: only two channels in image, cannot use blue channel, changing channels"
)
self.ChannelChoose[0].setCurrentIndex(channels[0])
self.ChannelChoose[1].setCurrentIndex(channels[1])
return channels
def model_choose(self, custom=False):
index = self.ModelChooseC.currentIndex(
) if custom else self.ModelChooseB.currentIndex()
if index > 0:
if custom:
model_name = self.ModelChooseC.currentText()
else:
model_name = self.net_names[index - 1]
print(f"GUI_INFO: selected model {model_name}, loading now")
self.initialize_model(model_name=model_name, custom=custom)
self.diameter = self.model.diam_labels
self.Diameter.setText("%0.2f" % self.diameter)
print(
f"GUI_INFO: diameter set to {self.diameter: 0.2f} (but can be changed)")
def calibrate_size(self):
self.initialize_model(model_name="cyto3")
diams, _ = self.model.sz.eval(self.stack[self.currentZ].copy(),
channels=self.get_channels(),
progress=self.progress)
diams = np.maximum(5.0, diams)
self.logger.info("estimated diameter of cells using %s model = %0.1f pixels" %
(self.current_model, diams))
self.Diameter.setText("%0.1f" % diams)
self.diameter = diams
self.update_scale()
self.progress.setValue(100)
def toggle_scale(self):
if self.scale_on:
self.p0.removeItem(self.scale)
self.scale_on = False
else:
self.p0.addItem(self.scale)
self.scale_on = True
def enable_buttons(self):
if len(self.model_strings) > 0:
self.ModelButtonC.setEnabled(True)
for i in range(len(self.StyleButtons)):
self.StyleButtons[i].setEnabled(True)
for i in range(len(self.DenoiseButtons)):
self.DenoiseButtons[i].setEnabled(True)
if self.load_3D:
self.DenoiseButtons[-2].setEnabled(False)
self.ModelButtonB.setEnabled(True)
self.SizeButton.setEnabled(True)
self.newmodel.setEnabled(True)
self.loadMasks.setEnabled(True)
for n in range(self.nchan):
self.sliders[n].setEnabled(True)
for n in range(self.nchan, 3):
self.sliders[n].setEnabled(True)
self.toggle_mask_ops()
self.update_plot()
self.setWindowTitle(self.filename)
def disable_buttons_removeROIs(self):
if len(self.model_strings) > 0:
self.ModelButtonC.setEnabled(False)
for i in range(len(self.StyleButtons)):
self.StyleButtons[i].setEnabled(False)
self.ModelButtonB.setEnabled(False)
self.SizeButton.setEnabled(False)
self.newmodel.setEnabled(False)
self.loadMasks.setEnabled(False)
self.saveSet.setEnabled(False)
self.savePNG.setEnabled(False)
self.saveFlows.setEnabled(False)
self.saveOutlines.setEnabled(False)
self.saveROIs.setEnabled(False)
self.MakeDeletionRegionButton.setEnabled(False)
self.DeleteMultipleROIButton.setEnabled(False)
self.DoneDeleteMultipleROIButton.setEnabled(True)
self.CancelDeleteMultipleROIButton.setEnabled(True)
def toggle_mask_ops(self):
self.update_layer()
self.toggle_saving()
self.toggle_removals()
def toggle_saving(self):
if self.ncells > 0:
self.saveSet.setEnabled(True)
self.savePNG.setEnabled(True)
self.saveFlows.setEnabled(True)
self.saveOutlines.setEnabled(True)
self.saveROIs.setEnabled(True)
else:
self.saveSet.setEnabled(False)
self.savePNG.setEnabled(False)
self.saveFlows.setEnabled(False)
self.saveOutlines.setEnabled(False)
self.saveROIs.setEnabled(False)
def toggle_removals(self):
if self.ncells > 0:
self.ClearButton.setEnabled(True)
self.remcell.setEnabled(True)
self.undo.setEnabled(True)
self.MakeDeletionRegionButton.setEnabled(True)
self.DeleteMultipleROIButton.setEnabled(True)
self.DoneDeleteMultipleROIButton.setEnabled(False)
self.CancelDeleteMultipleROIButton.setEnabled(False)
else:
self.ClearButton.setEnabled(False)
self.remcell.setEnabled(False)
self.undo.setEnabled(False)
self.MakeDeletionRegionButton.setEnabled(False)
self.DeleteMultipleROIButton.setEnabled(False)
self.DoneDeleteMultipleROIButton.setEnabled(False)
self.CancelDeleteMultipleROIButton.setEnabled(False)
def remove_action(self):
if self.selected > 0:
self.remove_cell(self.selected)
def undo_action(self):
if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ):
self.remove_stroke()
else:
# remove previous cell
if self.ncells > 0:
self.remove_cell(self.ncells)
def undo_remove_action(self):
self.undo_remove_cell()
def get_files(self):
folder = os.path.dirname(self.filename)
mask_filter = "_masks"
images = get_image_files(folder, mask_filter)
fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
f0 = os.path.split(self.filename)[-1]
idx = np.nonzero(np.array(fnames) == f0)[0][0]
return images, idx
def get_prev_image(self):
images, idx = self.get_files()
idx = (idx - 1) % len(images)
io._load_image(self, filename=images[idx])
def get_next_image(self, load_seg=True):
images, idx = self.get_files()
idx = (idx + 1) % len(images)
io._load_image(self, filename=images[idx], load_seg=load_seg)
def dragEnterEvent(self, event):
if event.mimeData().hasUrls():
event.accept()
else:
event.ignore()
def dropEvent(self, event):
files = [u.toLocalFile() for u in event.mimeData().urls()]
if os.path.splitext(files[0])[-1] == ".npy":
io._load_seg(self, filename=files[0], load_3D=self.load_3D)
else:
io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D)
def toggle_masks(self):
if self.MCheckBox.isChecked():
self.masksOn = True
else:
self.masksOn = False
if self.OCheckBox.isChecked():
self.outlinesOn = True
else:
self.outlinesOn = False
if not self.masksOn and not self.outlinesOn:
self.p0.removeItem(self.layer)
self.layer_off = True
else:
if self.layer_off:
self.p0.addItem(self.layer)
self.draw_layer()
self.update_layer()
if self.loaded:
self.update_plot()
self.update_layer()
def make_viewbox(self):
self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True,
name="plot1", border=[100, 100,
100], invertY=True)
self.p0.setCursor(QtCore.Qt.CrossCursor)
self.brush_size = 3
self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1)
self.p0.setMenuEnabled(False)
self.p0.setMouseEnabled(x=True, y=True)
self.img = pg.ImageItem(viewbox=self.p0, parent=self)
self.img.autoDownsample = False
self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self)
self.layer.setLevels([0, 255])
self.scale = pg.ImageItem(viewbox=self.p0, parent=self)
self.scale.setLevels([0, 255])
self.p0.scene().contextMenuItem = self.p0
#self.p0.setMouseEnabled(x=False,y=False)
self.Ly, self.Lx = 512, 512
self.p0.addItem(self.img)
self.p0.addItem(self.layer)
self.p0.addItem(self.scale)
def reset(self):
# ---- start sets of points ---- #
self.selected = 0
self.nchan = 3
self.loaded = False
self.channel = [0, 1]
self.current_point_set = []
self.in_stroke = False
self.strokes = []
self.stroke_appended = True
self.resize = False
self.ncells = 0
self.zdraw = []
self.removed_cell = []
self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
# -- zero out image stack -- #
self.opacity = 128 # how opaque masks should be
self.outcolor = [200, 200, 255, 200]
self.NZ, self.Ly, self.Lx = 1, 224, 224
self.saturation = []
for r in range(3):
self.saturation.append([[0, 255] for n in range(self.NZ)])
self.sliders[r].setValue([0, 255])
self.sliders[r].setEnabled(False)
self.sliders[r].show()
self.currentZ = 0
self.flows = [[], [], [], [], [[]]]
# masks matrix
# image matrix with a scale disk
self.stack = np.zeros((1, self.Ly, self.Lx, 3))
self.Lyr, self.Lxr = self.Ly, self.Lx
self.Ly0, self.Lx0 = self.Ly, self.Lx
self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
if self.restore and "upsample" in self.restore:
self.cellpix_resize = self.cellpix
self.cellpix_orig = self.cellpix
self.outpix_resize = self.cellpix
self.outpix_orig = self.cellpix
self.ismanual = np.zeros(0, "bool")
# -- set menus to default -- #
self.color = 0
self.RGBDropDown.setCurrentIndex(self.color)
self.view = 0
self.ViewDropDown.setCurrentIndex(0)
self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
self.delete_restore()
self.clear_all()
#self.update_plot()
self.filename = []
self.loaded = False
self.recompute_masks = False
self.deleting_multiple = False
self.removing_cells_list = []
self.removing_region = False
self.remove_roi_obj = None
def delete_restore(self):
""" delete restored imgs but don't reset settings """
if hasattr(self, "stack_filtered"):
del self.stack_filtered
if hasattr(self, "cellpix_orig"):
self.cellpix = self.cellpix_orig.copy()
self.outpix = self.outpix_orig.copy()
del self.outpix_orig, self.outpix_resize
del self.cellpix_orig, self.cellpix_resize
def clear_restore(self):
""" delete restored imgs and reset settings """
print("GUI_INFO: clearing restored image")
self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1:
self.ViewDropDown.setCurrentIndex(0)
self.delete_restore()
self.restore = None
self.ratio = 1.
self.set_normalize_params(self.get_normalize_params())
def brush_choose(self):
self.brush_size = self.BrushChoose.currentIndex() * 2 + 1
if self.loaded:
self.layer.setDrawKernel(kernel_size=self.brush_size)
self.update_layer()
def clear_all(self):
self.prev_selected = 0
self.selected = 0
if self.restore and "upsample" in self.restore:
self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8)
self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
self.cellpix_resize = self.cellpix.copy()
self.outpix_resize = self.outpix.copy()
self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
else:
self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
self.ncells = 0
self.toggle_removals()
self.update_scale()
self.update_layer()
def select_cell(self, idx):
self.prev_selected = self.selected
self.selected = idx
if self.selected > 0:
z = self.currentZ
self.layerz[self.cellpix[z] == idx] = np.array(
[255, 255, 255, self.opacity])
self.update_layer()
def select_cell_multi(self, idx):
if idx > 0:
z = self.currentZ
self.layerz[self.cellpix[z] == idx] = np.array(
[255, 255, 255, self.opacity])
self.update_layer()
def unselect_cell(self):
if self.selected > 0:
idx = self.selected
if idx < self.ncells + 1:
z = self.currentZ
self.layerz[self.cellpix[z] == idx] = np.append(
self.cellcolors[idx], self.opacity)
if self.outlinesOn:
self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
np.uint8)
#[0,0,0,self.opacity])
self.update_layer()
self.selected = 0
def unselect_cell_multi(self, idx):
z = self.currentZ
self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx],
self.opacity)
if self.outlinesOn:
self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
np.uint8)
# [0,0,0,self.opacity])
self.update_layer()
def remove_cell(self, idx):
if isinstance(idx, (int, np.integer)):
idx = [idx]
# because the function remove_single_cell updates the state of the cellpix and outpix arrays
# by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
# so that the indices are correct
idx.sort(reverse=True)
for i in idx:
self.remove_single_cell(i)
self.ncells -= len(idx) # _save_sets uses ncells
if self.ncells == 0:
self.ClearButton.setEnabled(False)
if self.NZ == 1:
io._save_sets_with_check(self)
self.update_layer()
def remove_single_cell(self, idx):
# remove from manual array
self.selected = 0
if self.NZ > 1:
zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0]
else:
zextent = [0]
for z in zextent:
cp = self.cellpix[z] == idx
op = self.outpix[z] == idx
# remove from self.cellpix and self.outpix
self.cellpix[z, cp] = 0
self.outpix[z, op] = 0
if z == self.currentZ:
# remove from mask layer
self.layerz[cp] = np.array([0, 0, 0, 0])
# reduce other pixels by -1
self.cellpix[self.cellpix > idx] -= 1
self.outpix[self.outpix > idx] -= 1
if self.NZ == 1:
self.removed_cell = [
self.ismanual[idx - 1], self.cellcolors[idx],
np.nonzero(cp),
np.nonzero(op)
]
self.redo.setEnabled(True)
ar, ac = self.removed_cell[2]
d = datetime.datetime.now()
self.track_changes.append(
[d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
# remove cell from lists
self.ismanual = np.delete(self.ismanual, idx - 1)
self.cellcolors = np.delete(self.cellcolors, [idx], axis=0)
del self.zdraw[idx - 1]
print("GUI_INFO: removed cell %d" % (idx - 1))
def remove_region_cells(self):
if self.removing_cells_list:
for idx in self.removing_cells_list:
self.unselect_cell_multi(idx)
self.removing_cells_list.clear()
self.disable_buttons_removeROIs()
self.removing_region = True
self.clear_multi_selected_cells()
# make roi region here in center of view, making ROI half the size of the view
roi_width = self.p0.viewRect().width() / 2
x_loc = self.p0.viewRect().x() + (roi_width / 2)
roi_height = self.p0.viewRect().height() / 2
y_loc = self.p0.viewRect().y() + (roi_height / 2)
pos = [x_loc, y_loc]
roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2),
removable=True)
roi.sigRemoveRequested.connect(self.remove_roi)
roi.sigRegionChangeFinished.connect(self.roi_changed)
self.p0.addItem(roi)
self.remove_roi_obj = roi
self.roi_changed(roi)
def delete_multiple_cells(self):
self.unselect_cell()
self.disable_buttons_removeROIs()
self.DoneDeleteMultipleROIButton.setEnabled(True)
self.MakeDeletionRegionButton.setEnabled(True)
self.CancelDeleteMultipleROIButton.setEnabled(True)
self.deleting_multiple = True
def done_remove_multiple_cells(self):
self.deleting_multiple = False
self.removing_region = False
self.DoneDeleteMultipleROIButton.setEnabled(False)
self.MakeDeletionRegionButton.setEnabled(False)
self.CancelDeleteMultipleROIButton.setEnabled(False)
if self.removing_cells_list:
self.removing_cells_list = list(set(self.removing_cells_list))
display_remove_list = [i - 1 for i in self.removing_cells_list]
print(f"GUI_INFO: removing cells: {display_remove_list}")
self.remove_cell(self.removing_cells_list)
self.removing_cells_list.clear()
self.unselect_cell()
self.enable_buttons()
if self.remove_roi_obj is not None:
self.remove_roi(self.remove_roi_obj)
def merge_cells(self, idx):
self.prev_selected = self.selected
self.selected = idx
if self.selected != self.prev_selected:
for z in range(self.NZ):
ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected)
ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected)
touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3,
(ac0[:, np.newaxis] - ac1) < 3).sum()
ar = np.hstack((ar0, ar1))
ac = np.hstack((ac0, ac1))
vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected)
vr1, vc1 = np.nonzero(self.outpix[z] == self.selected)
self.outpix[z, vr0, vc0] = 0
self.outpix[z, vr1, vc1] = 0
if touching > 0:
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0].squeeze().T
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
else:
vr = np.hstack((vr0, vr1))
vc = np.hstack((vc0, vc1))
color = self.cellcolors[self.prev_selected]
self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
self.remove_cell(self.selected)
print("GUI_INFO: merged two cells")
self.update_layer()
io._save_sets_with_check(self)
self.undo.setEnabled(False)
self.redo.setEnabled(False)
def undo_remove_cell(self):
if len(self.removed_cell) > 0:
z = 0
ar, ac = self.removed_cell[2]
vr, vc = self.removed_cell[3]
color = self.removed_cell[1]
self.draw_mask(z, ar, ac, vr, vc, color)
self.toggle_mask_ops()
self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0)
self.ncells += 1
self.ismanual = np.append(self.ismanual, self.removed_cell[0])
self.zdraw.append([])
print(">>> added back removed cell")
self.update_layer()
io._save_sets_with_check(self)
self.removed_cell = []
self.redo.setEnabled(False)
def remove_stroke(self, delete_points=True, stroke_ind=-1):
stroke = np.array(self.strokes[stroke_ind])
cZ = self.currentZ
inZ = stroke[0, 0] == cZ
if inZ:
outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0
self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0])
cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]]
ccol = self.cellcolors.copy()
if self.selected > 0:
ccol[self.selected] = np.array([255, 255, 255])
col2mask = ccol[cellpix]
if self.masksOn:
col2mask = np.concatenate(
(col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1)
else:
col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)),
axis=-1)
self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask
if self.outlinesOn:
self.layerz[stroke[outpix, 1], stroke[outpix,
2]] = np.array(self.outcolor)
if delete_points:
# self.current_point_set = self.current_point_set[:-1*(stroke[:,-1]==1).sum()]
del self.current_point_set[stroke_ind]
self.update_layer()
del self.strokes[stroke_ind]
def plot_clicked(self, event):
if event.button()==QtCore.Qt.LeftButton \
and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
and not self.removing_region:
if event.double():
try:
self.p0.setYRange(0, self.Ly + self.pr)
except:
self.p0.setYRange(0, self.Ly)
self.p0.setXRange(0, self.Lx)
def cancel_remove_multiple(self):
self.clear_multi_selected_cells()
self.done_remove_multiple_cells()
def clear_multi_selected_cells(self):
# unselect all previously selected cells:
for idx in self.removing_cells_list:
self.unselect_cell_multi(idx)
self.removing_cells_list.clear()
def add_roi(self, roi):
self.p0.addItem(roi)
self.remove_roi_obj = roi
def remove_roi(self, roi):
self.clear_multi_selected_cells()
assert roi == self.remove_roi_obj
self.remove_roi_obj = None
self.p0.removeItem(roi)
self.removing_region = False
def roi_changed(self, roi):
# find the overlapping cells and make them selected
pos = roi.pos()
size = roi.size()
x0 = int(pos.x())
y0 = int(pos.y())
x1 = int(pos.x() + size.x())
y1 = int(pos.y() + size.y())
if x0 < 0:
x0 = 0
if y0 < 0:
y0 = 0
if x1 > self.Lx:
x1 = self.Lx
if y1 > self.Ly:
y1 = self.Ly
# find cells in that region
cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1])
cell_idxs = np.trim_zeros(cell_idxs)
# deselect cells not in region by deselecting all and then selecting the ones in the region
self.clear_multi_selected_cells()
for idx in cell_idxs:
self.select_cell_multi(idx)
self.removing_cells_list.append(idx)
self.update_layer()
def mouse_moved(self, pos):
items = self.win.scene().items(pos)
def color_choose(self):
self.color = self.RGBDropDown.currentIndex()
self.view = 0
self.ViewDropDown.setCurrentIndex(self.view)
self.update_plot()
def update_plot(self):
self.view = self.ViewDropDown.currentIndex()
self.Ly, self.Lx, _ = self.stack[self.currentZ].shape
if self.restore and "upsample" in self.restore:
if self.view != 0:
if self.view == 3:
self.resize = True
elif len(self.flows[0]) > 0 and self.flows[0].shape[1] == self.Lyr:
self.resize = True
else:
self.resize = False
else:
self.resize = False
self.draw_layer()
self.update_scale()
self.update_layer()
if self.view == 0 or self.view == self.ViewDropDown.count() - 1:
image = self.stack[
self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ]
if self.nchan == 1:
# show single channel
image = image[..., 0]
if self.color == 0:
self.img.setImage(image, autoLevels=False, lut=None)
if self.nchan > 1:
levels = np.array([
self.saturation[0][self.currentZ],
self.saturation[1][self.currentZ],
self.saturation[2][self.currentZ]
])
self.img.setLevels(levels)
else:
self.img.setLevels(self.saturation[0][self.currentZ])
elif self.color > 0 and self.color < 4:
if self.nchan > 1:
image = image[:, :, self.color - 1]
self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color])
if self.nchan > 1:
self.img.setLevels(self.saturation[self.color - 1][self.currentZ])
else:
self.img.setLevels(self.saturation[0][self.currentZ])
elif self.color == 4:
if self.nchan > 1:
image = image.mean(axis=-1)
self.img.setImage(image, autoLevels=False, lut=None)
self.img.setLevels(self.saturation[0][self.currentZ])
elif self.color == 5:
if self.nchan > 1:
image = image.mean(axis=-1)
self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
self.img.setLevels(self.saturation[0][self.currentZ])
else:
image = np.zeros((self.Ly, self.Lx), np.uint8)
if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
image = self.flows[self.view - 1][self.currentZ]
if self.view > 1:
self.img.setImage(image, autoLevels=False, lut=self.bwr)
else:
self.img.setImage(image, autoLevels=False, lut=None)
self.img.setLevels([0.0, 255.0])
for r in range(3):
self.sliders[r].setValue([
self.saturation[r][self.currentZ][0],
self.saturation[r][self.currentZ][1]
])
self.win.show()
self.show()
def update_layer(self):
if self.masksOn or self.outlinesOn:
#self.draw_layer()
self.layer.setImage(self.layerz, autoLevels=False)
self.update_roi_count()
self.win.show()
self.show()
def update_roi_count(self):
self.roi_count.setText(f"{self.ncells} ROIs")
def add_set(self):
if len(self.current_point_set) > 0:
while len(self.strokes) > 0:
self.remove_stroke(delete_points=False)
if len(self.current_point_set[0]) > 8:
color = self.colormap[self.ncells, :3]
median = self.add_mask(points=self.current_point_set, color=color)
if median is not None:
self.removed_cell = []
self.toggle_mask_ops()
self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :],
axis=0)
self.ncells += 1
self.ismanual = np.append(self.ismanual, True)
if self.NZ == 1:
# only save after each cell if single image
io._save_sets_with_check(self)
else:
print("GUI_ERROR: cell too small, not drawn")
self.current_stroke = []
self.strokes = []
self.current_point_set = []
self.update_layer()
def add_mask(self, points=None, color=(100, 200, 50), dense=True):
# points is list of strokes
points_all = np.concatenate(points, axis=0)
# loop over z values
median = []
zdraw = np.unique(points_all[:, 0])
z = 0
ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
0, "int"), np.zeros(0, "int")
for stroke in points:
stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
vr = stroke[:, 1]
vc = stroke[:, 2]
# get points inside drawn points
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
axis=-1)[:, np.newaxis, :]
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
ar, ac = np.nonzero(mask)
ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
# get dense outline
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0][:,0].T
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
# concatenate all points
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
# if these pixels are overlapping with another cell, reassign them
ioverlap = self.cellpix[z][ar, ac] > 0
if (~ioverlap).sum() < 10:
print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn")
return None
elif ioverlap.sum() > 0:
ar, ac = ar[~ioverlap], ac[~ioverlap]
# compute outline of new mask
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0][:,0].T
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
ars = np.concatenate((ars, ar), axis=0)
acs = np.concatenate((acs, ac), axis=0)
vrs = np.concatenate((vrs, vr), axis=0)
vcs = np.concatenate((vcs, vc), axis=0)
self.draw_mask(z, ars, acs, vrs, vcs, color)
median.append(np.array([np.median(ars), np.median(acs)]))
self.zdraw.append(zdraw)
d = datetime.datetime.now()
self.track_changes.append(
[d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]])
return median
def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
""" draw single mask using outlines and area """
if idx is None:
idx = self.ncells + 1
self.cellpix[z, vr, vc] = idx
self.cellpix[z, ar, ac] = idx
self.outpix[z, vr, vc] = idx
if self.restore and "upsample" in self.restore:
if self.resize:
self.cellpix_resize[z, vr, vc] = idx
self.cellpix_resize[z, ar, ac] = idx
self.outpix_resize[z, vr, vc] = idx
self.cellpix_orig[z, (vr / self.ratio).astype(int),
(vc / self.ratio).astype(int)] = idx
self.cellpix_orig[z, (ar / self.ratio).astype(int),
(ac / self.ratio).astype(int)] = idx
self.outpix_orig[z, (vr / self.ratio).astype(int),
(vc / self.ratio).astype(int)] = idx
else:
self.cellpix_orig[z, vr, vc] = idx
self.cellpix_orig[z, ar, ac] = idx
self.outpix_orig[z, vr, vc] = idx
# get upsampled mask
vrr = (vr.copy() * self.ratio).astype(int)
vcr = (vc.copy() * self.ratio).astype(int)
mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8)
pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2),
axis=-1)[:, np.newaxis, :]
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
arr, acr = np.nonzero(mask)
arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2
# get dense outline
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE)
pvc, pvr = contours[-2][0].squeeze().T
vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2
# concatenate all points
arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr))))
self.cellpix_resize[z, vrr, vcr] = idx
self.cellpix_resize[z, arr, acr] = idx
self.outpix_resize[z, vrr, vcr] = idx
if z == self.currentZ:
self.layerz[ar, ac, :3] = color
if self.masksOn:
self.layerz[ar, ac, -1] = self.opacity
if self.outlinesOn:
self.layerz[vr, vc] = np.array(self.outcolor)
def compute_scale(self):
self.diameter = float(self.Diameter.text())
self.pr = int(float(self.Diameter.text()))
self.radii_padding = int(self.pr * 1.25)
self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8)
yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1],
self.pr / 2, self.Ly + self.radii_padding, self.Lx)
# rgb(150,50,150)
self.radii[yy, xx, 0] = 150
self.radii[yy, xx, 1] = 50
self.radii[yy, xx, 2] = 150
self.radii[yy, xx, 3] = 255
self.p0.setYRange(0, self.Ly + self.radii_padding)
self.p0.setXRange(0, self.Lx)
def update_scale(self):
self.compute_scale()
self.scale.setImage(self.radii, autoLevels=False)
self.scale.setLevels([0.0, 255.0])
self.win.show()
self.show()
def redraw_masks(self, masks=True, outlines=True, draw=True):
self.draw_layer()
def draw_masks(self):
self.draw_layer()
def draw_layer(self):
if self.resize:
self.Ly, self.Lx = self.Lyr, self.Lxr
else:
self.Ly, self.Lx = self.Ly0, self.Lx0
if self.masksOn or self.outlinesOn:
if self.restore and "upsample" in self.restore:
if self.resize:
self.cellpix = self.cellpix_resize.copy()
self.outpix = self.outpix_resize.copy()
else:
self.cellpix = self.cellpix_orig.copy()
self.outpix = self.outpix_orig.copy()
#print(self.cellpix.shape, self.outpix.shape, self.cellpix.max(), self.outpix.max())
self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8)
if self.masksOn:
self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :]
self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ]
> 0).astype(np.uint8)
if self.selected > 0:
self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array(
[255, 255, 255, self.opacity])
cZ = self.currentZ
stroke_z = np.array([s[0][0] for s in self.strokes])
inZ = np.nonzero(stroke_z == cZ)[0]
if len(inZ) > 0:
for i in inZ:
stroke = np.array(self.strokes[i])
self.layerz[stroke[:, 1], stroke[:,
2]] = np.array([255, 0, 255, 100])
else:
self.layerz[..., 3] = 0
if self.outlinesOn:
self.layerz[self.outpix[self.currentZ] > 0] = np.array(
self.outcolor).astype(np.uint8)
def set_restore_button(self):
keys = self.denoise_text
for i, key in enumerate(keys):
if key != "none" and (self.restore and key in self.restore):
self.DenoiseButtons[i].setStyleSheet(self.stylePressed)
elif key == "none" and self.restore is None:
self.DenoiseButtons[i].setStyleSheet(self.stylePressed)
else:
if self.DenoiseButtons[i].isEnabled():
self.DenoiseButtons[i].setStyleSheet(self.styleUnpressed)
def set_normalize_params(self, normalize_params):
from cellpose.models import normalize_default
if self.restore != "filter":
keys = list(normalize_params.keys()).copy()
for key in keys:
if key != "percentile":
normalize_params[key] = normalize_default[key]
normalize_params = {**normalize_default, **normalize_params}
percentile = self.check_percentile_params(normalize_params["percentile"])
out = self.check_filter_params(normalize_params["sharpen_radius"],
normalize_params["smooth_radius"],
normalize_params["tile_norm_blocksize"],
normalize_params["tile_norm_smooth3D"],
normalize_params["norm3D"],
normalize_params["invert"])
def check_percentile_params(self, percentile):
# check normalization params
if percentile is not None and not (percentile[0] >= 0 and percentile[1] > 0 and
percentile[0] < 100 and percentile[1] <= 100
and percentile[1] > percentile[0]):
print(
"GUI_ERROR: percentiles need be between 0 and 100, and upper > lower, using defaults"
)
self.norm_edits[0].setText("1.")
self.norm_edits[1].setText("99.")
percentile = [1., 99.]
elif percentile is None:
percentile = [1., 99.]
self.norm_edits[0].setText(str(percentile[0]))
self.norm_edits[1].setText(str(percentile[1]))
return percentile
def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert):
tile_norm = 0 if tile_norm < 0 else tile_norm
sharpen = 0 if sharpen < 0 else sharpen
smooth = 0 if smooth < 0 else smooth
smooth3D = 0 if smooth3D < 0 else smooth3D
norm3D = bool(norm3D)
invert = bool(invert)
if tile_norm > self.Ly and tile_norm > self.Lx:
print(
"GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling"
)
tile_norm = 0
self.filt_edits[0].setText(str(sharpen))
self.filt_edits[1].setText(str(smooth))
self.filt_edits[2].setText(str(tile_norm))
self.filt_edits[3].setText(str(smooth3D))
self.norm3D_cb.setChecked(norm3D)
self.invert_cb.setChecked(invert)
return sharpen, smooth, tile_norm, smooth3D, norm3D, invert
def get_normalize_params(self):
percentile = [
float(self.norm_edits[0].text()),
float(self.norm_edits[1].text())
]
self.check_percentile_params(percentile)
normalize_params = {"percentile": percentile}
norm3D = self.norm3D_cb.isChecked()
normalize_params["norm3D"] = norm3D
if self.restore == "filter":
sharpen = float(self.filt_edits[0].text())
smooth = float(self.filt_edits[1].text())
tile_norm = float(self.filt_edits[2].text())
smooth3D = float(self.filt_edits[3].text())
invert = self.invert_cb.isChecked()
out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D,
invert)
sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out
normalize_params["sharpen_radius"] = sharpen
normalize_params["smooth_radius"] = smooth
normalize_params["tile_norm_blocksize"] = tile_norm
normalize_params["tile_norm_smooth3D"] = smooth3D
normalize_params["invert"] = invert
from cellpose.models import normalize_default
normalize_params = {**normalize_default, **normalize_params}
return normalize_params
def compute_saturation(self, return_img=False):
norm = self.get_normalize_params()
print(norm)
sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"]
percentile = norm["percentile"]
tile_norm = norm["tile_norm_blocksize"]
invert = norm["invert"]
norm3D = norm["norm3D"]
smooth3D = norm["tile_norm_smooth3D"]
tile_norm = norm["tile_norm_blocksize"]
# if grayscale, use gray img
channels = self.get_channels()
if channels[0] == 0:
img_norm = self.stack.mean(axis=-1, keepdims=True)
elif sharpen > 0 or smooth > 0 or tile_norm > 0:
img_norm = self.stack.copy()
else:
img_norm = self.stack
if sharpen > 0 or smooth > 0 or tile_norm > 0:
self.clear_restore()
self.restore = "filter"
print(
"GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0"
)
print(
"GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this"
)
img_norm = self.stack.copy()
if sharpen > 0 or smooth > 0:
img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen,
smooth_radius=smooth)
if tile_norm > 0:
img_norm = normalize99_tile(img_norm, blocksize=tile_norm,
lower=percentile[0], upper=percentile[1],
smooth3D=smooth3D, norm3D=norm3D)
# convert to 0->255
img_norm_min = img_norm.min()
img_norm_max = img_norm.max()
for c in range(img_norm.shape[-1]):
if np.ptp(img_norm[..., c]) > 1e-3:
img_norm[..., c] -= img_norm_min
img_norm[..., c] /= (img_norm_max - img_norm_min)
img_norm *= 255
self.stack_filtered = img_norm
self.ViewDropDown.model().item(self.ViewDropDown.count() -
1).setEnabled(True)
self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
elif invert:
img_norm = self.stack.copy()
else:
img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
self.saturation = []
for c in range(img_norm.shape[-1]):
self.saturation.append([])
if np.ptp(img_norm[..., c]) > 1e-3:
if norm3D:
x01 = np.percentile(img_norm[..., c], percentile[0])
x99 = np.percentile(img_norm[..., c], percentile[1])
if invert:
x01i = 255. - x99
x99i = 255. - x01
x01, x99 = x01i, x99i
for n in range(self.NZ):
self.saturation[-1].append([x01, x99])
else:
for z in range(self.NZ):
if self.NZ > 1:
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
else:
x01 = np.percentile(img_norm[..., c], percentile[0])
x99 = np.percentile(img_norm[..., c], percentile[1])
if invert:
x01i = 255. - x99
x99i = 255. - x01
x01, x99 = x01i, x99i
self.saturation[-1].append([x01, x99])
else:
for n in range(self.NZ):
self.saturation[-1].append([0, 255.])
# if only 2 restore channels, add blue
if len(self.saturation) < 3:
for i in range(3 - len(self.saturation)):
self.saturation.append([])
for n in range(self.NZ):
self.saturation[-1].append([0, 255.])
print(self.saturation[2][self.currentZ])
if invert:
img_norm = 255. - img_norm
self.stack_filtered = img_norm
self.ViewDropDown.model().item(self.ViewDropDown.count() -
1).setEnabled(True)
self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
if img_norm.shape[-1] == 1:
self.saturation.append(self.saturation[0])
self.saturation.append(self.saturation[0])
self.autobtn.setChecked(True)
self.update_plot()
def chanchoose(self, image):
if image.ndim > 2 and self.nchan > 1:
if self.ChannelChoose[0].currentIndex() == 0:
return image.mean(axis=-1, keepdims=True)
else:
chanid = [self.ChannelChoose[0].currentIndex() - 1]
if self.ChannelChoose[1].currentIndex() > 0:
chanid.append(self.ChannelChoose[1].currentIndex() - 1)
return image[:, :, chanid]
else:
return image
def get_model_path(self, custom=False):
if custom:
self.current_model = self.ModelChooseC.currentText()
self.current_model_path = os.fspath(
models.MODEL_DIR.joinpath(self.current_model))
else:
self.current_model = self.net_names[max(
0,
self.ModelChooseB.currentIndex() - 1)]
self.current_model_path = models.model_path(self.current_model)
def initialize_model(self, model_name=None, custom=False):
if model_name == "dataset-specific models":
raise ValueError("need to specify model (use dropdown)")
elif model_name is None or custom:
self.get_model_path(custom=custom)
if not os.path.exists(self.current_model_path):
raise ValueError("need to specify model (use dropdown)")
if model_name is None or not isinstance(model_name, str):
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
pretrained_model=self.current_model_path)
else:
self.current_model = model_name
if self.current_model == "cyto" or self.current_model == "nuclei":
self.current_model_path = models.model_path(self.current_model, 0)
else:
self.current_model_path = os.fspath(
models.MODEL_DIR.joinpath(self.current_model))
if self.current_model != "cyto3":
diam_mean = 17. if self.current_model == "nuclei" else 30.
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
diam_mean=diam_mean,
model_type=self.current_model)
else:
self.model = models.Cellpose(gpu=self.useGPU.isChecked(),
model_type=self.current_model)
def add_model(self):
io._add_model(self)
return
def remove_model(self):
io._remove_model(self)
return
def new_model(self):
if self.NZ != 1:
print("ERROR: cannot train model on 3D data")
return
# train model
image_names = self.get_files()[0]
self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
image_names)
TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
train = TW.exec_()
if train:
self.logger.info(
f"training with {[os.path.split(f)[1] for f in self.train_files]}")
self.train_model(restore=restore, normalize_params=normalize_params)
else:
print("GUI_INFO: training cancelled")
def train_model(self, restore=None, normalize_params=None):
from cellpose.models import normalize_default
if normalize_params is None:
normalize_params = copy.deepcopy(normalize_default)
if self.training_params["model_index"] < len(models.MODEL_NAMES):
model_type = models.MODEL_NAMES[self.training_params["model_index"]]
self.logger.info(f"training new model starting at model {model_type}")
else:
model_type = None
self.logger.info(f"training new model starting from scratch")
self.current_model = model_type
self.channels = self.training_params["channels"]
self.logger.info(
f"training with chan = {self.ChannelChoose[0].currentText()}, chan2 = {self.ChannelChoose[1].currentText()}"
)
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
model_type=model_type)
self.SizeButton.setEnabled(False)
save_path = os.path.dirname(self.filename)
print("GUI_INFO: name of new model: " + self.training_params["model_name"])
print(f"GUI_INFO: SGD activated: {self.training_params['SGD']}")
self.new_model_path, train_losses = train.train_seg(
self.model.net, train_data=self.train_data, train_labels=self.train_labels,
channels=self.channels, normalize=normalize_params, min_train_masks=0,
save_path=save_path, nimg_per_epoch=max(8, len(self.train_data)),
learning_rate=self.training_params["learning_rate"],
weight_decay=self.training_params["weight_decay"],
n_epochs=self.training_params["n_epochs"],
SGD=self.training_params["SGD"],
model_name=self.training_params["model_name"])[:2]
# save train losses
np.save(str(self.new_model_path) + "_train_losses.npy", train_losses)
# run model on next image
io._add_model(self, self.new_model_path)
diam_labels = self.model.net.diam_labels.item() #.copy()
self.new_model_ind = len(self.model_strings)
self.autorun = True
channels = self.channels.copy()
self.clear_all()
# keep same channels
self.ChannelChoose[0].setCurrentIndex(channels[0])
self.ChannelChoose[1].setCurrentIndex(channels[1])
self.diameter = diam_labels
self.Diameter.setText("%0.2f" % self.diameter)
self.logger.info(f">>>> diameter set to diam_labels ( = {diam_labels: 0.3f} )")
self.restore = restore
self.set_normalize_params(normalize_params)
self.get_next_image(load_seg=False)
self.compute_segmentation(custom=True)
self.logger.info(
f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
)
def compute_restore(self):
if self.restore:
self.logger.info(f"running image restoration {self.restore}")
if self.restore != "filter":
rstr = self.restore.split("_")
model_type = rstr[0]
if len(rstr) > 1:
dset = rstr[1]
if dset == "cyto3":
self.DenoiseChoose.setCurrentIndex(0)
else:
self.DenoiseChoose.setCurrentIndex(1)
if "upsample" in self.restore:
i = self.DenoiseChoose.currentIndex()
diam_up = 30. if i==0 or i==1 else 17.
print(diam_up, self.ratio)
self.Diameter.setText(str(diam_up / self.ratio))
self.compute_denoise_model(model_type=model_type)
else:
self.compute_saturation()
def get_thresholds(self):
try:
flow_threshold = float(self.flow_threshold.text())
cellprob_threshold = float(self.cellprob_threshold.text())
if flow_threshold == 0.0 or self.NZ > 1:
flow_threshold = None
return flow_threshold, cellprob_threshold
except Exception as e:
print(
"flow threshold or cellprob threshold not a valid number, setting to defaults"
)
self.flow_threshold.setText("0.4")
self.cellprob_threshold.setText("0.0")
return 0.4, 0.0
def compute_cprob(self):
if self.recompute_masks:
flow_threshold, cellprob_threshold = self.get_thresholds()
if flow_threshold is None:
self.logger.info(
"computing masks with cell prob=%0.3f, no flow error threshold" %
(cellprob_threshold))
else:
self.logger.info(
"computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
(cellprob_threshold, flow_threshold))
maski = dynamics.resize_and_compute_masks(
self.flows[4][:-1], self.flows[4][-1], p=self.flows[3].copy(),
cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold,
resize=self.cellpix.shape[-2:])[0]
self.masksOn = True
if not self.OCheckBox.isChecked():
self.MCheckBox.setChecked(True)
if maski.ndim < 3:
maski = maski[np.newaxis, ...]
self.logger.info("%d cells found" % (len(np.unique(maski)[1:])))
io._masks_to_gui(self, maski, outlines=None)
self.show()
def compute_denoise_model(self, model_type=None):
self.progress.setValue(0)
try:
tic = time.time()
nstr = self.DenoiseChoose.currentText()
nstr.replace("-", "")
self.clear_restore()
model_name = model_type + "_" + nstr
print(model_name)
# denoising model
self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(),
model_type=model_name)
self.progress.setValue(10)
diam_up = 30. if "cyto" in model_name else 17.
# params
channels = self.get_channels()
self.diameter = float(self.Diameter.text())
normalize_params = self.get_normalize_params()
print("GUI_INFO: channels: ", channels)
print("GUI_INFO: normalize_params: ", normalize_params)
print("GUI_INFO: diameter (before upsampling): ", self.diameter)
data = self.stack.copy()
print(data.shape)
self.Ly, self.Lx = data.shape[-3:-1]
if "upsample" in model_name:
# get upsampling factor
if self.diameter >= diam_up:
print(
f"GUI_ERROR: cannot upsample, already set to pixel diameter >= {diam_up}"
)
self.progress.setValue(0)
return
self.ratio = diam_up / self.diameter
print(
"GUI_WARNING: upsampling image, this will also duplicate mask layer and resize it, will use more RAM"
)
print(
f"GUI_INFO: upsampling image to {diam_up} pixel diameter ({self.ratio:0.2f} times)"
)
self.Lyr, self.Lxr = int(self.Ly * self.ratio), int(self.Lx *
self.ratio)
self.Ly0, self.Lx0 = self.Ly, self.Lx
# moved resize into eval
#data = resize_image(data, Ly=self.Lyr, Lx=self.Lxr)
#self.diameter = diam_up
#self.Diameter.setText(str(diam_up))
else:
self.Lyr, self.Lxr = self.Ly, self.Lx
self.Ly0, self.Lx0 = self.Ly, self.Lx
diam_up = self.diameter
img_norm = self.denoise_model.eval(data, channels=channels, z_axis=0,
channel_axis=3, diameter=self.diameter,
normalize=normalize_params)
print(img_norm.shape)
self.diameter = diam_up
self.Diameter.setText(str(diam_up))
if img_norm.ndim == 2:
img_norm = img_norm[:, :, np.newaxis]
if img_norm.ndim == 3:
img_norm = img_norm[np.newaxis, ...]
self.progress.setValue(100)
self.logger.info(f"{model_name} finished in %0.3f sec" %
(time.time() - tic))
# compute saturation
percentile = normalize_params["percentile"]
img_norm_min = img_norm.min()
img_norm_max = img_norm.max()
chan = [0] if channels[0] == 0 else [channels[0] - 1, channels[1] - 1]
self.saturation = [[], [], []]
for c in range(img_norm.shape[-1]):
if np.ptp(img_norm[..., c]) > 1e-3:
img_norm[..., c] -= img_norm_min
img_norm[..., c] /= (img_norm_max - img_norm_min)
for z in range(self.NZ):
x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) * 255.
x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) * 255.
self.saturation[chan[c]].append([x01, x99])
notchan = np.ones(3, "bool")
notchan[np.array(chan)] = False
notchan = np.nonzero(notchan)[0]
for c in notchan:
for z in range(self.NZ):
self.saturation[c].append([0, 255.])
img_norm *= 255.
self.autobtn.setChecked(True)
# assign to denoised channels
self.stack_filtered = np.zeros(
(self.NZ, self.Lyr, self.Lxr, self.stack.shape[-1]), "float32")
for i, c in enumerate(chan[:img_norm.shape[-1]]):
for z in range(self.NZ):
self.stack_filtered[z, :, :, c] = img_norm[z, :, :, i]
# make upsampled masks
if model_type == "upsample":
self.cellpix_orig = self.cellpix.copy()
self.outpix_orig = self.outpix.copy()
self.cellpix_resize = cv2.resize(
self.cellpix_orig[0], (self.Lxr, self.Lyr),
interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
outlines = masks_to_outlines(self.cellpix_resize[0])[np.newaxis, :, :]
self.outpix_resize = outlines * self.cellpix_resize
self.restore = model_name
# draw plot
if model_type == "upsample":
self.resize = True
else:
self.resize = False
self.draw_layer()
self.update_layer()
self.update_scale()
# if denoised in grayscale, show in grayscale
if channels[0] == 0:
self.RGBDropDown.setCurrentIndex(4)
self.ViewDropDown.model().item(self.ViewDropDown.count() -
1).setEnabled(True)
self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
self.update_plot()
except Exception as e:
print("ERROR: %s" % e)
def compute_segmentation(self, custom=False, model_name=None, load_model=True):
self.progress.setValue(0)
try:
tic = time.time()
self.clear_all()
self.flows = [[], [], []]
if load_model:
self.initialize_model(model_name=model_name, custom=custom)
self.progress.setValue(10)
do_3D = self.load_3D
stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
self.stitch_threshold, float) else self.stitch_threshold
anisotropy = float(self.anisotropy.text()) if not isinstance(
self.anisotropy, float) else self.anisotropy
flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance(
self.flow3D_smooth, float) else self.flow3D_smooth
min_size = int(self.min_size.text()) if not isinstance(
self.min_size, int) else self.min_size
resample = self.resample.isChecked() if not isinstance(
self.resample, bool) else self.resample
do_3D = False if stitch_threshold > 0. else do_3D
channels = self.get_channels()
if self.restore is not None and self.restore != "filter":
data = self.stack_filtered.copy().squeeze()
else:
data = self.stack.copy().squeeze()
flow_threshold, cellprob_threshold = self.get_thresholds()
self.diameter = float(self.Diameter.text())
niter = max(0, int(self.niter.text()))
niter = None if niter == 0 else niter
normalize_params = self.get_normalize_params()
print(normalize_params)
try:
masks, flows = self.model.eval(
data, channels=channels, diameter=self.diameter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
normalize=normalize_params, stitch_threshold=stitch_threshold,
anisotropy=anisotropy, resample=resample, flow3D_smooth=flow3D_smooth,
min_size=min_size,
progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
except Exception as e:
print("NET ERROR: %s" % e)
self.progress.setValue(0)
return
self.progress.setValue(75)
# convert flows to uint8 and resize to original image size
flows_new = []
flows_new.append(flows[0].copy()) # RGB flow
flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
255).astype("uint8")) # cellprob
if self.load_3D:
if stitch_threshold == 0.:
flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
else:
flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))
if not self.load_3D:
if self.restore and "upsample" in self.restore:
self.Ly, self.Lx = self.Lyr, self.Lxr
if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
self.flows = []
for j in range(len(flows_new)):
self.flows.append(
resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
interpolation=cv2.INTER_NEAREST))
else:
self.flows = flows_new
else:
if not resample:
self.flows = []
Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
print("GUI_INFO: resizing flows to original image size")
for j in range(len(flows_new)):
flow0 = flows_new[j]
if Ly0 != Ly:
flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
no_channels=flow0.ndim==3,
interpolation=cv2.INTER_NEAREST)
if Lz0 != Lz:
flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
Ly=Lz, Lx=Lx,
no_channels=flow0.ndim==3,
interpolation=cv2.INTER_NEAREST), 0, 1)
self.flows.append(flow0)
else:
self.flows = flows_new
# add first axis
if self.NZ == 1:
masks = masks[np.newaxis, ...]
self.flows = [
self.flows[n][np.newaxis, ...] for n in range(len(self.flows))
]
self.logger.info("%d cells found with model in %0.3f sec" %
(len(np.unique(masks)[1:]), time.time() - tic))
self.progress.setValue(80)
z = 0
io._masks_to_gui(self, masks, outlines=None)
self.masksOn = True
self.MCheckBox.setChecked(True)
self.progress.setValue(100)
if self.restore != "filter" and self.restore is not None:
self.compute_saturation()
if not do_3D and not stitch_threshold > 0:
self.recompute_masks = True
else:
self.recompute_masks = False
except Exception as e:
print("ERROR: %s" % e)