Update app_flash.py
Browse files- app_flash.py +5 -2
app_flash.py
CHANGED
|
@@ -23,14 +23,17 @@ print(f"🔧 Using device: {device} (CPU-only)")
|
|
| 23 |
# 1️⃣ FlashPack model with better hidden layers
|
| 24 |
# ============================================================
|
| 25 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 26 |
-
def __init__(self
|
| 27 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
| 28 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 29 |
self.relu = nn.ReLU()
|
| 30 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 31 |
self.fc3 = nn.Linear(hidden_dim, output_dim)
|
| 32 |
|
| 33 |
-
def forward(self, x: torch.Tensor)
|
| 34 |
x = self.fc1(x)
|
| 35 |
x = self.relu(x)
|
| 36 |
x = self.fc2(x)
|
|
|
|
| 23 |
# 1️⃣ FlashPack model with better hidden layers
|
| 24 |
# ============================================================
|
| 25 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 26 |
+
def __init__(self):
|
| 27 |
super().__init__()
|
| 28 |
+
input_dim = 1536
|
| 29 |
+
hidden_dim = 1024
|
| 30 |
+
output_dim = 1536
|
| 31 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 32 |
self.relu = nn.ReLU()
|
| 33 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 34 |
self.fc3 = nn.Linear(hidden_dim, output_dim)
|
| 35 |
|
| 36 |
+
def forward(self, x: torch.Tensor):
|
| 37 |
x = self.fc1(x)
|
| 38 |
x = self.relu(x)
|
| 39 |
x = self.fc2(x)
|