Commit
·
1c61b96
1
Parent(s):
95b4916
support activation checkpointing
Browse files- modeling_xlm_roberta.py +46 -5
modeling_xlm_roberta.py
CHANGED
|
@@ -17,6 +17,7 @@ from functools import partial
|
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
| 19 |
import torch.nn.functional as F
|
|
|
|
| 20 |
from einops import rearrange
|
| 21 |
from transformers import PretrainedConfig
|
| 22 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -42,7 +43,6 @@ from .embedding import XLMRobertaEmbeddings
|
|
| 42 |
from .mha import MHA
|
| 43 |
from .mlp import FusedMLP, Mlp
|
| 44 |
|
| 45 |
-
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 46 |
|
| 47 |
try:
|
| 48 |
from flash_attn.ops.fused_dense import FusedDense
|
|
@@ -166,6 +166,15 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 166 |
self.layers = nn.ModuleList(
|
| 167 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 168 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 171 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
@@ -177,7 +186,15 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 177 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
| 178 |
)
|
| 179 |
for layer in self.layers:
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if subset_mask is not None:
|
| 182 |
hidden_states = hidden_states[subset_mask]
|
| 183 |
else:
|
|
@@ -188,11 +205,27 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 188 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 189 |
if subset_mask is None:
|
| 190 |
for layer in self.layers:
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 193 |
else:
|
| 194 |
for layer in self.layers[:-1]:
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if key_padding_mask is not None:
|
| 197 |
subset_idx = torch.nonzero(
|
| 198 |
subset_mask[key_padding_mask], as_tuple=False
|
|
@@ -218,7 +251,15 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 218 |
"cu_seqlens_k": cu_seqlens,
|
| 219 |
"max_seqlen_k": max_seqlen_in_batch,
|
| 220 |
}
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
return hidden_states
|
| 223 |
|
| 224 |
|
|
|
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
| 19 |
import torch.nn.functional as F
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
from einops import rearrange
|
| 22 |
from transformers import PretrainedConfig
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 43 |
from .mha import MHA
|
| 44 |
from .mlp import FusedMLP, Mlp
|
| 45 |
|
|
|
|
| 46 |
|
| 47 |
try:
|
| 48 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
|
| 166 |
self.layers = nn.ModuleList(
|
| 167 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 168 |
)
|
| 169 |
+
self._grad_checkpointing = False
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def gradient_checkpointing(self):
|
| 173 |
+
return self._grad_checkpointing
|
| 174 |
+
|
| 175 |
+
@gradient_checkpointing.setter
|
| 176 |
+
def gradient_checkpointing(self, value):
|
| 177 |
+
self._grad_checkpointing = value
|
| 178 |
|
| 179 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 180 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
|
|
| 186 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
| 187 |
)
|
| 188 |
for layer in self.layers:
|
| 189 |
+
if self._grad_checkpointing:
|
| 190 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 191 |
+
layer,
|
| 192 |
+
hidden_states,
|
| 193 |
+
use_reentrant=False,
|
| 194 |
+
mixer_kwargs=mixer_kwargs
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 198 |
if subset_mask is not None:
|
| 199 |
hidden_states = hidden_states[subset_mask]
|
| 200 |
else:
|
|
|
|
| 205 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 206 |
if subset_mask is None:
|
| 207 |
for layer in self.layers:
|
| 208 |
+
if self._grad_checkpointing:
|
| 209 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 210 |
+
layer,
|
| 211 |
+
hidden_states,
|
| 212 |
+
use_reentrant=False,
|
| 213 |
+
mixer_kwargs=mixer_kwargs
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 217 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 218 |
else:
|
| 219 |
for layer in self.layers[:-1]:
|
| 220 |
+
if self._grad_checkpointing:
|
| 221 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 222 |
+
layer,
|
| 223 |
+
hidden_states,
|
| 224 |
+
use_reentrant=False,
|
| 225 |
+
mixer_kwargs=mixer_kwargs
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 229 |
if key_padding_mask is not None:
|
| 230 |
subset_idx = torch.nonzero(
|
| 231 |
subset_mask[key_padding_mask], as_tuple=False
|
|
|
|
| 251 |
"cu_seqlens_k": cu_seqlens,
|
| 252 |
"max_seqlen_k": max_seqlen_in_batch,
|
| 253 |
}
|
| 254 |
+
if self._grad_checkpointing:
|
| 255 |
+
torch.utils.checkpoint.checkpoint(
|
| 256 |
+
self.layers[-1],
|
| 257 |
+
hidden_states_subset,
|
| 258 |
+
use_reentrant=False,
|
| 259 |
+
mixer_kwargs=mixer_kwargs
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
| 263 |
return hidden_states
|
| 264 |
|
| 265 |
|