Spaces:
Sleeping
Sleeping
TedYeh
commited on
Commit
·
bf809f1
1
Parent(s):
a38dfb6
update predictor
Browse files- predictor.py +2 -2
predictor.py
CHANGED
|
@@ -198,7 +198,7 @@ def evaluation(model, epoch, device, dataloaders):
|
|
| 198 |
print(preds)
|
| 199 |
|
| 200 |
def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
| 201 |
-
device = torch.device("cuda")
|
| 202 |
translator= Translator(to_lang="zh-TW")
|
| 203 |
|
| 204 |
model = CUPredictor()
|
|
@@ -218,7 +218,7 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
|
| 218 |
image_tensor = trans(inp_img)
|
| 219 |
image_tensor = image_tensor.unsqueeze(0)
|
| 220 |
with torch.no_grad():
|
| 221 |
-
inputs = image_tensor
|
| 222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
| 223 |
_, preds = torch.max(outputs_c, 1)
|
| 224 |
idx = preds.numpy()[0]
|
|
|
|
| 198 |
print(preds)
|
| 199 |
|
| 200 |
def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
| 201 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 202 |
translator= Translator(to_lang="zh-TW")
|
| 203 |
|
| 204 |
model = CUPredictor()
|
|
|
|
| 218 |
image_tensor = trans(inp_img)
|
| 219 |
image_tensor = image_tensor.unsqueeze(0)
|
| 220 |
with torch.no_grad():
|
| 221 |
+
inputs = image_tensor.to(device)
|
| 222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
| 223 |
_, preds = torch.max(outputs_c, 1)
|
| 224 |
idx = preds.numpy()[0]
|