peterroh commited on
Commit
38145a7
·
verified ·
1 Parent(s): 045081f

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +8 -7
modeling.py CHANGED
@@ -1,4 +1,5 @@
1
  from functools import partial
 
2
  import logging
3
  import re
4
  from typing import Optional, Tuple, Union
@@ -229,6 +230,8 @@ class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
229
 
230
  hidden_states = self.patch_embed(pixel_values)
231
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
 
 
232
 
233
  cu_seqlens = torch.repeat_interleave(
234
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -238,20 +241,18 @@ class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
238
  for blk in self.blocks:
239
  if output_hidden_states:
240
  encoder_states = encoder_states + (hidden_states,)
241
- if self.gradient_checkpointing and self.training:
242
- layer_outputs = self._gradient_checkpointing_func(
243
- blk.__call__,
244
  hidden_states,
245
- cu_seqlens,
246
- rotary_pos_emb,
247
  )
248
  else:
249
- layer_outputs = blk(
250
  hidden_states,
251
  cu_seqlens=cu_seqlens,
252
  rotary_pos_emb=rotary_pos_emb,
253
  )
254
- hidden_states = layer_outputs
255
  if output_hidden_states:
256
  encoder_states = encoder_states + (hidden_states,)
257
 
 
1
  from functools import partial
2
+ import inspect
3
  import logging
4
  import re
5
  from typing import Optional, Tuple, Union
 
230
 
231
  hidden_states = self.patch_embed(pixel_values)
232
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
233
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
234
+ position_embeddings = (emb.cos(), emb.sin())
235
 
236
  cu_seqlens = torch.repeat_interleave(
237
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
 
241
  for blk in self.blocks:
242
  if output_hidden_states:
243
  encoder_states = encoder_states + (hidden_states,)
244
+ if "position_embeddings" in inspect.signature(blk.forward).parameters:
245
+ hidden_states = blk(
 
246
  hidden_states,
247
+ cu_seqlens=cu_seqlens,
248
+ position_embeddings=position_embeddings,
249
  )
250
  else:
251
+ hidden_states = blk(
252
  hidden_states,
253
  cu_seqlens=cu_seqlens,
254
  rotary_pos_emb=rotary_pos_emb,
255
  )
 
256
  if output_hidden_states:
257
  encoder_states = encoder_states + (hidden_states,)
258