fix-adapter-masks (#32)
Browse files- fix: adapter masks (934939f54211c85cc0a5f9891937c4015377c102)
Co-authored-by: Jack Min Ong <[email protected]>
block.py
CHANGED
|
@@ -233,7 +233,7 @@ class Block(nn.Module):
|
|
| 233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 234 |
)
|
| 235 |
if not isinstance(self.mlp, nn.Identity):
|
| 236 |
-
mlp_out = self.mlp(hidden_states,
|
| 237 |
if self.return_residual: # mlp out is actually a pair here
|
| 238 |
mlp_out, hidden_states = mlp_out
|
| 239 |
if not self.fused_dropout_add_ln:
|
|
|
|
| 233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 234 |
)
|
| 235 |
if not isinstance(self.mlp, nn.Identity):
|
| 236 |
+
mlp_out = self.mlp(hidden_states, adapter_mask=mixer_kwargs.get('adapter_mask'))
|
| 237 |
if self.return_residual: # mlp out is actually a pair here
|
| 238 |
mlp_out, hidden_states = mlp_out
|
| 239 |
if not self.fused_dropout_add_ln:
|
mha.py
CHANGED
|
@@ -590,7 +590,7 @@ class MHA(nn.Module):
|
|
| 590 |
max_seqlen=None,
|
| 591 |
mixer_subset=None,
|
| 592 |
inference_params=None,
|
| 593 |
-
|
| 594 |
**kwargs,
|
| 595 |
):
|
| 596 |
"""
|
|
@@ -647,13 +647,13 @@ class MHA(nn.Module):
|
|
| 647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 648 |
assert x_kv is None and mixer_subset is None
|
| 649 |
|
| 650 |
-
if
|
| 651 |
-
unique_tasks = torch.unique(
|
| 652 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 653 |
-
qkv = torch.empty(x.shape[
|
| 654 |
dtype=qkv_dtype, device=x.device)
|
| 655 |
for task_id in unique_tasks:
|
| 656 |
-
task_indices = (
|
| 657 |
task_tensor = x[task_indices]
|
| 658 |
if not self.return_residual:
|
| 659 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
|
@@ -755,13 +755,13 @@ class MHA(nn.Module):
|
|
| 755 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 756 |
|
| 757 |
inp = rearrange(context, "... h d -> ... (h d)")
|
| 758 |
-
if
|
| 759 |
-
unique_tasks = torch.unique(
|
| 760 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 761 |
-
out = torch.empty(inp.shape[
|
| 762 |
dtype=out_dtype, device=inp.device)
|
| 763 |
for task_id in unique_tasks:
|
| 764 |
-
task_indices = (
|
| 765 |
task_tensor = inp[task_indices]
|
| 766 |
task_out = self.out_proj(task_tensor, task_id=task_id)
|
| 767 |
out[task_indices] = task_out
|
|
|
|
| 590 |
max_seqlen=None,
|
| 591 |
mixer_subset=None,
|
| 592 |
inference_params=None,
|
| 593 |
+
adapter_mask=None,
|
| 594 |
**kwargs,
|
| 595 |
):
|
| 596 |
"""
|
|
|
|
| 647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 648 |
assert x_kv is None and mixer_subset is None
|
| 649 |
|
| 650 |
+
if adapter_mask is not None:
|
| 651 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 652 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 653 |
+
qkv = torch.empty(*x.shape[:-1], self.Wqkv.out_features,
|
| 654 |
dtype=qkv_dtype, device=x.device)
|
| 655 |
for task_id in unique_tasks:
|
| 656 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 657 |
task_tensor = x[task_indices]
|
| 658 |
if not self.return_residual:
|
| 659 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
|
|
|
| 755 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 756 |
|
| 757 |
inp = rearrange(context, "... h d -> ... (h d)")
|
| 758 |
+
if adapter_mask is not None:
|
| 759 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 760 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 761 |
+
out = torch.empty(*inp.shape[:-1], self.out_proj.out_features,
|
| 762 |
dtype=out_dtype, device=inp.device)
|
| 763 |
for task_id in unique_tasks:
|
| 764 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 765 |
task_tensor = inp[task_indices]
|
| 766 |
task_out = self.out_proj(task_tensor, task_id=task_id)
|
| 767 |
out[task_indices] = task_out
|
mlp.py
CHANGED
|
@@ -47,14 +47,14 @@ class Mlp(nn.Module):
|
|
| 47 |
self.activation = activation
|
| 48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
|
| 50 |
-
def forward(self, x,
|
| 51 |
-
if
|
| 52 |
-
unique_tasks = torch.unique(
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
-
y = torch.empty(x.shape[
|
| 55 |
dtype=fc1_dtype, device=x.device)
|
| 56 |
for task_id in unique_tasks:
|
| 57 |
-
task_indices = (
|
| 58 |
task_tensor = x[task_indices]
|
| 59 |
task_y = self.fc1(task_tensor, task_id=task_id)
|
| 60 |
y[task_indices] = task_y
|
|
@@ -63,13 +63,13 @@ class Mlp(nn.Module):
|
|
| 63 |
|
| 64 |
y = self.activation(y)
|
| 65 |
|
| 66 |
-
if
|
| 67 |
-
unique_tasks = torch.unique(
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
-
out = torch.empty(y.shape[
|
| 70 |
dtype=fc2_dtype, device=y.device)
|
| 71 |
for task_id in unique_tasks:
|
| 72 |
-
task_indices = (
|
| 73 |
task_tensor = y[task_indices]
|
| 74 |
task_out = self.fc2(task_tensor, task_id=task_id)
|
| 75 |
out[task_indices] = task_out
|
|
|
|
| 47 |
self.activation = activation
|
| 48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
|
| 50 |
+
def forward(self, x, adapter_mask=None):
|
| 51 |
+
if adapter_mask is not None:
|
| 52 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
+
y = torch.empty(*x.shape[:-1], self.fc1.out_features,
|
| 55 |
dtype=fc1_dtype, device=x.device)
|
| 56 |
for task_id in unique_tasks:
|
| 57 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 58 |
task_tensor = x[task_indices]
|
| 59 |
task_y = self.fc1(task_tensor, task_id=task_id)
|
| 60 |
y[task_indices] = task_y
|
|
|
|
| 63 |
|
| 64 |
y = self.activation(y)
|
| 65 |
|
| 66 |
+
if adapter_mask is not None:
|
| 67 |
+
unique_tasks = torch.unique(adapter_mask)
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
+
out = torch.empty(*y.shape[:-1], self.fc2.out_features,
|
| 70 |
dtype=fc2_dtype, device=y.device)
|
| 71 |
for task_id in unique_tasks:
|
| 72 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 73 |
task_tensor = y[task_indices]
|
| 74 |
task_out = self.fc2(task_tensor, task_id=task_id)
|
| 75 |
out[task_indices] = task_out
|
modeling_xlm_roberta.py
CHANGED
|
@@ -230,7 +230,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 230 |
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
| 231 |
hidden_states, key_padding_mask, adapter_mask
|
| 232 |
)
|
| 233 |
-
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "
|
| 234 |
|
| 235 |
if subset_mask is None:
|
| 236 |
for layer in self.layers:
|
|
|
|
| 230 |
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
| 231 |
hidden_states, key_padding_mask, adapter_mask
|
| 232 |
)
|
| 233 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "adapter_mask": cu_adapter_mask}
|
| 234 |
|
| 235 |
if subset_mask is None:
|
| 236 |
for layer in self.layers:
|