""" Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. """ from qtpy import QtGui, QtCore, QtWidgets from qtpy.QtGui import QPainter, QPixmap from qtpy.QtWidgets import QApplication, QRadioButton, QWidget, QDialog, QButtonGroup, QSlider, QStyle, QStyleOptionSlider, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox import pyqtgraph as pg from pyqtgraph import functions as fn from pyqtgraph import Point import numpy as np import pathlib, os def stylesheet(): return """ QToolTip { background-color: black; color: white; border: black solid 1px } QComboBox {color: white; background-color: rgb(40,40,40);} QComboBox::item:enabled { color: white; background-color: rgb(40,40,40); selection-color: white; selection-background-color: rgb(50,100,50);} QComboBox::item:!enabled { background-color: rgb(40,40,40); color: rgb(100,100,100); } QScrollArea > QWidget > QWidget { background: transparent; border: none; margin: 0px 0px 0px 0px; } QGroupBox { border: 1px solid white; color: rgb(255,255,255); border-radius: 6px; margin-top: 8px; padding: 0px 0px;} QPushButton:pressed {Text-align: center; background-color: rgb(150,50,150); border-color: white; color:white;} QToolTip { background-color: black; color: white; border: black solid 1px } QPushButton:!pressed {Text-align: center; background-color: rgb(50,50,50); border-color: white; color:white;} QToolTip { background-color: black; color: white; border: black solid 1px } QPushButton:disabled {Text-align: center; background-color: rgb(30,30,30); border-color: white; color:rgb(80,80,80);} QToolTip { background-color: black; color: white; border: black solid 1px } """ class DarkPalette(QtGui.QPalette): """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application. (from pykilosort/kilosort4) """ def __init__(self): QtGui.QPalette.__init__(self) self.setup() def setup(self): self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40)) self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24)) self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47)) self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47)) self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0)) self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218)) self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218)) self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0)) self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text, QtGui.QColor(128, 128, 128)) self.setColor( QtGui.QPalette.Disabled, QtGui.QPalette.ButtonText, QtGui.QColor(128, 128, 128), ) self.setColor( QtGui.QPalette.Disabled, QtGui.QPalette.WindowText, QtGui.QColor(128, 128, 128), ) def create_channel_choose(): # choose channel ChannelChoose = [QComboBox(), QComboBox()] ChannelLabels = [] ChannelChoose[0].addItems(["gray", "red", "green", "blue"]) ChannelChoose[1].addItems(["none", "red", "green", "blue"]) cstr = ["chan to segment:", "chan2 (optional): "] for i in range(2): ChannelLabels.append(QLabel(cstr[i])) if i == 0: ChannelLabels[i].setToolTip( "this is the channel in which the cytoplasm or nuclei exist \ that you want to segment") ChannelChoose[i].setToolTip( "this is the channel in which the cytoplasm or nuclei exist \ that you want to segment") else: ChannelLabels[i].setToolTip( "if cytoplasm model is chosen, and you also have a \ nuclear channel, then choose the nuclear channel for this option") ChannelChoose[i].setToolTip( "if cytoplasm model is chosen, and you also have a \ nuclear channel, then choose the nuclear channel for this option") return ChannelChoose, ChannelLabels class ModelButton(QPushButton): def __init__(self, parent, model_name, text): super().__init__() self.setEnabled(False) self.setText(text) self.setFont(parent.boldfont) self.clicked.connect(lambda: self.press(parent)) self.model_name = model_name if "cyto3" not in model_name else "cyto3" def press(self, parent): parent.compute_segmentation(model_name=self.model_name) class DenoiseButton(QPushButton): def __init__(self, parent, text): super().__init__() self.setEnabled(False) self.model_type = text self.setText(text) self.setFont(parent.medfont) self.clicked.connect(lambda: self.press(parent)) def press(self, parent): if self.model_type == "filter": parent.restore = "filter" normalize_params = parent.get_normalize_params() if (normalize_params["sharpen_radius"] == 0 and normalize_params["smooth_radius"] == 0 and normalize_params["tile_norm_blocksize"] == 0): print( "GUI_ERROR: no filtering settings on (use custom filter settings)") parent.restore = None return parent.restore = self.model_type parent.compute_saturation() elif self.model_type != "none": parent.compute_denoise_model(model_type=self.model_type) else: parent.clear_restore() parent.set_restore_button() class TrainWindow(QDialog): def __init__(self, parent, model_strings): super().__init__(parent) self.setGeometry(100, 100, 900, 550) self.setWindowTitle("train settings") self.win = QWidget(self) self.l0 = QGridLayout() self.win.setLayout(self.l0) yoff = 0 qlabel = QLabel("train model w/ images + _seg.npy in current folder >>") qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold)) qlabel.setAlignment(QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, yoff, 0, 1, 2) # choose initial model yoff += 1 self.ModelChoose = QComboBox() self.ModelChoose.addItems(model_strings) self.ModelChoose.addItems(["scratch"]) self.ModelChoose.setFixedWidth(150) self.ModelChoose.setCurrentIndex(parent.training_params["model_index"]) self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1) qlabel = QLabel("initial model: ") qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, yoff, 0, 1, 1) # choose channels self.ChannelChoose, self.ChannelLabels = create_channel_choose() for i in range(2): yoff += 1 self.ChannelChoose[i].setFixedWidth(150) self.ChannelChoose[i].setCurrentIndex( parent.ChannelChoose[i].currentIndex()) self.l0.addWidget(self.ChannelLabels[i], yoff, 0, 1, 1) self.l0.addWidget(self.ChannelChoose[i], yoff, 1, 1, 1) # choose parameters labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"] self.edits = [] yoff += 1 for i, label in enumerate(labels): qlabel = QLabel(label) qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, i + yoff, 0, 1, 1) self.edits.append(QLineEdit()) self.edits[-1].setText(str(parent.training_params[label])) self.edits[-1].setFixedWidth(200) self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) yoff += 1 use_SGD = "SGD" self.useSGD = QCheckBox(f"{use_SGD}") self.useSGD.setToolTip("use SGD, if unchecked uses AdamW (recommended learning_rate then 0.001)") self.useSGD.setChecked(True) self.l0.addWidget(self.useSGD, i+yoff, 1, 1, 1) yoff += len(labels) yoff += 1 self.use_norm = QCheckBox(f"use restored/filtered image") self.use_norm.setChecked(True) #self.l0.addWidget(self.use_norm, yoff, 0, 2, 4) yoff += 2 qlabel = QLabel( "(to remove files, click cancel then remove \nfrom folder and reopen train window)" ) self.l0.addWidget(qlabel, yoff, 0, 2, 4) # click button yoff += 3 QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel self.buttonBox = QDialogButtonBox(QBtn) self.buttonBox.accepted.connect(lambda: self.accept(parent)) self.buttonBox.rejected.connect(self.reject) self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4) # list files in folder qlabel = QLabel("filenames") qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) self.l0.addWidget(qlabel, 0, 4, 1, 1) qlabel = QLabel("# of masks") qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) self.l0.addWidget(qlabel, 0, 5, 1, 1) for i in range(10): if i > len(parent.train_files) - 1: break elif i == 9 and len(parent.train_files) > 10: label = "..." nmasks = "..." else: label = os.path.split(parent.train_files[i])[-1] nmasks = str(parent.train_labels[i].max()) qlabel = QLabel(label) self.l0.addWidget(qlabel, i + 1, 4, 1, 1) qlabel = QLabel(nmasks) qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, i + 1, 5, 1, 1) def accept(self, parent): # set training params parent.training_params = { "model_index": self.ModelChoose.currentIndex(), "learning_rate": float(self.edits[0].text()), "weight_decay": float(self.edits[1].text()), "n_epochs": int(self.edits[2].text()), "model_name": self.edits[3].text(), "SGD": True if self.useSGD.isChecked() else False, "channels": [self.ChannelChoose[0].currentIndex(), self.ChannelChoose[1].currentIndex()], #"use_norm": True if self.use_norm.isChecked() else False, } self.done(1) class ExampleGUI(QDialog): def __init__(self, parent=None): super(ExampleGUI, self).__init__(parent) self.setGeometry(100, 100, 1300, 900) self.setWindowTitle("GUI layout") self.win = QWidget(self) layout = QGridLayout() self.win.setLayout(layout) guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png") guip_path = str(guip_path.resolve()) pixmap = QPixmap(guip_path) label = QLabel(self) label.setPixmap(pixmap) pixmap.scaled layout.addWidget(label, 0, 0, 1, 1) class HelpWindow(QDialog): def __init__(self, parent=None): super(HelpWindow, self).__init__(parent) self.setGeometry(100, 50, 700, 1000) self.setWindowTitle("cellpose help") self.win = QWidget(self) layout = QGridLayout() self.win.setLayout(layout) text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html") with open(str(text_file.resolve()), "r") as f: text = f.read() label = QLabel(text) label.setFont(QtGui.QFont("Arial", 8)) label.setWordWrap(True) layout.addWidget(label, 0, 0, 1, 1) self.show() class TrainHelpWindow(QDialog): def __init__(self, parent=None): super(TrainHelpWindow, self).__init__(parent) self.setGeometry(100, 50, 700, 300) self.setWindowTitle("training instructions") self.win = QWidget(self) layout = QGridLayout() self.win.setLayout(layout) text_file = pathlib.Path(__file__).parent.joinpath( "guitrainhelpwindowtext.html") with open(str(text_file.resolve()), "r") as f: text = f.read() label = QLabel(text) label.setFont(QtGui.QFont("Arial", 8)) label.setWordWrap(True) layout.addWidget(label, 0, 0, 1, 1) self.show() class ViewBoxNoRightDrag(pg.ViewBox): def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True, invertY=False, enableMenu=True, name=None, invertX=False): pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY, enableMenu, name, invertX) self.parent = parent self.axHistoryPointer = -1 def keyPressEvent(self, ev): """ This routine should capture key presses in the current view box. The following events are implemented: +/= : moves forward in the zooming stack (if it exists) - : moves backward in the zooming stack (if it exists) """ ev.accept() if ev.text() == "-": self.scaleBy([1.1, 1.1]) elif ev.text() in ["+", "="]: self.scaleBy([0.9, 0.9]) else: ev.ignore() class ImageDraw(pg.ImageItem): """ **Bases:** :class:`GraphicsObject ` GraphicsObject displaying an image. Optimized for rapid update (ie video display). This item displays either a 2D numpy array (height, width) or a 3D array (height, width, RGBa). This array is optionally scaled (see :func:`setLevels `) and/or colored with a lookup table (see :func:`setLookupTable `) before being displayed. ImageItem is frequently used in conjunction with :class:`HistogramLUTItem ` or :class:`HistogramLUTWidget ` to provide a GUI for controlling the levels and lookup table used to display the image. """ sigImageChanged = QtCore.Signal() def __init__(self, image=None, viewbox=None, parent=None, **kargs): super(ImageDraw, self).__init__() #self.image=None #self.viewbox=viewbox self.levels = np.array([0, 255]) self.lut = None self.autoDownsample = False self.axisOrder = "row-major" self.removable = False self.parent = parent #kernel[1,1] = 1 self.setDrawKernel(kernel_size=self.parent.brush_size) self.parent.current_stroke = [] self.parent.in_stroke = False def mouseClickEvent(self, ev): if (self.parent.masksOn or self.parent.outlinesOn) and not self.parent.removing_region: is_right_click = ev.button() == QtCore.Qt.RightButton if self.parent.loaded \ and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\ and not self.parent.deleting_multiple: if not self.parent.in_stroke: ev.accept() self.create_start(ev.pos()) self.parent.stroke_appended = False self.parent.in_stroke = True self.drawAt(ev.pos(), ev) else: ev.accept() self.end_stroke() self.parent.in_stroke = False elif not self.parent.in_stroke: y, x = int(ev.pos().y()), int(ev.pos().x()) if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx: if ev.button() == QtCore.Qt.LeftButton and not ev.double(): idx = self.parent.cellpix[self.parent.currentZ][y, x] if idx > 0: if ev.modifiers() & QtCore.Qt.ControlModifier: # delete mask selected self.parent.remove_cell(idx) elif ev.modifiers() & QtCore.Qt.AltModifier: self.parent.merge_cells(idx) elif self.parent.masksOn and not self.parent.deleting_multiple: self.parent.unselect_cell() self.parent.select_cell(idx) elif self.parent.deleting_multiple: if idx in self.parent.removing_cells_list: self.parent.unselect_cell_multi(idx) self.parent.removing_cells_list.remove(idx) else: self.parent.select_cell_multi(idx) self.parent.removing_cells_list.append(idx) elif self.parent.masksOn and not self.parent.deleting_multiple: self.parent.unselect_cell() def mouseDragEvent(self, ev): ev.ignore() return def hoverEvent(self, ev): #QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.CrossCursor) if self.parent.in_stroke: if self.parent.in_stroke: # continue stroke if not at start self.drawAt(ev.pos()) if self.is_at_start(ev.pos()): #self.parent.in_stroke = False self.end_stroke() else: ev.acceptClicks(QtCore.Qt.RightButton) #ev.acceptClicks(QtCore.Qt.LeftButton) def create_start(self, pos): self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False, pen=pg.mkPen(color=(255, 0, 0), width=self.parent.brush_size), size=max(3 * 2, self.parent.brush_size * 1.8 * 2), brush=None) self.parent.p0.addItem(self.scatter) def is_at_start(self, pos): thresh_out = max(6, self.parent.brush_size * 3) thresh_in = max(3, self.parent.brush_size * 1.8) # first check if you ever left the start if len(self.parent.current_stroke) > 3: stroke = np.array(self.parent.current_stroke) dist = (((stroke[1:, 1:] - stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5 dist = dist.flatten() #print(dist) has_left = (dist > thresh_out).nonzero()[0] if len(has_left) > 0: first_left = np.sort(has_left)[0] has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum() if has_returned > 0: return True else: return False else: return False def end_stroke(self): self.parent.p0.removeItem(self.scatter) if not self.parent.stroke_appended: self.parent.strokes.append(self.parent.current_stroke) self.parent.stroke_appended = True self.parent.current_stroke = np.array(self.parent.current_stroke) ioutline = self.parent.current_stroke[:, 3] == 1 self.parent.current_point_set.append( list(self.parent.current_stroke[ioutline])) self.parent.current_stroke = [] if self.parent.autosave: self.parent.add_set() if len(self.parent.current_point_set) and len( self.parent.current_point_set[0]) > 0 and self.parent.autosave: self.parent.add_set() self.parent.in_stroke = False def tabletEvent(self, ev): pass #print(ev.device()) #print(ev.pointerType()) #print(ev.pressure()) def drawAt(self, pos, ev=None): mask = self.strokemask stroke = self.parent.current_stroke pos = [int(pos.y()), int(pos.x())] dk = self.drawKernel kc = self.drawKernelCenter sx = [0, dk.shape[0]] sy = [0, dk.shape[1]] tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]] ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]] kcent = kc.copy() if tx[0] <= 0: sx[0] = 0 sx[1] = kc[0] + 1 tx = sx kcent[0] = 0 if ty[0] <= 0: sy[0] = 0 sy[1] = kc[1] + 1 ty = sy kcent[1] = 0 if tx[1] >= self.parent.Ly - 1: sx[0] = dk.shape[0] - kc[0] - 1 sx[1] = dk.shape[0] tx[0] = self.parent.Ly - kc[0] - 1 tx[1] = self.parent.Ly kcent[0] = tx[1] - tx[0] - 1 if ty[1] >= self.parent.Lx - 1: sy[0] = dk.shape[1] - kc[1] - 1 sy[1] = dk.shape[1] ty[0] = self.parent.Lx - kc[1] - 1 ty[1] = self.parent.Lx kcent[1] = ty[1] - ty[0] - 1 ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1])) ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1])) self.image[ts] = mask[ss] for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)): for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)): iscent = np.logical_and(kx == kcent[0], ky == kcent[1]) stroke.append([self.parent.currentZ, x, y, iscent]) self.updateImage() def setDrawKernel(self, kernel_size=3): bs = kernel_size kernel = np.ones((bs, bs), np.uint8) self.drawKernel = kernel self.drawKernelCenter = [ int(np.floor(kernel.shape[0] / 2)), int(np.floor(kernel.shape[1] / 2)) ] onmask = 255 * kernel[:, :, np.newaxis] offmask = np.zeros((bs, bs, 1)) opamask = 100 * kernel[:, :, np.newaxis] self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1) self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)