fix-adapter-masks (#32)
Browse files- fix: adapter masks (934939f54211c85cc0a5f9891937c4015377c102)
Co-authored-by: Jack Min Ong <[email protected]>
block.py
CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
-
mlp_out = self.mlp(hidden_states,
|
237 |
if self.return_residual: # mlp out is actually a pair here
|
238 |
mlp_out, hidden_states = mlp_out
|
239 |
if not self.fused_dropout_add_ln:
|
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
+
mlp_out = self.mlp(hidden_states, adapter_mask=mixer_kwargs.get('adapter_mask'))
|
237 |
if self.return_residual: # mlp out is actually a pair here
|
238 |
mlp_out, hidden_states = mlp_out
|
239 |
if not self.fused_dropout_add_ln:
|
mha.py
CHANGED
@@ -590,7 +590,7 @@ class MHA(nn.Module):
|
|
590 |
max_seqlen=None,
|
591 |
mixer_subset=None,
|
592 |
inference_params=None,
|
593 |
-
|
594 |
**kwargs,
|
595 |
):
|
596 |
"""
|
@@ -647,13 +647,13 @@ class MHA(nn.Module):
|
|
647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
648 |
assert x_kv is None and mixer_subset is None
|
649 |
|
650 |
-
if
|
651 |
-
unique_tasks = torch.unique(
|
652 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
653 |
-
qkv = torch.empty(x.shape[
|
654 |
dtype=qkv_dtype, device=x.device)
|
655 |
for task_id in unique_tasks:
|
656 |
-
task_indices = (
|
657 |
task_tensor = x[task_indices]
|
658 |
if not self.return_residual:
|
659 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
@@ -755,13 +755,13 @@ class MHA(nn.Module):
|
|
755 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
756 |
|
757 |
inp = rearrange(context, "... h d -> ... (h d)")
|
758 |
-
if
|
759 |
-
unique_tasks = torch.unique(
|
760 |
out_dtype = next(self.out_proj.parameters()).dtype
|
761 |
-
out = torch.empty(inp.shape[
|
762 |
dtype=out_dtype, device=inp.device)
|
763 |
for task_id in unique_tasks:
|
764 |
-
task_indices = (
|
765 |
task_tensor = inp[task_indices]
|
766 |
task_out = self.out_proj(task_tensor, task_id=task_id)
|
767 |
out[task_indices] = task_out
|
|
|
590 |
max_seqlen=None,
|
591 |
mixer_subset=None,
|
592 |
inference_params=None,
|
593 |
+
adapter_mask=None,
|
594 |
**kwargs,
|
595 |
):
|
596 |
"""
|
|
|
647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
648 |
assert x_kv is None and mixer_subset is None
|
649 |
|
650 |
+
if adapter_mask is not None:
|
651 |
+
unique_tasks = torch.unique(adapter_mask)
|
652 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
653 |
+
qkv = torch.empty(*x.shape[:-1], self.Wqkv.out_features,
|
654 |
dtype=qkv_dtype, device=x.device)
|
655 |
for task_id in unique_tasks:
|
656 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
657 |
task_tensor = x[task_indices]
|
658 |
if not self.return_residual:
|
659 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
|
|
755 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
756 |
|
757 |
inp = rearrange(context, "... h d -> ... (h d)")
|
758 |
+
if adapter_mask is not None:
|
759 |
+
unique_tasks = torch.unique(adapter_mask)
|
760 |
out_dtype = next(self.out_proj.parameters()).dtype
|
761 |
+
out = torch.empty(*inp.shape[:-1], self.out_proj.out_features,
|
762 |
dtype=out_dtype, device=inp.device)
|
763 |
for task_id in unique_tasks:
|
764 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
765 |
task_tensor = inp[task_indices]
|
766 |
task_out = self.out_proj(task_tensor, task_id=task_id)
|
767 |
out[task_indices] = task_out
|
mlp.py
CHANGED
@@ -47,14 +47,14 @@ class Mlp(nn.Module):
|
|
47 |
self.activation = activation
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
-
def forward(self, x,
|
51 |
-
if
|
52 |
-
unique_tasks = torch.unique(
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
-
y = torch.empty(x.shape[
|
55 |
dtype=fc1_dtype, device=x.device)
|
56 |
for task_id in unique_tasks:
|
57 |
-
task_indices = (
|
58 |
task_tensor = x[task_indices]
|
59 |
task_y = self.fc1(task_tensor, task_id=task_id)
|
60 |
y[task_indices] = task_y
|
@@ -63,13 +63,13 @@ class Mlp(nn.Module):
|
|
63 |
|
64 |
y = self.activation(y)
|
65 |
|
66 |
-
if
|
67 |
-
unique_tasks = torch.unique(
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
-
out = torch.empty(y.shape[
|
70 |
dtype=fc2_dtype, device=y.device)
|
71 |
for task_id in unique_tasks:
|
72 |
-
task_indices = (
|
73 |
task_tensor = y[task_indices]
|
74 |
task_out = self.fc2(task_tensor, task_id=task_id)
|
75 |
out[task_indices] = task_out
|
|
|
47 |
self.activation = activation
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
+
def forward(self, x, adapter_mask=None):
|
51 |
+
if adapter_mask is not None:
|
52 |
+
unique_tasks = torch.unique(adapter_mask)
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
+
y = torch.empty(*x.shape[:-1], self.fc1.out_features,
|
55 |
dtype=fc1_dtype, device=x.device)
|
56 |
for task_id in unique_tasks:
|
57 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
task_tensor = x[task_indices]
|
59 |
task_y = self.fc1(task_tensor, task_id=task_id)
|
60 |
y[task_indices] = task_y
|
|
|
63 |
|
64 |
y = self.activation(y)
|
65 |
|
66 |
+
if adapter_mask is not None:
|
67 |
+
unique_tasks = torch.unique(adapter_mask)
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
+
out = torch.empty(*y.shape[:-1], self.fc2.out_features,
|
70 |
dtype=fc2_dtype, device=y.device)
|
71 |
for task_id in unique_tasks:
|
72 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
task_tensor = y[task_indices]
|
74 |
task_out = self.fc2(task_tensor, task_id=task_id)
|
75 |
out[task_indices] = task_out
|
modeling_xlm_roberta.py
CHANGED
@@ -230,7 +230,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
230 |
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
231 |
hidden_states, key_padding_mask, adapter_mask
|
232 |
)
|
233 |
-
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "
|
234 |
|
235 |
if subset_mask is None:
|
236 |
for layer in self.layers:
|
|
|
230 |
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
231 |
hidden_states, key_padding_mask, adapter_mask
|
232 |
)
|
233 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "adapter_mask": cu_adapter_mask}
|
234 |
|
235 |
if subset_mask is None:
|
236 |
for layer in self.layers:
|