Spaces:
Sleeping
Sleeping
| """ | |
| Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. | |
| """ | |
| from qtpy import QtGui, QtCore, QtWidgets | |
| from qtpy.QtGui import QPainter, QPixmap | |
| from qtpy.QtWidgets import QApplication, QRadioButton, QWidget, QDialog, QButtonGroup, QSlider, QStyle, QStyleOptionSlider, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox | |
| import pyqtgraph as pg | |
| from pyqtgraph import functions as fn | |
| from pyqtgraph import Point | |
| import numpy as np | |
| import pathlib, os | |
| def stylesheet(): | |
| return """ | |
| QToolTip { | |
| background-color: black; | |
| color: white; | |
| border: black solid 1px | |
| } | |
| QComboBox {color: white; | |
| background-color: rgb(40,40,40);} | |
| QComboBox::item:enabled { color: white; | |
| background-color: rgb(40,40,40); | |
| selection-color: white; | |
| selection-background-color: rgb(50,100,50);} | |
| QComboBox::item:!enabled { | |
| background-color: rgb(40,40,40); | |
| color: rgb(100,100,100); | |
| } | |
| QScrollArea > QWidget > QWidget | |
| { | |
| background: transparent; | |
| border: none; | |
| margin: 0px 0px 0px 0px; | |
| } | |
| QGroupBox | |
| { border: 1px solid white; color: rgb(255,255,255); | |
| border-radius: 6px; | |
| margin-top: 8px; | |
| padding: 0px 0px;} | |
| QPushButton:pressed {Text-align: center; | |
| background-color: rgb(150,50,150); | |
| border-color: white; | |
| color:white;} | |
| QToolTip { | |
| background-color: black; | |
| color: white; | |
| border: black solid 1px | |
| } | |
| QPushButton:!pressed {Text-align: center; | |
| background-color: rgb(50,50,50); | |
| border-color: white; | |
| color:white;} | |
| QToolTip { | |
| background-color: black; | |
| color: white; | |
| border: black solid 1px | |
| } | |
| QPushButton:disabled {Text-align: center; | |
| background-color: rgb(30,30,30); | |
| border-color: white; | |
| color:rgb(80,80,80);} | |
| QToolTip { | |
| background-color: black; | |
| color: white; | |
| border: black solid 1px | |
| } | |
| """ | |
| class DarkPalette(QtGui.QPalette): | |
| """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application. | |
| (from pykilosort/kilosort4) | |
| """ | |
| def __init__(self): | |
| QtGui.QPalette.__init__(self) | |
| self.setup() | |
| def setup(self): | |
| self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40)) | |
| self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255)) | |
| self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24)) | |
| self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47)) | |
| self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255)) | |
| self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255)) | |
| self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255)) | |
| self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47)) | |
| self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255)) | |
| self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0)) | |
| self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218)) | |
| self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218)) | |
| self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0)) | |
| self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text, | |
| QtGui.QColor(128, 128, 128)) | |
| self.setColor( | |
| QtGui.QPalette.Disabled, | |
| QtGui.QPalette.ButtonText, | |
| QtGui.QColor(128, 128, 128), | |
| ) | |
| self.setColor( | |
| QtGui.QPalette.Disabled, | |
| QtGui.QPalette.WindowText, | |
| QtGui.QColor(128, 128, 128), | |
| ) | |
| def create_channel_choose(): | |
| # choose channel | |
| ChannelChoose = [QComboBox(), QComboBox()] | |
| ChannelLabels = [] | |
| ChannelChoose[0].addItems(["gray", "red", "green", "blue"]) | |
| ChannelChoose[1].addItems(["none", "red", "green", "blue"]) | |
| cstr = ["chan to segment:", "chan2 (optional): "] | |
| for i in range(2): | |
| ChannelLabels.append(QLabel(cstr[i])) | |
| if i == 0: | |
| ChannelLabels[i].setToolTip( | |
| "this is the channel in which the cytoplasm or nuclei exist \ | |
| that you want to segment") | |
| ChannelChoose[i].setToolTip( | |
| "this is the channel in which the cytoplasm or nuclei exist \ | |
| that you want to segment") | |
| else: | |
| ChannelLabels[i].setToolTip( | |
| "if <em>cytoplasm</em> model is chosen, and you also have a \ | |
| nuclear channel, then choose the nuclear channel for this option") | |
| ChannelChoose[i].setToolTip( | |
| "if <em>cytoplasm</em> model is chosen, and you also have a \ | |
| nuclear channel, then choose the nuclear channel for this option") | |
| return ChannelChoose, ChannelLabels | |
| class ModelButton(QPushButton): | |
| def __init__(self, parent, model_name, text): | |
| super().__init__() | |
| self.setEnabled(False) | |
| self.setText(text) | |
| self.setFont(parent.boldfont) | |
| self.clicked.connect(lambda: self.press(parent)) | |
| self.model_name = model_name if "cyto3" not in model_name else "cyto3" | |
| def press(self, parent): | |
| parent.compute_segmentation(model_name=self.model_name) | |
| class DenoiseButton(QPushButton): | |
| def __init__(self, parent, text): | |
| super().__init__() | |
| self.setEnabled(False) | |
| self.model_type = text | |
| self.setText(text) | |
| self.setFont(parent.medfont) | |
| self.clicked.connect(lambda: self.press(parent)) | |
| def press(self, parent): | |
| if self.model_type == "filter": | |
| parent.restore = "filter" | |
| normalize_params = parent.get_normalize_params() | |
| if (normalize_params["sharpen_radius"] == 0 and | |
| normalize_params["smooth_radius"] == 0 and | |
| normalize_params["tile_norm_blocksize"] == 0): | |
| print( | |
| "GUI_ERROR: no filtering settings on (use custom filter settings)") | |
| parent.restore = None | |
| return | |
| parent.restore = self.model_type | |
| parent.compute_saturation() | |
| elif self.model_type != "none": | |
| parent.compute_denoise_model(model_type=self.model_type) | |
| else: | |
| parent.clear_restore() | |
| parent.set_restore_button() | |
| class TrainWindow(QDialog): | |
| def __init__(self, parent, model_strings): | |
| super().__init__(parent) | |
| self.setGeometry(100, 100, 900, 550) | |
| self.setWindowTitle("train settings") | |
| self.win = QWidget(self) | |
| self.l0 = QGridLayout() | |
| self.win.setLayout(self.l0) | |
| yoff = 0 | |
| qlabel = QLabel("train model w/ images + _seg.npy in current folder >>") | |
| qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold)) | |
| qlabel.setAlignment(QtCore.Qt.AlignVCenter) | |
| self.l0.addWidget(qlabel, yoff, 0, 1, 2) | |
| # choose initial model | |
| yoff += 1 | |
| self.ModelChoose = QComboBox() | |
| self.ModelChoose.addItems(model_strings) | |
| self.ModelChoose.addItems(["scratch"]) | |
| self.ModelChoose.setFixedWidth(150) | |
| self.ModelChoose.setCurrentIndex(parent.training_params["model_index"]) | |
| self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1) | |
| qlabel = QLabel("initial model: ") | |
| qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) | |
| self.l0.addWidget(qlabel, yoff, 0, 1, 1) | |
| # choose channels | |
| self.ChannelChoose, self.ChannelLabels = create_channel_choose() | |
| for i in range(2): | |
| yoff += 1 | |
| self.ChannelChoose[i].setFixedWidth(150) | |
| self.ChannelChoose[i].setCurrentIndex( | |
| parent.ChannelChoose[i].currentIndex()) | |
| self.l0.addWidget(self.ChannelLabels[i], yoff, 0, 1, 1) | |
| self.l0.addWidget(self.ChannelChoose[i], yoff, 1, 1, 1) | |
| # choose parameters | |
| labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"] | |
| self.edits = [] | |
| yoff += 1 | |
| for i, label in enumerate(labels): | |
| qlabel = QLabel(label) | |
| qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) | |
| self.l0.addWidget(qlabel, i + yoff, 0, 1, 1) | |
| self.edits.append(QLineEdit()) | |
| self.edits[-1].setText(str(parent.training_params[label])) | |
| self.edits[-1].setFixedWidth(200) | |
| self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) | |
| yoff += 1 | |
| use_SGD = "SGD" | |
| self.useSGD = QCheckBox(f"{use_SGD}") | |
| self.useSGD.setToolTip("use SGD, if unchecked uses AdamW (recommended learning_rate then 0.001)") | |
| self.useSGD.setChecked(True) | |
| self.l0.addWidget(self.useSGD, i+yoff, 1, 1, 1) | |
| yoff += len(labels) | |
| yoff += 1 | |
| self.use_norm = QCheckBox(f"use restored/filtered image") | |
| self.use_norm.setChecked(True) | |
| #self.l0.addWidget(self.use_norm, yoff, 0, 2, 4) | |
| yoff += 2 | |
| qlabel = QLabel( | |
| "(to remove files, click cancel then remove \nfrom folder and reopen train window)" | |
| ) | |
| self.l0.addWidget(qlabel, yoff, 0, 2, 4) | |
| # click button | |
| yoff += 3 | |
| QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel | |
| self.buttonBox = QDialogButtonBox(QBtn) | |
| self.buttonBox.accepted.connect(lambda: self.accept(parent)) | |
| self.buttonBox.rejected.connect(self.reject) | |
| self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4) | |
| # list files in folder | |
| qlabel = QLabel("filenames") | |
| qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) | |
| self.l0.addWidget(qlabel, 0, 4, 1, 1) | |
| qlabel = QLabel("# of masks") | |
| qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) | |
| self.l0.addWidget(qlabel, 0, 5, 1, 1) | |
| for i in range(10): | |
| if i > len(parent.train_files) - 1: | |
| break | |
| elif i == 9 and len(parent.train_files) > 10: | |
| label = "..." | |
| nmasks = "..." | |
| else: | |
| label = os.path.split(parent.train_files[i])[-1] | |
| nmasks = str(parent.train_labels[i].max()) | |
| qlabel = QLabel(label) | |
| self.l0.addWidget(qlabel, i + 1, 4, 1, 1) | |
| qlabel = QLabel(nmasks) | |
| qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) | |
| self.l0.addWidget(qlabel, i + 1, 5, 1, 1) | |
| def accept(self, parent): | |
| # set training params | |
| parent.training_params = { | |
| "model_index": self.ModelChoose.currentIndex(), | |
| "learning_rate": float(self.edits[0].text()), | |
| "weight_decay": float(self.edits[1].text()), | |
| "n_epochs": int(self.edits[2].text()), | |
| "model_name": self.edits[3].text(), | |
| "SGD": True if self.useSGD.isChecked() else False, | |
| "channels": [self.ChannelChoose[0].currentIndex(), | |
| self.ChannelChoose[1].currentIndex()], | |
| #"use_norm": True if self.use_norm.isChecked() else False, | |
| } | |
| self.done(1) | |
| class ExampleGUI(QDialog): | |
| def __init__(self, parent=None): | |
| super(ExampleGUI, self).__init__(parent) | |
| self.setGeometry(100, 100, 1300, 900) | |
| self.setWindowTitle("GUI layout") | |
| self.win = QWidget(self) | |
| layout = QGridLayout() | |
| self.win.setLayout(layout) | |
| guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png") | |
| guip_path = str(guip_path.resolve()) | |
| pixmap = QPixmap(guip_path) | |
| label = QLabel(self) | |
| label.setPixmap(pixmap) | |
| pixmap.scaled | |
| layout.addWidget(label, 0, 0, 1, 1) | |
| class HelpWindow(QDialog): | |
| def __init__(self, parent=None): | |
| super(HelpWindow, self).__init__(parent) | |
| self.setGeometry(100, 50, 700, 1000) | |
| self.setWindowTitle("cellpose help") | |
| self.win = QWidget(self) | |
| layout = QGridLayout() | |
| self.win.setLayout(layout) | |
| text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html") | |
| with open(str(text_file.resolve()), "r") as f: | |
| text = f.read() | |
| label = QLabel(text) | |
| label.setFont(QtGui.QFont("Arial", 8)) | |
| label.setWordWrap(True) | |
| layout.addWidget(label, 0, 0, 1, 1) | |
| self.show() | |
| class TrainHelpWindow(QDialog): | |
| def __init__(self, parent=None): | |
| super(TrainHelpWindow, self).__init__(parent) | |
| self.setGeometry(100, 50, 700, 300) | |
| self.setWindowTitle("training instructions") | |
| self.win = QWidget(self) | |
| layout = QGridLayout() | |
| self.win.setLayout(layout) | |
| text_file = pathlib.Path(__file__).parent.joinpath( | |
| "guitrainhelpwindowtext.html") | |
| with open(str(text_file.resolve()), "r") as f: | |
| text = f.read() | |
| label = QLabel(text) | |
| label.setFont(QtGui.QFont("Arial", 8)) | |
| label.setWordWrap(True) | |
| layout.addWidget(label, 0, 0, 1, 1) | |
| self.show() | |
| class ViewBoxNoRightDrag(pg.ViewBox): | |
| def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True, | |
| invertY=False, enableMenu=True, name=None, invertX=False): | |
| pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY, | |
| enableMenu, name, invertX) | |
| self.parent = parent | |
| self.axHistoryPointer = -1 | |
| def keyPressEvent(self, ev): | |
| """ | |
| This routine should capture key presses in the current view box. | |
| The following events are implemented: | |
| +/= : moves forward in the zooming stack (if it exists) | |
| - : moves backward in the zooming stack (if it exists) | |
| """ | |
| ev.accept() | |
| if ev.text() == "-": | |
| self.scaleBy([1.1, 1.1]) | |
| elif ev.text() in ["+", "="]: | |
| self.scaleBy([0.9, 0.9]) | |
| else: | |
| ev.ignore() | |
| class ImageDraw(pg.ImageItem): | |
| """ | |
| **Bases:** :class:`GraphicsObject <pyqtgraph.GraphicsObject>` | |
| GraphicsObject displaying an image. Optimized for rapid update (ie video display). | |
| This item displays either a 2D numpy array (height, width) or | |
| a 3D array (height, width, RGBa). This array is optionally scaled (see | |
| :func:`setLevels <pyqtgraph.ImageItem.setLevels>`) and/or colored | |
| with a lookup table (see :func:`setLookupTable <pyqtgraph.ImageItem.setLookupTable>`) | |
| before being displayed. | |
| ImageItem is frequently used in conjunction with | |
| :class:`HistogramLUTItem <pyqtgraph.HistogramLUTItem>` or | |
| :class:`HistogramLUTWidget <pyqtgraph.HistogramLUTWidget>` to provide a GUI | |
| for controlling the levels and lookup table used to display the image. | |
| """ | |
| sigImageChanged = QtCore.Signal() | |
| def __init__(self, image=None, viewbox=None, parent=None, **kargs): | |
| super(ImageDraw, self).__init__() | |
| #self.image=None | |
| #self.viewbox=viewbox | |
| self.levels = np.array([0, 255]) | |
| self.lut = None | |
| self.autoDownsample = False | |
| self.axisOrder = "row-major" | |
| self.removable = False | |
| self.parent = parent | |
| #kernel[1,1] = 1 | |
| self.setDrawKernel(kernel_size=self.parent.brush_size) | |
| self.parent.current_stroke = [] | |
| self.parent.in_stroke = False | |
| def mouseClickEvent(self, ev): | |
| if (self.parent.masksOn or | |
| self.parent.outlinesOn) and not self.parent.removing_region: | |
| is_right_click = ev.button() == QtCore.Qt.RightButton | |
| if self.parent.loaded \ | |
| and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\ | |
| and not self.parent.deleting_multiple: | |
| if not self.parent.in_stroke: | |
| ev.accept() | |
| self.create_start(ev.pos()) | |
| self.parent.stroke_appended = False | |
| self.parent.in_stroke = True | |
| self.drawAt(ev.pos(), ev) | |
| else: | |
| ev.accept() | |
| self.end_stroke() | |
| self.parent.in_stroke = False | |
| elif not self.parent.in_stroke: | |
| y, x = int(ev.pos().y()), int(ev.pos().x()) | |
| if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx: | |
| if ev.button() == QtCore.Qt.LeftButton and not ev.double(): | |
| idx = self.parent.cellpix[self.parent.currentZ][y, x] | |
| if idx > 0: | |
| if ev.modifiers() & QtCore.Qt.ControlModifier: | |
| # delete mask selected | |
| self.parent.remove_cell(idx) | |
| elif ev.modifiers() & QtCore.Qt.AltModifier: | |
| self.parent.merge_cells(idx) | |
| elif self.parent.masksOn and not self.parent.deleting_multiple: | |
| self.parent.unselect_cell() | |
| self.parent.select_cell(idx) | |
| elif self.parent.deleting_multiple: | |
| if idx in self.parent.removing_cells_list: | |
| self.parent.unselect_cell_multi(idx) | |
| self.parent.removing_cells_list.remove(idx) | |
| else: | |
| self.parent.select_cell_multi(idx) | |
| self.parent.removing_cells_list.append(idx) | |
| elif self.parent.masksOn and not self.parent.deleting_multiple: | |
| self.parent.unselect_cell() | |
| def mouseDragEvent(self, ev): | |
| ev.ignore() | |
| return | |
| def hoverEvent(self, ev): | |
| #QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.CrossCursor) | |
| if self.parent.in_stroke: | |
| if self.parent.in_stroke: | |
| # continue stroke if not at start | |
| self.drawAt(ev.pos()) | |
| if self.is_at_start(ev.pos()): | |
| #self.parent.in_stroke = False | |
| self.end_stroke() | |
| else: | |
| ev.acceptClicks(QtCore.Qt.RightButton) | |
| #ev.acceptClicks(QtCore.Qt.LeftButton) | |
| def create_start(self, pos): | |
| self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False, | |
| pen=pg.mkPen(color=(255, 0, 0), | |
| width=self.parent.brush_size), | |
| size=max(3 * 2, | |
| self.parent.brush_size * 1.8 * 2), | |
| brush=None) | |
| self.parent.p0.addItem(self.scatter) | |
| def is_at_start(self, pos): | |
| thresh_out = max(6, self.parent.brush_size * 3) | |
| thresh_in = max(3, self.parent.brush_size * 1.8) | |
| # first check if you ever left the start | |
| if len(self.parent.current_stroke) > 3: | |
| stroke = np.array(self.parent.current_stroke) | |
| dist = (((stroke[1:, 1:] - | |
| stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5 | |
| dist = dist.flatten() | |
| #print(dist) | |
| has_left = (dist > thresh_out).nonzero()[0] | |
| if len(has_left) > 0: | |
| first_left = np.sort(has_left)[0] | |
| has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum() | |
| if has_returned > 0: | |
| return True | |
| else: | |
| return False | |
| else: | |
| return False | |
| def end_stroke(self): | |
| self.parent.p0.removeItem(self.scatter) | |
| if not self.parent.stroke_appended: | |
| self.parent.strokes.append(self.parent.current_stroke) | |
| self.parent.stroke_appended = True | |
| self.parent.current_stroke = np.array(self.parent.current_stroke) | |
| ioutline = self.parent.current_stroke[:, 3] == 1 | |
| self.parent.current_point_set.append( | |
| list(self.parent.current_stroke[ioutline])) | |
| self.parent.current_stroke = [] | |
| if self.parent.autosave: | |
| self.parent.add_set() | |
| if len(self.parent.current_point_set) and len( | |
| self.parent.current_point_set[0]) > 0 and self.parent.autosave: | |
| self.parent.add_set() | |
| self.parent.in_stroke = False | |
| def tabletEvent(self, ev): | |
| pass | |
| #print(ev.device()) | |
| #print(ev.pointerType()) | |
| #print(ev.pressure()) | |
| def drawAt(self, pos, ev=None): | |
| mask = self.strokemask | |
| stroke = self.parent.current_stroke | |
| pos = [int(pos.y()), int(pos.x())] | |
| dk = self.drawKernel | |
| kc = self.drawKernelCenter | |
| sx = [0, dk.shape[0]] | |
| sy = [0, dk.shape[1]] | |
| tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]] | |
| ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]] | |
| kcent = kc.copy() | |
| if tx[0] <= 0: | |
| sx[0] = 0 | |
| sx[1] = kc[0] + 1 | |
| tx = sx | |
| kcent[0] = 0 | |
| if ty[0] <= 0: | |
| sy[0] = 0 | |
| sy[1] = kc[1] + 1 | |
| ty = sy | |
| kcent[1] = 0 | |
| if tx[1] >= self.parent.Ly - 1: | |
| sx[0] = dk.shape[0] - kc[0] - 1 | |
| sx[1] = dk.shape[0] | |
| tx[0] = self.parent.Ly - kc[0] - 1 | |
| tx[1] = self.parent.Ly | |
| kcent[0] = tx[1] - tx[0] - 1 | |
| if ty[1] >= self.parent.Lx - 1: | |
| sy[0] = dk.shape[1] - kc[1] - 1 | |
| sy[1] = dk.shape[1] | |
| ty[0] = self.parent.Lx - kc[1] - 1 | |
| ty[1] = self.parent.Lx | |
| kcent[1] = ty[1] - ty[0] - 1 | |
| ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1])) | |
| ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1])) | |
| self.image[ts] = mask[ss] | |
| for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)): | |
| for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)): | |
| iscent = np.logical_and(kx == kcent[0], ky == kcent[1]) | |
| stroke.append([self.parent.currentZ, x, y, iscent]) | |
| self.updateImage() | |
| def setDrawKernel(self, kernel_size=3): | |
| bs = kernel_size | |
| kernel = np.ones((bs, bs), np.uint8) | |
| self.drawKernel = kernel | |
| self.drawKernelCenter = [ | |
| int(np.floor(kernel.shape[0] / 2)), | |
| int(np.floor(kernel.shape[1] / 2)) | |
| ] | |
| onmask = 255 * kernel[:, :, np.newaxis] | |
| offmask = np.zeros((bs, bs, 1)) | |
| opamask = 100 * kernel[:, :, np.newaxis] | |
| self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1) | |
| self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1) | |