Spaces:
Running
Running
Commit
·
b6bb35e
1
Parent(s):
26e8a2f
initial commit
Browse files- LICENSE +201 -0
- README.md +5 -5
- app.py +591 -0
- checkpoints/ScribblePrompt_unet_v1_nf192_res128.pt +3 -0
- network.py +123 -0
- predictor.py +242 -0
- requirements.txt +4 -0
- test_examples/COBRE.jpg +0 -0
- test_examples/SCR.jpg +0 -0
- test_examples/TotalSegmentator.jpg +0 -0
- test_examples/TotalSegmentator_2.jpg +0 -0
- val_od_examples/ACDC.jpg +0 -0
- val_od_examples/BTCV.jpg +0 -0
- val_od_examples/BUID.jpg +0 -0
- val_od_examples/DRIVE.jpg +0 -0
- val_od_examples/HipXRay.jpg +0 -0
- val_od_examples/PanDental.jpg +0 -0
- val_od_examples/SCD.jpg +0 -0
- val_od_examples/SpineWeb.jpg +0 -0
- val_od_examples/WBC.jpg +0 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Scribbleprompt
|
| 3 |
+
emoji: 🩻
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.41.0
|
| 8 |
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
app.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import pathlib
|
| 8 |
+
|
| 9 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 10 |
+
|
| 11 |
+
from predictor import Predictor
|
| 12 |
+
|
| 13 |
+
RES = 256
|
| 14 |
+
|
| 15 |
+
test_example_dir = pathlib.Path("./test_examples")
|
| 16 |
+
test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
|
| 17 |
+
|
| 18 |
+
val_example_dir = pathlib.Path("./val_od_examples")
|
| 19 |
+
val_examples = [str(val_example_dir / x) for x in sorted(os.listdir(val_example_dir))]
|
| 20 |
+
|
| 21 |
+
default_example = test_example_dir / "TotalSegmentator_2.jpg"
|
| 22 |
+
exp_dir = pathlib.Path('./checkpoints')
|
| 23 |
+
default_model = 'ScribblePrompt-Unet'
|
| 24 |
+
|
| 25 |
+
model_dict = {
|
| 26 |
+
'ScribblePrompt-Unet': 'ScribblePrompt_unet_v1_nf192_res128.pt'
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
# Model initialization functions
|
| 31 |
+
# -----------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def load_model(exp_key: str = default_model):
|
| 34 |
+
fpath = exp_dir / model_dict.get(exp_key)
|
| 35 |
+
exp = Predictor(fpath)
|
| 36 |
+
return exp, None
|
| 37 |
+
|
| 38 |
+
# -----------------------------------------------------------------------------
|
| 39 |
+
# Vizualization functions
|
| 40 |
+
# -----------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def _get_overlay(img, lay, const_color="l_blue"):
|
| 43 |
+
"""
|
| 44 |
+
Helper function for preparing overlay
|
| 45 |
+
"""
|
| 46 |
+
assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape)
|
| 47 |
+
|
| 48 |
+
if img.ndim == 2:
|
| 49 |
+
img = np.repeat(img[...,None], 3, axis=-1)
|
| 50 |
+
|
| 51 |
+
assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape)
|
| 52 |
+
|
| 53 |
+
if const_color == "blue":
|
| 54 |
+
const_color = 255*np.array([0, 0, 1])
|
| 55 |
+
elif const_color == "green":
|
| 56 |
+
const_color = 255*np.array([0, 1, 0])
|
| 57 |
+
elif const_color == "red":
|
| 58 |
+
const_color = 255*np.array([1, 0, 0])
|
| 59 |
+
elif const_color == "l_blue":
|
| 60 |
+
const_color = np.array([31, 119, 180])
|
| 61 |
+
elif const_color == "orange":
|
| 62 |
+
const_color = np.array([255, 127, 14])
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
| 66 |
+
x,y = np.nonzero(lay)
|
| 67 |
+
for i in range(img.shape[-1]):
|
| 68 |
+
img[x,y,i] = const_color[i]
|
| 69 |
+
|
| 70 |
+
return img
|
| 71 |
+
|
| 72 |
+
def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
|
| 73 |
+
"""
|
| 74 |
+
Overlay the ground truth mask and scribbles on the image if provided
|
| 75 |
+
"""
|
| 76 |
+
assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape)
|
| 77 |
+
output = np.repeat(img[...,None], 3, axis=-1)
|
| 78 |
+
|
| 79 |
+
if mask is not None:
|
| 80 |
+
|
| 81 |
+
assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape)
|
| 82 |
+
|
| 83 |
+
if contour:
|
| 84 |
+
contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
| 85 |
+
cv2.drawContours(output, contours[0], -1, (0, 255, 0), 1)
|
| 86 |
+
else:
|
| 87 |
+
mask_overlay = _get_overlay(img, mask)
|
| 88 |
+
mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
|
| 89 |
+
output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2))
|
| 90 |
+
|
| 91 |
+
if scribbles is not None:
|
| 92 |
+
pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green")
|
| 93 |
+
cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output)
|
| 94 |
+
neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red")
|
| 95 |
+
cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output)
|
| 96 |
+
|
| 97 |
+
return output
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=True):
|
| 101 |
+
"""
|
| 102 |
+
Visualize image with clicks, scribbles, predicted mask overlaid
|
| 103 |
+
"""
|
| 104 |
+
assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img))
|
| 105 |
+
if mask is not None:
|
| 106 |
+
if isinstance(mask, torch.Tensor):
|
| 107 |
+
mask = mask.cpu().numpy()
|
| 108 |
+
|
| 109 |
+
if binary and mask is not None:
|
| 110 |
+
mask = 1*(mask > 0.5)
|
| 111 |
+
|
| 112 |
+
out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
|
| 113 |
+
|
| 114 |
+
if point_coords is not None:
|
| 115 |
+
for i,(col,row) in enumerate(point_coords):
|
| 116 |
+
if point_labels[i] == 1:
|
| 117 |
+
cv2.circle(out,(col, row), 2, (0,255,0), -1)
|
| 118 |
+
else:
|
| 119 |
+
cv2.circle(out,(col, row), 2, (255,0,0), -1)
|
| 120 |
+
|
| 121 |
+
if bbox_coords is not None:
|
| 122 |
+
for i in range(len(bbox_coords)//2):
|
| 123 |
+
cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), 1)
|
| 124 |
+
if len(bbox_coords) % 2 == 1:
|
| 125 |
+
cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
|
| 126 |
+
|
| 127 |
+
return out
|
| 128 |
+
|
| 129 |
+
# -----------------------------------------------------------------------------
|
| 130 |
+
# Collect scribbles
|
| 131 |
+
# -----------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img, label: int):
|
| 134 |
+
"""
|
| 135 |
+
Record scribbles
|
| 136 |
+
"""
|
| 137 |
+
assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks))
|
| 138 |
+
|
| 139 |
+
if scribble_img is not None:
|
| 140 |
+
|
| 141 |
+
color_mask = scribble_img.get('mask')
|
| 142 |
+
scribble_mask = color_mask[...,0]/255
|
| 143 |
+
|
| 144 |
+
not_same = (scribble_mask != last_scribble_mask)
|
| 145 |
+
if not isinstance(not_same, bool):
|
| 146 |
+
not_same = not_same.any()
|
| 147 |
+
|
| 148 |
+
if not_same:
|
| 149 |
+
# In case any scribbles were removed
|
| 150 |
+
corrected_scribble_masks = np.stack(2*[(scribble_mask > 0)], axis=0)*seperate_scribble_masks
|
| 151 |
+
corrected_last_scribble_mask = last_scribble_mask*(scribble_mask > 0)
|
| 152 |
+
|
| 153 |
+
delta = (scribble_mask - corrected_last_scribble_mask) > 0
|
| 154 |
+
new_scribbles = scribble_mask * delta
|
| 155 |
+
corrected_scribble_masks[label,...] = np.clip(corrected_scribble_masks[label,...] + new_scribbles, a_min=0, a_max=1)
|
| 156 |
+
|
| 157 |
+
last_scribble_mask = scribble_mask
|
| 158 |
+
seperate_scribble_masks = corrected_scribble_masks
|
| 159 |
+
|
| 160 |
+
return seperate_scribble_masks, last_scribble_mask
|
| 161 |
+
|
| 162 |
+
def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode):
|
| 163 |
+
"""
|
| 164 |
+
Make predictions
|
| 165 |
+
"""
|
| 166 |
+
box = None
|
| 167 |
+
if len(bbox_coords) == 1:
|
| 168 |
+
gr.Error("Please click a second time to define the bounding box")
|
| 169 |
+
box = None
|
| 170 |
+
elif len(bbox_coords) == 2:
|
| 171 |
+
box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) # B x n x 4
|
| 172 |
+
|
| 173 |
+
if seperate_scribble_masks is not None:
|
| 174 |
+
scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device)
|
| 175 |
+
else:
|
| 176 |
+
scribble = None
|
| 177 |
+
|
| 178 |
+
prompts = dict(
|
| 179 |
+
img=torch.from_numpy(input_img)[None,None,...].to(device)/255,
|
| 180 |
+
point_coords=torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None,
|
| 181 |
+
point_labels=torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None,
|
| 182 |
+
scribble=scribble,
|
| 183 |
+
mask_input=low_res_mask.to(device) if low_res_mask is not None else None,
|
| 184 |
+
box=box,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
mask, img_features, low_res_mask = predictor.predict(prompts, img_features, multimask_mode=multimask_mode)
|
| 188 |
+
|
| 189 |
+
return mask, img_features, low_res_mask
|
| 190 |
+
|
| 191 |
+
def refresh_predictions(predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
| 192 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
| 193 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode):
|
| 194 |
+
|
| 195 |
+
# Record any new scribbles
|
| 196 |
+
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
| 197 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
| 198 |
+
label=(0 if brush_label == "Positive (green)" else 1) # current color of the brush
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Make prediction
|
| 202 |
+
best_mask, img_features, low_res_mask = get_predictions(
|
| 203 |
+
predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Update input visualizations
|
| 207 |
+
mask_to_viz = best_mask.numpy()
|
| 208 |
+
click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
|
| 209 |
+
scribble_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
|
| 210 |
+
|
| 211 |
+
out_viz = [
|
| 212 |
+
viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
|
| 213 |
+
255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3),
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
|
| 220 |
+
click_coords, click_labels, bbox_coords,
|
| 221 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 222 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData):
|
| 223 |
+
"""
|
| 224 |
+
Record user click and update the prediction
|
| 225 |
+
"""
|
| 226 |
+
# Record click coordinates
|
| 227 |
+
if bbox_label:
|
| 228 |
+
bbox_coords.append(evt.index)
|
| 229 |
+
elif brush_label in ['Positive (green)', 'Negative (red)']:
|
| 230 |
+
click_coords.append(evt.index)
|
| 231 |
+
click_labels.append(1 if brush_label=='Positive (green)' else 0)
|
| 232 |
+
else:
|
| 233 |
+
raise TypeError("Invalid brush label: {brush_label}")
|
| 234 |
+
|
| 235 |
+
# Only make new prediction if not waiting for additional bounding box click
|
| 236 |
+
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
|
| 237 |
+
|
| 238 |
+
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
|
| 239 |
+
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
| 240 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
| 241 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
|
| 242 |
+
)
|
| 243 |
+
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
| 244 |
+
|
| 245 |
+
else:
|
| 246 |
+
click_input_viz = viz_pred_mask(
|
| 247 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
| 248 |
+
)
|
| 249 |
+
scribble_input_viz = viz_pred_mask(
|
| 250 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
|
| 251 |
+
)
|
| 252 |
+
# Don't update output image if waiting for additional bounding box click
|
| 253 |
+
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
| 257 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 258 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox):
|
| 259 |
+
"""
|
| 260 |
+
Remove last click and then update the prediction
|
| 261 |
+
"""
|
| 262 |
+
if bbox_label:
|
| 263 |
+
if len(bbox_coords) > 0:
|
| 264 |
+
bbox_coords.pop()
|
| 265 |
+
elif brush_label in ['Positive (green)', 'Negative (red)']:
|
| 266 |
+
if len(click_coords) > 0:
|
| 267 |
+
click_coords.pop()
|
| 268 |
+
click_labels.pop()
|
| 269 |
+
else:
|
| 270 |
+
raise TypeError("Invalid brush label: {brush_label}")
|
| 271 |
+
|
| 272 |
+
# Only make new prediction if not waiting for additional bounding box click
|
| 273 |
+
if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox:
|
| 274 |
+
|
| 275 |
+
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
|
| 276 |
+
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
| 277 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
| 278 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
|
| 279 |
+
)
|
| 280 |
+
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
| 281 |
+
|
| 282 |
+
else:
|
| 283 |
+
click_input_viz = viz_pred_mask(
|
| 284 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
| 285 |
+
)
|
| 286 |
+
scribble_input_viz = viz_pred_mask(
|
| 287 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Don't update output image if waiting for additional bounding box click
|
| 291 |
+
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# --------------------------------------------------
|
| 296 |
+
|
| 297 |
+
with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
|
| 298 |
+
|
| 299 |
+
# State variables
|
| 300 |
+
seperate_scribble_masks = gr.State(np.zeros((2,RES,RES), dtype=np.float32))
|
| 301 |
+
last_scribble_mask = gr.State(np.zeros((RES,RES), dtype=np.float32))
|
| 302 |
+
|
| 303 |
+
click_coords = gr.State([])
|
| 304 |
+
click_labels = gr.State([])
|
| 305 |
+
bbox_coords = gr.State([])
|
| 306 |
+
|
| 307 |
+
# Load default model
|
| 308 |
+
predictor = gr.State(load_model()[0])
|
| 309 |
+
img_features = gr.State(None) # For SAM models
|
| 310 |
+
best_mask = gr.State(None)
|
| 311 |
+
low_res_mask = gr.State(None)
|
| 312 |
+
|
| 313 |
+
gr.HTML("""\
|
| 314 |
+
<h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Medical Image</h1>
|
| 315 |
+
<p style="text-align: center; font-size: large;"><a href="https://scribbleprompt.csail.mit.edu">ScribblePrompt</a> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
|
| 316 |
+
</p>
|
| 317 |
+
|
| 318 |
+
""")
|
| 319 |
+
|
| 320 |
+
with gr.Accordion("Open for instructions!", open=False):
|
| 321 |
+
gr.Markdown(
|
| 322 |
+
"""
|
| 323 |
+
* Select an input image from the examples below or upload your own image through the <b>'Input Image'</b> tab.
|
| 324 |
+
* Use the <b>'Scribbles'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> scribbles.
|
| 325 |
+
- Use the buttons in the top right hand corner of the canvas to undo or adjust the brush size
|
| 326 |
+
- Note: the app cannot detect new scribbles drawn on top of previous scribbles in a different color. Please undo/erase the scribble before drawing on the same pixel in a different color.
|
| 327 |
+
* Use the <b>'Clicks/Boxes'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> clicks and <span style='color:orange'>bounding boxes</span> by placing two clicks.
|
| 328 |
+
* The <b>'Output'</b> tab will show the model's prediction based on your current inputs and the previous prediction.
|
| 329 |
+
* The <b>'Clear Input Mask'</b> button will clear the latest prediction (which is used as an input to the model).
|
| 330 |
+
* The <b>'Clear All Inputs'</b> button will clear all inputs (including scribbles, clicks, bounding boxes, and the last prediction).
|
| 331 |
+
"""
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# Interface ------------------------------------
|
| 336 |
+
|
| 337 |
+
with gr.Row():
|
| 338 |
+
model_dropdown = gr.Dropdown(
|
| 339 |
+
label="Model",
|
| 340 |
+
choices = list(model_dict.keys()),
|
| 341 |
+
value=default_model,
|
| 342 |
+
multiselect=False,
|
| 343 |
+
interactive=False,
|
| 344 |
+
visible=False
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
with gr.Row():
|
| 348 |
+
with gr.Column(scale=1):
|
| 349 |
+
brush_label = gr.Radio(["Positive (green)", "Negative (red)"],
|
| 350 |
+
value="Positive (green)", label="Scribble/Click Label")
|
| 351 |
+
bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
|
| 352 |
+
with gr.Column(scale=1):
|
| 353 |
+
binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
|
| 354 |
+
autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
|
| 355 |
+
gr.Markdown("<span style='color:orange'>Troubleshooting:</span> If the image does not fully load in the Scribbles tab, click 'Clear Scribbles' or 'Clear All Inputs' to reload (it make take multiple tries). If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
|
| 356 |
+
multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
|
| 357 |
+
|
| 358 |
+
with gr.Row():
|
| 359 |
+
display_height = 500
|
| 360 |
+
|
| 361 |
+
with gr.Column(scale=1):
|
| 362 |
+
with gr.Tab("Scribbles"):
|
| 363 |
+
scribble_img = gr.Image(
|
| 364 |
+
label="Input",
|
| 365 |
+
brush_radius=3,
|
| 366 |
+
interactive=True,
|
| 367 |
+
brush_color="#00FF00",
|
| 368 |
+
tool="sketch",
|
| 369 |
+
height=display_height,
|
| 370 |
+
type='numpy',
|
| 371 |
+
value=default_example,
|
| 372 |
+
)
|
| 373 |
+
clear_scribble_button = gr.ClearButton([scribble_img], value="Clear Scribbles", variant="stop")
|
| 374 |
+
|
| 375 |
+
with gr.Tab("Clicks/Boxes") as click_tab:
|
| 376 |
+
click_img = gr.Image(
|
| 377 |
+
label="Input",
|
| 378 |
+
type='numpy',
|
| 379 |
+
value=default_example,
|
| 380 |
+
height=display_height
|
| 381 |
+
)
|
| 382 |
+
with gr.Row():
|
| 383 |
+
undo_click_button = gr.Button("Undo Last Click")
|
| 384 |
+
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
|
| 385 |
+
|
| 386 |
+
with gr.Tab("Input Image"):
|
| 387 |
+
input_img = gr.Image(
|
| 388 |
+
label="Input",
|
| 389 |
+
image_mode="L",
|
| 390 |
+
visible=True,
|
| 391 |
+
value=default_example,
|
| 392 |
+
height=display_height
|
| 393 |
+
)
|
| 394 |
+
gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
|
| 395 |
+
|
| 396 |
+
with gr.Column(scale=1):
|
| 397 |
+
with gr.Tab("Output"):
|
| 398 |
+
output_img = gr.Gallery(
|
| 399 |
+
label='Outputs',
|
| 400 |
+
columns=1,
|
| 401 |
+
elem_id="gallery",
|
| 402 |
+
preview=True,
|
| 403 |
+
object_fit="scale-down",
|
| 404 |
+
height=display_height+50
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
submit_button = gr.Button("Refresh Prediction", variant='primary')
|
| 408 |
+
clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop")
|
| 409 |
+
clear_mask_button = gr.Button("Clear Input Mask")
|
| 410 |
+
|
| 411 |
+
# ----------------------------------------------
|
| 412 |
+
# Loading Models
|
| 413 |
+
# ----------------------------------------------
|
| 414 |
+
|
| 415 |
+
model_dropdown.change(fn=load_model,
|
| 416 |
+
inputs=[model_dropdown],
|
| 417 |
+
outputs=[predictor, img_features]
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# ----------------------------------------------
|
| 421 |
+
# Loading Examples
|
| 422 |
+
# ----------------------------------------------
|
| 423 |
+
|
| 424 |
+
gr.Examples(examples=test_examples,
|
| 425 |
+
inputs=[input_img],
|
| 426 |
+
examples_per_page=10,
|
| 427 |
+
label='Unseen Examples from Test Datasets'
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
gr.Examples(examples=val_examples,
|
| 431 |
+
inputs=[input_img],
|
| 432 |
+
examples_per_page=10,
|
| 433 |
+
label='Unseen Examples from Validation Datasets'
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# When clear scribble button is clicked
|
| 437 |
+
def clear_scribble_history(input_img):
|
| 438 |
+
if input_img is not None:
|
| 439 |
+
input_shape = input_img.shape[:2]
|
| 440 |
+
else:
|
| 441 |
+
input_shape = (RES, RES)
|
| 442 |
+
return input_img, input_img, np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None
|
| 443 |
+
|
| 444 |
+
clear_scribble_button.click(clear_scribble_history,
|
| 445 |
+
inputs=[input_img],
|
| 446 |
+
outputs=[click_img, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# When clear clicks button is clicked
|
| 450 |
+
def clear_click_history(input_img):
|
| 451 |
+
return input_img, input_img, [], [], [], None, None
|
| 452 |
+
|
| 453 |
+
clear_click_button.click(clear_click_history,
|
| 454 |
+
inputs=[input_img],
|
| 455 |
+
outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask])
|
| 456 |
+
|
| 457 |
+
# When clear all button is clicked
|
| 458 |
+
def clear_all_history(input_img):
|
| 459 |
+
if input_img is not None:
|
| 460 |
+
input_shape = input_img.shape[:2]
|
| 461 |
+
else:
|
| 462 |
+
input_shape = (RES, RES)
|
| 463 |
+
return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
|
| 464 |
+
|
| 465 |
+
input_img.change(clear_all_history,
|
| 466 |
+
inputs=[input_img],
|
| 467 |
+
outputs=[click_img, scribble_img,
|
| 468 |
+
output_img, click_coords, click_labels, bbox_coords,
|
| 469 |
+
seperate_scribble_masks, last_scribble_mask,
|
| 470 |
+
best_mask, low_res_mask, img_features
|
| 471 |
+
])
|
| 472 |
+
|
| 473 |
+
clear_all_button.click(clear_all_history,
|
| 474 |
+
inputs=[input_img],
|
| 475 |
+
outputs=[click_img, scribble_img,
|
| 476 |
+
output_img, click_coords, click_labels, bbox_coords,
|
| 477 |
+
seperate_scribble_masks, last_scribble_mask,
|
| 478 |
+
best_mask, low_res_mask, img_features
|
| 479 |
+
])
|
| 480 |
+
|
| 481 |
+
# clear previous prediction mask
|
| 482 |
+
def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks):
|
| 483 |
+
|
| 484 |
+
click_input_viz = viz_pred_mask(
|
| 485 |
+
input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks
|
| 486 |
+
)
|
| 487 |
+
scribble_input_viz = viz_pred_mask(
|
| 488 |
+
input_img, None, click_coords, click_labels, bbox_coords, None
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
return None, None, click_input_viz, scribble_input_viz
|
| 492 |
+
|
| 493 |
+
clear_mask_button.click(
|
| 494 |
+
clear_best_mask,
|
| 495 |
+
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks],
|
| 496 |
+
outputs=[best_mask, low_res_mask, click_img, scribble_img],
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# ----------------------------------------------
|
| 500 |
+
# Clicks
|
| 501 |
+
# ----------------------------------------------
|
| 502 |
+
|
| 503 |
+
click_img.select(get_select_coords,
|
| 504 |
+
inputs=[
|
| 505 |
+
predictor,
|
| 506 |
+
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
| 507 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 508 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
| 509 |
+
],
|
| 510 |
+
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
| 511 |
+
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
| 512 |
+
api_name = "get_select_coords"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
submit_button.click(fn=refresh_predictions,
|
| 516 |
+
inputs=[
|
| 517 |
+
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
|
| 518 |
+
scribble_img, seperate_scribble_masks, last_scribble_mask,
|
| 519 |
+
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
|
| 520 |
+
],
|
| 521 |
+
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
| 522 |
+
seperate_scribble_masks, last_scribble_mask],
|
| 523 |
+
api_name="refresh_predictions"
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
undo_click_button.click(fn=undo_click,
|
| 527 |
+
inputs=[
|
| 528 |
+
predictor,
|
| 529 |
+
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
| 530 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
| 531 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
| 532 |
+
],
|
| 533 |
+
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
| 534 |
+
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
| 535 |
+
api_name="undo_click"
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
def update_click_img(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox,
|
| 539 |
+
last_scribble_mask, scribble_img, brush_label, best_mask):
|
| 540 |
+
"""
|
| 541 |
+
Draw scribbles in the click canvas
|
| 542 |
+
"""
|
| 543 |
+
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
| 544 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
| 545 |
+
label=(0 if brush_label == "Positive (green)" else 1) # previous color of the brush
|
| 546 |
+
)
|
| 547 |
+
click_input_viz = viz_pred_mask(
|
| 548 |
+
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
|
| 549 |
+
)
|
| 550 |
+
return click_input_viz, seperate_scribble_masks, last_scribble_mask
|
| 551 |
+
|
| 552 |
+
click_tab.select(fn=update_click_img,
|
| 553 |
+
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
|
| 554 |
+
binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask],
|
| 555 |
+
outputs=[click_img, seperate_scribble_masks, last_scribble_mask],
|
| 556 |
+
api_name="update_click_img"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# ----------------------------------------------
|
| 560 |
+
# Scribbles
|
| 561 |
+
# ----------------------------------------------
|
| 562 |
+
|
| 563 |
+
def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label):
|
| 564 |
+
"""
|
| 565 |
+
Recorn new scribbles when changing brush color
|
| 566 |
+
"""
|
| 567 |
+
if label == "Negative (red)":
|
| 568 |
+
brush_update = gr.Image.update(brush_color = "#FF0000") # red
|
| 569 |
+
elif label == "Positive (green)":
|
| 570 |
+
brush_update = gr.Image.update(brush_color = "#00FF00") # green
|
| 571 |
+
else:
|
| 572 |
+
raise TypeError("Invalid brush color")
|
| 573 |
+
|
| 574 |
+
# Record latest scribbles
|
| 575 |
+
seperate_scribble_masks, last_scribble_mask = get_scribbles(
|
| 576 |
+
seperate_scribble_masks, last_scribble_mask, scribble_img,
|
| 577 |
+
label=(1 if label == "Positive (green)" else 0) # previous color of the brush
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
return seperate_scribble_masks, last_scribble_mask, brush_update
|
| 581 |
+
|
| 582 |
+
brush_label.change(fn=change_brush_color,
|
| 583 |
+
inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label],
|
| 584 |
+
outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img],
|
| 585 |
+
api_name="change_brush_color"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
if __name__ == "__main__":
|
| 590 |
+
|
| 591 |
+
demo.queue(api_open=False).launch(show_api=False)
|
checkpoints/ScribblePrompt_unet_v1_nf192_res128.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43f57ee8fa8ec529c31be281e06749f9e629b30157bbbcc9baf200cddec1acbe
|
| 3 |
+
size 15977486
|
network.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, Any, List
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
# -----------------------------------------------------------------------------
|
| 6 |
+
# Blocks
|
| 7 |
+
# -----------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
class Conv2d(nn.Module):
|
| 10 |
+
""" Perform a 2D convolution
|
| 11 |
+
|
| 12 |
+
inputs are [b, c, h, w] where
|
| 13 |
+
b is the batch size
|
| 14 |
+
c is the number of channels
|
| 15 |
+
h is the height
|
| 16 |
+
w is the width
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
out_channels: int,
|
| 21 |
+
kernel_size: int,
|
| 22 |
+
padding: int,
|
| 23 |
+
do_activation: bool = True,
|
| 24 |
+
):
|
| 25 |
+
super(Conv2d, self).__init__()
|
| 26 |
+
|
| 27 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
|
| 28 |
+
lst = [conv]
|
| 29 |
+
|
| 30 |
+
if do_activation:
|
| 31 |
+
lst.append(nn.PReLU())
|
| 32 |
+
|
| 33 |
+
self.conv = nn.Sequential(*lst)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
# x is [B, C, H, W]
|
| 37 |
+
return self.conv(x)
|
| 38 |
+
|
| 39 |
+
# -----------------------------------------------------------------------------
|
| 40 |
+
# Network
|
| 41 |
+
# -----------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
class _UNet(nn.Module):
|
| 44 |
+
def __init__(self,
|
| 45 |
+
in_channels: int = 1,
|
| 46 |
+
out_channels: int = 1,
|
| 47 |
+
features: List[int] = [64, 64, 64, 64, 64],
|
| 48 |
+
conv_kernel_size: int = 3,
|
| 49 |
+
conv: Optional[nn.Module] = None,
|
| 50 |
+
conv_kwargs: Dict[str,Any] = {}
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
UNet (but can switch out the Conv)
|
| 54 |
+
"""
|
| 55 |
+
super(_UNet, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.in_channels = in_channels
|
| 58 |
+
|
| 59 |
+
padding = (conv_kernel_size - 1) // 2
|
| 60 |
+
|
| 61 |
+
self.ups = nn.ModuleList()
|
| 62 |
+
self.downs = nn.ModuleList()
|
| 63 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 64 |
+
|
| 65 |
+
# Down part of U-Net
|
| 66 |
+
for feat in features:
|
| 67 |
+
self.downs.append(
|
| 68 |
+
conv(
|
| 69 |
+
in_channels, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
in_channels = feat
|
| 73 |
+
|
| 74 |
+
# Up part of U-Net
|
| 75 |
+
for feat in reversed(features):
|
| 76 |
+
self.ups.append(nn.UpsamplingBilinear2d(scale_factor=2))
|
| 77 |
+
self.ups.append(
|
| 78 |
+
conv(
|
| 79 |
+
# Factor of 2 is for the skip connections
|
| 80 |
+
feat * 2, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.bottleneck = conv(
|
| 85 |
+
features[-1], features[-1], kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
|
| 86 |
+
)
|
| 87 |
+
self.final_conv = conv(
|
| 88 |
+
features[0], out_channels, kernel_size=1, padding=0, do_activation=False, **conv_kwargs
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
skip_connections = []
|
| 93 |
+
for down in self.downs:
|
| 94 |
+
x = down(x)
|
| 95 |
+
skip_connections.append(x)
|
| 96 |
+
x = self.pool(x)
|
| 97 |
+
|
| 98 |
+
x = self.bottleneck(x)
|
| 99 |
+
skip_connections = skip_connections[::-1]
|
| 100 |
+
|
| 101 |
+
for idx in range(0, len(self.ups), 2):
|
| 102 |
+
x = self.ups[idx](x)
|
| 103 |
+
skip_connection = skip_connections[idx // 2]
|
| 104 |
+
|
| 105 |
+
concat_skip = torch.cat((skip_connection, x), dim=1)
|
| 106 |
+
x = self.ups[idx + 1](concat_skip)
|
| 107 |
+
|
| 108 |
+
return self.final_conv(x)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class UNet(_UNet):
|
| 112 |
+
"""
|
| 113 |
+
Unet with normal conv blocks
|
| 114 |
+
|
| 115 |
+
input shape: B x C x H x W
|
| 116 |
+
output shape: B x C x H x W
|
| 117 |
+
"""
|
| 118 |
+
def __init__(self, **kwargs) -> None:
|
| 119 |
+
super().__init__(conv=Conv2d, **kwargs)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
return super().forward(x)
|
| 123 |
+
|
predictor.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Dict, Tuple, Optional
|
| 4 |
+
import network
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Predictor:
|
| 8 |
+
"""
|
| 9 |
+
Wrapper for ScribblePrompt Unet model
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, path: str, verbose: bool = False):
|
| 12 |
+
|
| 13 |
+
self.verbose = verbose
|
| 14 |
+
|
| 15 |
+
assert path.exists(), f"Checkpoint {path} does not exist"
|
| 16 |
+
self.path = path
|
| 17 |
+
|
| 18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
self.build_model()
|
| 20 |
+
self.load()
|
| 21 |
+
self.model.eval()
|
| 22 |
+
self.to_device()
|
| 23 |
+
|
| 24 |
+
def build_model(self):
|
| 25 |
+
"""
|
| 26 |
+
Build the model
|
| 27 |
+
"""
|
| 28 |
+
self.model = network.UNet(
|
| 29 |
+
in_channels = 5,
|
| 30 |
+
out_channels = 1,
|
| 31 |
+
features = [192, 192, 192, 192],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def load(self):
|
| 35 |
+
"""
|
| 36 |
+
Load the state of the model from a checkpoint file.
|
| 37 |
+
"""
|
| 38 |
+
with (self.path).open("rb") as f:
|
| 39 |
+
state = torch.load(f, map_location=self.device)
|
| 40 |
+
self.model.load_state_dict(state, strict=True)
|
| 41 |
+
if self.verbose:
|
| 42 |
+
print(
|
| 43 |
+
f"Loaded checkpoint from {self.path} to {self.device}"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def to_device(self):
|
| 47 |
+
"""
|
| 48 |
+
Move the model to cpu or gpu
|
| 49 |
+
"""
|
| 50 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 51 |
+
self.model = self.model.to(self.device)
|
| 52 |
+
|
| 53 |
+
def predict(self, prompts: Dict[str,any], img_features: Optional[torch.Tensor] = None, multimask_mode: bool = False):
|
| 54 |
+
"""
|
| 55 |
+
Make predictions!
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
mask (torch.Tensor): H x W
|
| 59 |
+
img_features (torch.Tensor): B x 1 x H x W (for SAM models)
|
| 60 |
+
low_res_mask (torch.Tensor): B x 1 x H x W logits
|
| 61 |
+
"""
|
| 62 |
+
if self.verbose:
|
| 63 |
+
print("point_coords", prompts.get("point_coords", None))
|
| 64 |
+
print("point_labels", prompts.get("point_labels", None))
|
| 65 |
+
print("box", prompts.get("box", None))
|
| 66 |
+
print("img", prompts.get("img").shape, prompts.get("img").min(), prompts.get("img").max())
|
| 67 |
+
if prompts.get("scribble") is not None:
|
| 68 |
+
print("scribble", prompts.get("scribble", None).shape, prompts.get("scribble").min(), prompts.get("scribble").max())
|
| 69 |
+
|
| 70 |
+
original_shape = prompts.get('img').shape[-2:]
|
| 71 |
+
|
| 72 |
+
# Rescale to 128 x 128
|
| 73 |
+
prompts = rescale_inputs(prompts)
|
| 74 |
+
|
| 75 |
+
# Prepare inputs for ScribblePrompt unet (1 x 5 x 128 x 128)
|
| 76 |
+
x = prepare_inputs(prompts).float()
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
yhat = self.model(x.to(self.device)).cpu()
|
| 80 |
+
|
| 81 |
+
mask = torch.sigmoid(yhat)
|
| 82 |
+
|
| 83 |
+
# Resize for app resolution
|
| 84 |
+
mask = F.interpolate(mask, size=original_shape, mode='bilinear').squeeze()
|
| 85 |
+
|
| 86 |
+
# mask: H x W, yhat: 1 x 1 x H x W
|
| 87 |
+
return mask, None, yhat
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# -----------------------------------------------------------------------------
|
| 91 |
+
# Prepare inputs
|
| 92 |
+
# -----------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
def rescale_inputs(inputs: Dict[str,any], res=128):
|
| 95 |
+
"""
|
| 96 |
+
Rescale the inputs
|
| 97 |
+
"""
|
| 98 |
+
h,w = inputs['img'].shape[-2:]
|
| 99 |
+
if h != res or w != res:
|
| 100 |
+
|
| 101 |
+
inputs.update(dict(
|
| 102 |
+
img = F.interpolate(inputs['img'], size=(res,res), mode='bilinear')
|
| 103 |
+
))
|
| 104 |
+
|
| 105 |
+
if inputs.get('scribble') is not None:
|
| 106 |
+
inputs.update({
|
| 107 |
+
'scribble': F.interpolate(inputs['scribble'], size=(res,res), mode='bilinear')
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
if inputs.get("box") is not None:
|
| 111 |
+
boxes = inputs.get("box").clone()
|
| 112 |
+
coords = boxes.reshape(-1, 2, 2)
|
| 113 |
+
coords[..., 0] = coords[..., 0] * (res / w)
|
| 114 |
+
coords[..., 1] = coords[..., 1] * (res / h)
|
| 115 |
+
inputs.update({'box': coords.reshape(1, -1, 4).int()})
|
| 116 |
+
|
| 117 |
+
if inputs.get("point_coords") is not None:
|
| 118 |
+
coords = inputs.get("point_coords").clone()
|
| 119 |
+
coords[..., 0] = coords[..., 0] * (res / w)
|
| 120 |
+
coords[..., 1] = coords[..., 1] * (res / h)
|
| 121 |
+
inputs.update({'point_coords': coords.int()})
|
| 122 |
+
|
| 123 |
+
return inputs
|
| 124 |
+
|
| 125 |
+
def prepare_inputs(inputs: Dict[str,torch.Tensor], device = None) -> torch.Tensor:
|
| 126 |
+
"""
|
| 127 |
+
Prepare inputs for ScribblePrompt Unet
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
x (torch.Tensor): B x 5 x H x W
|
| 131 |
+
"""
|
| 132 |
+
img = inputs['img']
|
| 133 |
+
if device is None:
|
| 134 |
+
device = img.device
|
| 135 |
+
|
| 136 |
+
img = img.to(device)
|
| 137 |
+
shape = tuple(img.shape[-2:])
|
| 138 |
+
|
| 139 |
+
if inputs.get("box") is not None:
|
| 140 |
+
# Embed bounding box
|
| 141 |
+
# Input: B x 1 x 4
|
| 142 |
+
# Output: B x 1 x H x W
|
| 143 |
+
box_embed = bbox_shaded(inputs['box'], shape=shape, device=device)
|
| 144 |
+
else:
|
| 145 |
+
box_embed = torch.zeros(img.shape, device=device)
|
| 146 |
+
|
| 147 |
+
if inputs.get("point_coords") is not None:
|
| 148 |
+
# Encode points
|
| 149 |
+
# B x 2 x H x W
|
| 150 |
+
scribble_click_embed = click_onehot(inputs['point_coords'], inputs['point_labels'], shape=shape)
|
| 151 |
+
else:
|
| 152 |
+
scribble_click_embed = torch.zeros((img.shape[0], 2) + shape, device=device)
|
| 153 |
+
|
| 154 |
+
if inputs.get("scribble") is not None:
|
| 155 |
+
# Combine scribbles with click encoding
|
| 156 |
+
# B x 2 x H x W
|
| 157 |
+
scribble_click_embed = torch.clamp(scribble_click_embed + inputs.get('scribble'), min=0.0, max=1.0)
|
| 158 |
+
|
| 159 |
+
if inputs.get('mask_input') is not None:
|
| 160 |
+
# Previous prediction
|
| 161 |
+
mask_input = inputs['mask_input']
|
| 162 |
+
else:
|
| 163 |
+
# Initialize empty channel for mask input
|
| 164 |
+
mask_input = torch.zeros(img.shape, device=img.device)
|
| 165 |
+
|
| 166 |
+
x = torch.cat((img, box_embed, scribble_click_embed, mask_input), dim=-3)
|
| 167 |
+
# B x 5 x H x W
|
| 168 |
+
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
# -----------------------------------------------------------------------------
|
| 172 |
+
# Encode clicks and bounding boxes
|
| 173 |
+
# -----------------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def click_onehot(point_coords, point_labels, shape: Tuple[int,int] = (128,128), indexing='xy'):
|
| 176 |
+
"""
|
| 177 |
+
Represent clicks as two HxW binary masks (one for positive clicks and one for negative)
|
| 178 |
+
with 1 at the click locations and 0 otherwise
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
point_coords (torch.Tensor): BxNx2 tensor of xy coordinates
|
| 182 |
+
point_labels (torch.Tensor): BxN tensor of labels (0 or 1)
|
| 183 |
+
shape (tuple): output shape
|
| 184 |
+
Returns:
|
| 185 |
+
embed (torch.Tensor): Bx2xHxW tensor
|
| 186 |
+
"""
|
| 187 |
+
assert indexing in ['xy','uv'], f"Invalid indexing: {indexing}"
|
| 188 |
+
assert len(point_coords.shape) == 3, "point_coords must be BxNx2"
|
| 189 |
+
assert point_coords.shape[-1] == 2, "point_coords must be BxNx2"
|
| 190 |
+
assert point_labels.shape[-1] == point_coords.shape[1], "point_labels must be BxN"
|
| 191 |
+
assert len(shape)==2, f"shape must be 2D: {shape}"
|
| 192 |
+
|
| 193 |
+
device = point_coords.device
|
| 194 |
+
batch_size = point_coords.shape[0]
|
| 195 |
+
n_points = point_coords.shape[1]
|
| 196 |
+
|
| 197 |
+
embed = torch.zeros((batch_size,2)+shape, device=device)
|
| 198 |
+
labels = point_labels.flatten().float()
|
| 199 |
+
|
| 200 |
+
idx_coords = torch.cat((
|
| 201 |
+
torch.arange(batch_size, device=device).reshape(-1,1).repeat(1,n_points)[...,None],
|
| 202 |
+
point_coords
|
| 203 |
+
), axis=2).reshape(-1,3)
|
| 204 |
+
|
| 205 |
+
if indexing=='xy':
|
| 206 |
+
embed[ idx_coords[:,0], 0, idx_coords[:,2], idx_coords[:,1] ] = labels
|
| 207 |
+
embed[ idx_coords[:,0], 1, idx_coords[:,2], idx_coords[:,1] ] = 1.0-labels
|
| 208 |
+
else:
|
| 209 |
+
embed[ idx_coords[:,0], 0, idx_coords[:,1], idx_coords[:,2] ] = labels
|
| 210 |
+
embed[ idx_coords[:,0], 1, idx_coords[:,1], idx_coords[:,2] ] = 1.0-labels
|
| 211 |
+
|
| 212 |
+
return embed
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def bbox_shaded(boxes, shape: Tuple[int,int] = (128,128), device='cpu'):
|
| 216 |
+
"""
|
| 217 |
+
Represent bounding boxes as a binary mask with 1 inside boxes and 0 otherwise
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
boxes (torch.Tensor): Bx1x4 [x1, y1, x2, y2]
|
| 221 |
+
Returns:
|
| 222 |
+
bbox_embed (torch.Tesor): Bx1xHxW according to shape
|
| 223 |
+
"""
|
| 224 |
+
assert len(shape)==2, "shape must be 2D"
|
| 225 |
+
if isinstance(boxes, torch.Tensor):
|
| 226 |
+
boxes = boxes.int().cpu().numpy()
|
| 227 |
+
|
| 228 |
+
batch_size = boxes.shape[0]
|
| 229 |
+
n_boxes = boxes.shape[1]
|
| 230 |
+
bbox_embed = torch.zeros((batch_size,1)+tuple(shape), device=device, dtype=torch.float32)
|
| 231 |
+
|
| 232 |
+
if boxes is not None:
|
| 233 |
+
for i in range(batch_size):
|
| 234 |
+
for j in range(n_boxes):
|
| 235 |
+
x1, y1, x2, y2 = boxes[i,j,:]
|
| 236 |
+
x_min = min(x1,x2)
|
| 237 |
+
x_max = max(x1,x2)
|
| 238 |
+
y_min = min(y1,y2)
|
| 239 |
+
y_max = max(y1,y2)
|
| 240 |
+
bbox_embed[ i, 0, y_min:y_max, x_min:x_max ] = 1.0
|
| 241 |
+
|
| 242 |
+
return bbox_embed
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
torch
|
| 3 |
+
opencv-python
|
| 4 |
+
pathlib
|
test_examples/COBRE.jpg
ADDED
|
test_examples/SCR.jpg
ADDED
|
test_examples/TotalSegmentator.jpg
ADDED
|
test_examples/TotalSegmentator_2.jpg
ADDED
|
val_od_examples/ACDC.jpg
ADDED
|
val_od_examples/BTCV.jpg
ADDED
|
val_od_examples/BUID.jpg
ADDED
|
val_od_examples/DRIVE.jpg
ADDED
|
val_od_examples/HipXRay.jpg
ADDED
|
val_od_examples/PanDental.jpg
ADDED
|
val_od_examples/SCD.jpg
ADDED
|
val_od_examples/SpineWeb.jpg
ADDED
|
val_od_examples/WBC.jpg
ADDED
|