Update engine.py
Browse files
engine.py
CHANGED
|
@@ -108,7 +108,9 @@ def predict_fn(data_loader, model, device, extract_features=False):
|
|
| 108 |
mask=mask,
|
| 109 |
token_type_ids=token_type_ids
|
| 110 |
).cpu().detach().numpy().tolist())
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
fin_outputs.extend(torch.argmax(
|
| 113 |
outputs, dim=1).cpu().detach().numpy().tolist())
|
| 114 |
|
|
|
|
| 108 |
mask=mask,
|
| 109 |
token_type_ids=token_type_ids
|
| 110 |
).cpu().detach().numpy().tolist())
|
| 111 |
+
print("1",torch.argmax(outputs, dim=1))
|
| 112 |
+
print("2",torch.argmax(outputs, dim=1).cpu())
|
| 113 |
+
print("3",torch.argmax(outputs, dim=1).cpu().numpy())
|
| 114 |
fin_outputs.extend(torch.argmax(
|
| 115 |
outputs, dim=1).cpu().detach().numpy().tolist())
|
| 116 |
|