Spaces:
Running
on
Zero
Running
on
Zero
Update trellis/pipelines/trellis_image_to_3d.py
Browse files
trellis/pipelines/trellis_image_to_3d.py
CHANGED
|
@@ -232,17 +232,14 @@ class TrellisImageTo3DPipeline(Pipeline):
|
|
| 232 |
if scale < 1:
|
| 233 |
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
|
| 237 |
|
| 238 |
-
#
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
if getattr(self, 'rembg_session', None) is None:
|
| 244 |
-
self.rembg_session = rembg.new_session('u2net')
|
| 245 |
-
output = rembg.remove(input, session=self.rembg_session)
|
| 246 |
|
| 247 |
# Process the output image
|
| 248 |
output_np = np.array(output)
|
|
@@ -341,7 +338,7 @@ class TrellisImageTo3DPipeline(Pipeline):
|
|
| 341 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 342 |
])
|
| 343 |
|
| 344 |
-
input_images = transform_image(image).unsqueeze(0).
|
| 345 |
|
| 346 |
with torch.no_grad():
|
| 347 |
preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
|
|
@@ -793,11 +790,11 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
|
|
| 793 |
del new_pipeline.VGGT_model.point_head
|
| 794 |
new_pipeline.VGGT_model.eval()
|
| 795 |
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
|
| 802 |
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
| 803 |
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
|
|
|
| 232 |
if scale < 1:
|
| 233 |
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
| 234 |
|
| 235 |
+
# Get mask using BiRefNet
|
| 236 |
+
mask = self._get_birefnet_mask(input)
|
| 237 |
|
| 238 |
+
# Convert input to RGBA and apply mask
|
| 239 |
+
input_rgba = input.convert('RGBA')
|
| 240 |
+
input_array = np.array(input_rgba)
|
| 241 |
+
input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
|
| 242 |
+
output = Image.fromarray(input_array)
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
# Process the output image
|
| 245 |
output_np = np.array(output)
|
|
|
|
| 338 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 339 |
])
|
| 340 |
|
| 341 |
+
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
| 342 |
|
| 343 |
with torch.no_grad():
|
| 344 |
preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
|
|
|
|
| 790 |
del new_pipeline.VGGT_model.point_head
|
| 791 |
new_pipeline.VGGT_model.eval()
|
| 792 |
|
| 793 |
+
new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
|
| 794 |
+
'ZhengPeng7/BiRefNet',
|
| 795 |
+
trust_remote_code=True
|
| 796 |
+
).to(new_pipeline.device)
|
| 797 |
+
new_pipeline.birefnet_model.eval()
|
| 798 |
|
| 799 |
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
| 800 |
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|