Update modeling.py
Browse files- modeling.py +8 -7
modeling.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from functools import partial
|
|
|
|
| 2 |
import logging
|
| 3 |
import re
|
| 4 |
from typing import Optional, Tuple, Union
|
|
@@ -229,6 +230,8 @@ class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
|
|
| 229 |
|
| 230 |
hidden_states = self.patch_embed(pixel_values)
|
| 231 |
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
|
|
|
|
|
| 232 |
|
| 233 |
cu_seqlens = torch.repeat_interleave(
|
| 234 |
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
|
@@ -238,20 +241,18 @@ class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
|
|
| 238 |
for blk in self.blocks:
|
| 239 |
if output_hidden_states:
|
| 240 |
encoder_states = encoder_states + (hidden_states,)
|
| 241 |
-
if
|
| 242 |
-
|
| 243 |
-
blk.__call__,
|
| 244 |
hidden_states,
|
| 245 |
-
cu_seqlens,
|
| 246 |
-
|
| 247 |
)
|
| 248 |
else:
|
| 249 |
-
|
| 250 |
hidden_states,
|
| 251 |
cu_seqlens=cu_seqlens,
|
| 252 |
rotary_pos_emb=rotary_pos_emb,
|
| 253 |
)
|
| 254 |
-
hidden_states = layer_outputs
|
| 255 |
if output_hidden_states:
|
| 256 |
encoder_states = encoder_states + (hidden_states,)
|
| 257 |
|
|
|
|
| 1 |
from functools import partial
|
| 2 |
+
import inspect
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
from typing import Optional, Tuple, Union
|
|
|
|
| 230 |
|
| 231 |
hidden_states = self.patch_embed(pixel_values)
|
| 232 |
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 233 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 234 |
+
position_embeddings = (emb.cos(), emb.sin())
|
| 235 |
|
| 236 |
cu_seqlens = torch.repeat_interleave(
|
| 237 |
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
|
|
|
| 241 |
for blk in self.blocks:
|
| 242 |
if output_hidden_states:
|
| 243 |
encoder_states = encoder_states + (hidden_states,)
|
| 244 |
+
if "position_embeddings" in inspect.signature(blk.forward).parameters:
|
| 245 |
+
hidden_states = blk(
|
|
|
|
| 246 |
hidden_states,
|
| 247 |
+
cu_seqlens=cu_seqlens,
|
| 248 |
+
position_embeddings=position_embeddings,
|
| 249 |
)
|
| 250 |
else:
|
| 251 |
+
hidden_states = blk(
|
| 252 |
hidden_states,
|
| 253 |
cu_seqlens=cu_seqlens,
|
| 254 |
rotary_pos_emb=rotary_pos_emb,
|
| 255 |
)
|
|
|
|
| 256 |
if output_hidden_states:
|
| 257 |
encoder_states = encoder_states + (hidden_states,)
|
| 258 |
|