How can I finetune mnv5 with my own LLM?

#38
by lsxx - opened

I'm trying to finetune mnv5 together with my 3b llama style LLM, I extracted the mnv5 weight from the original gemma3n weight and copy the vision config from it.When training, gradient NAN issue exists at first step, after printing gradients, I found that gradient explosion occurs in the mnv5 vision encoder. Could you help me find where the problem is?

this is my gradient log:
vision_model.timm_model.blocks.1.2.dw_start.conv - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000353, 0.000106] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.2.dw_start - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000763, 0.000732] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.2.dw_start - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000851, 0.000771] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.2 - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000751, 0.000732] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.2 - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000763, 0.000504] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.drop_path - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000751, 0.000732] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.drop_path - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000751, 0.000732] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.layer_scale - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.layer_scale - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000751, 0.000732] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_end - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_end - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.bn.act - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.bn.act - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.bn.drop - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.bn.drop - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.bn - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-0.925781, 1.203125] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.bn - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.conv - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj.conv - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.925781, 1.203125] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_proj - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000759, 0.000881] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.se - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.se - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_mid - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_mid - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.bn.act - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-8.937500, 10.437500] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.bn.act - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.bn.drop - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-8.937500, 10.437500] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.bn.drop - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-8.937500, 10.437500] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.bn - torch.Size([1, 1024, 96, 96]) grad_input[0] range: [-69632.000000, 33792.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.bn - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.conv - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp.conv - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-69632.000000, 33792.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.pw_exp - torch.Size([1, 1024, 96, 96]) grad_output[0] range: [-17.875000, 20.875000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.bn.act - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.bn.act - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.bn.drop - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.bn.drop - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.bn - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-3053453312.000000, 4898947072.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.bn - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.conv - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-28454158336.000000, 38923141120.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start.conv - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-3053453312.000000, 4898947072.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-28454158336.000000, 38923141120.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1.dw_start - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-1130496.000000, 1630208.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1 - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-28454158336.000000, 38923141120.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.1 - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000751, 0.000732] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.layer_scale - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.layer_scale - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-28454158336.000000, 38923141120.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_end - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_end - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.bn.act - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.bn.act - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.bn.drop - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.bn.drop - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.bn - torch.Size([1, 256, 96, 96]) grad_input[0] range: [-103903848824832.000000, 120396523241472.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.bn - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.conv - torch.Size([1, 768, 96, 96]) grad_input[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj.conv - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-103903848824832.000000, 120396523241472.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj - torch.Size([1, 768, 96, 96]) grad_input[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_proj - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-51271172096.000000, 43754979328.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.se - torch.Size([1, 768, 96, 96]) grad_input[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.se - torch.Size([1, 768, 96, 96]) grad_output[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.bn.act - torch.Size([1, 768, 96, 96]) grad_input[0] range: [-1310617860308992.000000, 1398578790531072.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.bn.act - torch.Size([1, 768, 96, 96]) grad_output[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.bn.drop - torch.Size([1, 768, 96, 96]) grad_input[0] range: [-1310617860308992.000000, 1398578790531072.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.bn.drop - torch.Size([1, 768, 96, 96]) grad_output[0] range: [-1310617860308992.000000, 1398578790531072.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.bn - torch.Size([1, 768, 96, 96]) grad_input[0] range: [-7962364141191036928.000000, 7313845794849685504.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.bn - torch.Size([1, 768, 96, 96]) grad_output[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.conv - torch.Size([1, 768, 192, 192]) grad_input[0] range: [-50728546202701266944.000000, 42081634918149914624.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid.conv - torch.Size([1, 768, 96, 96]) grad_output[0] range: [-7962364141191036928.000000, 7313845794849685504.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid - torch.Size([1, 768, 192, 192]) grad_input[0] range: [-50728546202701266944.000000, 42081634918149914624.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_mid - torch.Size([1, 768, 96, 96]) grad_output[0] range: [-2621235720617984.000000, 2797157581062144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.bn.act - torch.Size([1, 768, 192, 192]) grad_input[0] range: [-25364273101350633472.000000, 21040817459074957312.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.bn.act - torch.Size([1, 768, 192, 192]) grad_output[0] range: [-50728546202701266944.000000, 42081634918149914624.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.bn.drop - torch.Size([1, 768, 192, 192]) grad_input[0] range: [-25364273101350633472.000000, 21040817459074957312.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.bn.drop - torch.Size([1, 768, 192, 192]) grad_output[0] range: [-25364273101350633472.000000, 21040817459074957312.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.bn - torch.Size([1, 768, 192, 192]) grad_input[0] range: [-34532304905984280625152.000000, 18446744073709551616000.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.bn - torch.Size([1, 768, 192, 192]) grad_output[0] range: [-50728546202701266944.000000, 42081634918149914624.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.conv - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp.conv - torch.Size([1, 768, 192, 192]) grad_output[0] range: [-34532304905984280625152.000000, 18446744073709551616000.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.pw_exp - torch.Size([1, 768, 192, 192]) grad_output[0] range: [-50728546202701266944.000000, 42081634918149914624.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.bn.act - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.bn.act - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.bn.drop - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.bn.drop - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.bn - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-1721510367131231944781594624.000000, 5145188288279861767549485056.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.bn - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.conv - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start.conv - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-1721510367131231944781594624.000000, 5145188288279861767549485056.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0.dw_start - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-694187872981837846413312.000000, 774468103190621815046144.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0 - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1.0 - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-28454158336.000000, 38923141120.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1 - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.1 - torch.Size([1, 256, 96, 96]) grad_output[0] range: [-0.000740, 0.000376] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.drop_path - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.drop_path - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn2.act - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn2.act - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn2.drop - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn2.drop - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn2 - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-19727812466051820060792443633664.000000, 20440865928680199099134339186688.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn2 - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.conv_pwl - torch.Size([1, 512, 192, 192]) grad_input[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.conv_pwl - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-19727812466051820060792443633664.000000, 20440865928680199099134339186688.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.se - torch.Size([1, 512, 192, 192]) grad_input[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.se - torch.Size([1, 512, 192, 192]) grad_output[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.aa - torch.Size([1, 512, 192, 192]) grad_input[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.aa - torch.Size([1, 512, 192, 192]) grad_output[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn1.act - torch.Size([1, 512, 192, 192]) grad_input[0] range: [-166062228629898051596068119904256.000000, 116623855220997104937696694894592.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn1.act - torch.Size([1, 512, 192, 192]) grad_output[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn1.drop - torch.Size([1, 512, 192, 192]) grad_input[0] range: [-166062228629898051596068119904256.000000, 116623855220997104937696694894592.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn1.drop - torch.Size([1, 512, 192, 192]) grad_output[0] range: [-166062228629898051596068119904256.000000, 116623855220997104937696694894592.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn1 - torch.Size([1, 512, 192, 192]) grad_input[0] range: [-212884171199927932769750349498023936.000000, 258316768712107674519392192378699776.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.bn1 - torch.Size([1, 512, 192, 192]) grad_output[0] range: [-332124457259796103192136239808512.000000, 233247710441994209875393389789184.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.conv_exp - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2.conv_exp - torch.Size([1, 512, 192, 192]) grad_output[0] range: [-212884171199927932769750349498023936.000000, 258316768712107674519392192378699776.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2 - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.2 - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-32650668536151904750464401408.000000, 17640645559816668917312520192.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.drop_path - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.drop_path - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.bn2.act - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.bn2.act - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.bn2.drop - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.bn2.drop - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.bn2 - torch.Size([1, 128, 192, 192]) grad_input[0] range: [-inf, inf] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.bn2 - torch.Size([1, 128, 192, 192]) grad_output[0] range: [-22098415429924226387025792377160728576.000000, 35723002386719614084289814745034260480.000000] dtype: torch.bfloat16
vision_model.timm_model.blocks.0.1.conv_pwl - torch.Size([1, 512, 192, 192]) grad_input[0] range: [nan, nan] dtype: torch.bfloat16
NaN detected in grad_input[0] of vision_model.timm_model.blocks.0.1.conv_pwl

my environment:
nvidia L40s*4
timm=1.0.16
transformers=4.53.0
torch=2.5
cuda=12.2

Hi,

Thanks for reaching out to us, This is an excellent, detailed log of a gradient explosion issue, which is a common problem when combining and fine-tuning models, especially across different architectures and with mixed precision like bfloat16.

The log clearly points to where the explosion is originating: the early layers of the MobileNetV5 (MNV5) vision encoder's blocks, specifically within the Batch Normalization (BN) layers' backward pass.

  1. Crucial: Freeze the running statistics of the MobileNetV5 Batch Normalization layers.
  2. Essential: Use a low global learning rate and apply gradient clipping (e.g., max_norm=1.0).
  3. Highly Recommended: Increase your effective batch size (Batch Size × Accumulation Steps to at least 32, which is good practice for both BN stability and overall training quality.

Thanks.

This comment has been hidden (marked as Off-Topic)

Hi,

Thanks for reaching out to us, This is an excellent, detailed log of a gradient explosion issue, which is a common problem when combining and fine-tuning models, especially across different architectures and with mixed precision like bfloat16.

The log clearly points to where the explosion is originating: the early layers of the MobileNetV5 (MNV5) vision encoder's blocks, specifically within the Batch Normalization (BN) layers' backward pass.

  1. Crucial: Freeze the running statistics of the MobileNetV5 Batch Normalization layers.
  2. Essential: Use a low global learning rate and apply gradient clipping (e.g., max_norm=1.0).
  3. Highly Recommended: Increase your effective batch size (Batch Size × Accumulation Steps to at least 32, which is good practice for both BN stability and overall training quality.

Thanks.
I've checked every block and found that gradient explosion may happen everywhere in the mnv5 model, is this an AI response?

Sign up or log in to comment