rahul7star commited on
Commit
2a7aa21
·
verified ·
1 Parent(s): 2d5cdb3

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +5 -2
app_flash.py CHANGED
@@ -23,14 +23,17 @@ print(f"🔧 Using device: {device} (CPU-only)")
23
  # 1️⃣ FlashPack model with better hidden layers
24
  # ============================================================
25
  class GemmaTrainer(nn.Module, FlashPackMixin):
26
- def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
27
  super().__init__()
 
 
 
28
  self.fc1 = nn.Linear(input_dim, hidden_dim)
29
  self.relu = nn.ReLU()
30
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
31
  self.fc3 = nn.Linear(hidden_dim, output_dim)
32
 
33
- def forward(self, x: torch.Tensor) -> torch.Tensor:
34
  x = self.fc1(x)
35
  x = self.relu(x)
36
  x = self.fc2(x)
 
23
  # 1️⃣ FlashPack model with better hidden layers
24
  # ============================================================
25
  class GemmaTrainer(nn.Module, FlashPackMixin):
26
+ def __init__(self):
27
  super().__init__()
28
+ input_dim = 1536
29
+ hidden_dim = 1024
30
+ output_dim = 1536
31
  self.fc1 = nn.Linear(input_dim, hidden_dim)
32
  self.relu = nn.ReLU()
33
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
34
  self.fc3 = nn.Linear(hidden_dim, output_dim)
35
 
36
+ def forward(self, x: torch.Tensor):
37
  x = self.fc1(x)
38
  x = self.relu(x)
39
  x = self.fc2(x)