Update engine.py
Browse files
engine.py
CHANGED
|
@@ -108,6 +108,7 @@ def predict_fn(data_loader, model, device, extract_features=False):
|
|
| 108 |
mask=mask,
|
| 109 |
token_type_ids=token_type_ids
|
| 110 |
).cpu().detach().numpy().tolist())
|
|
|
|
| 111 |
print("1",torch.argmax(outputs, dim=1))
|
| 112 |
print("2",torch.argmax(outputs, dim=1).cpu())
|
| 113 |
print("3",torch.argmax(outputs, dim=1).cpu().numpy())
|
|
|
|
| 108 |
mask=mask,
|
| 109 |
token_type_ids=token_type_ids
|
| 110 |
).cpu().detach().numpy().tolist())
|
| 111 |
+
print("0",outputs)
|
| 112 |
print("1",torch.argmax(outputs, dim=1))
|
| 113 |
print("2",torch.argmax(outputs, dim=1).cpu())
|
| 114 |
print("3",torch.argmax(outputs, dim=1).cpu().numpy())
|