Update modeling_intern_vit.py (#4)
Browse files- Update modeling_intern_vit.py (188f686e8863f48f045de9508e6035f37809e1fd)
Co-authored-by: Roy Hvaara <[email protected]>
- modeling_intern_vit.py +5 -4
modeling_intern_vit.py
CHANGED
|
@@ -15,17 +15,18 @@ from transformers.activations import ACT2FN
|
|
| 15 |
from transformers.modeling_outputs import (BaseModelOutput,
|
| 16 |
BaseModelOutputWithPooling)
|
| 17 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 18 |
from transformers.utils import logging
|
| 19 |
|
| 20 |
from .configuration_intern_vit import InternVisionConfig
|
| 21 |
|
| 22 |
try:
|
| 23 |
-
|
| 24 |
-
from flash_attn.flash_attn_interface import \
|
| 25 |
-
flash_attn_unpadded_qkvpacked_func
|
| 26 |
-
except: # v2
|
| 27 |
from flash_attn.flash_attn_interface import \
|
| 28 |
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 31 |
|
|
|
|
| 15 |
from transformers.modeling_outputs import (BaseModelOutput,
|
| 16 |
BaseModelOutputWithPooling)
|
| 17 |
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
+
from transformers.utils.import_utils import is_flash_attn_greater_or_equal
|
| 19 |
from transformers.utils import logging
|
| 20 |
|
| 21 |
from .configuration_intern_vit import InternVisionConfig
|
| 22 |
|
| 23 |
try:
|
| 24 |
+
if is_flash_attn_greater_or_equal("2.0.0"):
|
|
|
|
|
|
|
|
|
|
| 25 |
from flash_attn.flash_attn_interface import \
|
| 26 |
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
| 27 |
+
else:
|
| 28 |
+
from flash_attn.flash_attn_interface import \
|
| 29 |
+
flash_attn_unpadded_qkvpacked_func
|
| 30 |
|
| 31 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 32 |
|