kvaishnavi commited on
Commit
3749258
·
1 Parent(s): 5c58d1e

Add WebGPU as an option

Browse files
Files changed (1) hide show
  1. onnx/builder.py +14 -12
onnx/builder.py CHANGED
@@ -31,7 +31,7 @@ def build_vision(args):
31
  url = "https://wallpaper.dog/large/10809054.jpg"
32
  image_4 = Image.open(requests.get(url, stream=True).raw)
33
  images = [image_1, image_2, image_3, image_4]
34
- inputs = processor(prompt, images=images, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
35
  inputs["input_image_embeds"] = inputs["input_image_embeds"].to(args.precision)
36
  inputs["image_attention_mask"] = inputs["image_attention_mask"].to(args.precision)
37
 
@@ -110,7 +110,7 @@ def build_vision(args):
110
  "--output_model", fpath_4,
111
  "--block_size", str(32),
112
  ]
113
- if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
114
  subprocess.run(cmd)
115
  shutil.rmtree(temp_folder_3)
116
 
@@ -120,7 +120,7 @@ def build_speech(args):
120
  prompt = f"{user_prompt}<|audio_1|>\n<|audio_2|>\nWhat are the stories that these audios come from?{prompt_suffix}{assistant_prompt}"
121
  audio1 = soundfile.read(os.path.join(args.input, "examples", "what_is_the_traffic_sign_in_the_image.wav"))
122
  audio2 = soundfile.read(os.path.join(args.input, "examples", "what_is_shown_in_this_image.wav"))
123
- inputs = processor(prompt, audios=[audio1, audio2], return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
124
  inputs["input_audio_embeds"] = inputs["input_audio_embeds"].to(args.precision)
125
 
126
  # TorchScript export
@@ -232,7 +232,7 @@ def build_speech(args):
232
  "--output_model", fpath_5,
233
  "--block_size", str(32),
234
  ]
235
- if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
236
  subprocess.run(cmd)
237
  shutil.rmtree(temp_folder_4)
238
 
@@ -241,9 +241,9 @@ def build_embedding(args):
241
  # TorchScript export
242
  batch_size, sequence_length, num_image_tokens, num_audio_tokens = 2, 8, 2, 2
243
  inputs = {
244
- "input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.execution_provider.replace("dml", "cuda"), dtype=torch.int64),
245
- "image_features": torch.randn(num_image_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
246
- "audio_features": torch.randn(num_audio_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
247
  }
248
  inputs["input_ids"][0][0] = -1
249
  inputs["input_ids"][0][1] = -1
@@ -302,8 +302,9 @@ def build_text(args):
302
  extra_options = {
303
  "exclude_embeds": "true",
304
  "filename": "phi-4-mm-text.onnx",
 
305
  }
306
- if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4
307
  create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options)
308
 
309
 
@@ -533,7 +534,7 @@ def build_quantized_adapters(args):
533
  "--output_model", fpath_3,
534
  "--block_size", str(32),
535
  ]
536
- if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
537
  subprocess.run(cmd)
538
 
539
  filename = "phi-4-mm-qlora-speech.onnx"
@@ -544,7 +545,7 @@ def build_quantized_adapters(args):
544
  "--output_model", fpath_4,
545
  "--block_size", str(32),
546
  ]
547
- if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
548
  subprocess.run(cmd)
549
 
550
  os.remove(fpath_1)
@@ -594,7 +595,7 @@ def get_args():
594
  "-e",
595
  "--execution_provider",
596
  required=True,
597
- choices=["cpu", "cuda", "dml"],
598
  help="Execution provider for Phi-4 multimodal components",
599
  )
600
 
@@ -608,6 +609,7 @@ def get_args():
608
 
609
  args = parser.parse_args()
610
  args.precision = torch.float16 if args.precision == "fp16" else torch.float32
 
611
  return args
612
 
613
  if __name__ == "__main__":
@@ -618,7 +620,7 @@ if __name__ == "__main__":
618
  args = get_args()
619
  config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
620
  processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
621
- model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda"))
622
 
623
  # Build model components
624
  build_vision(args)
 
31
  url = "https://wallpaper.dog/large/10809054.jpg"
32
  image_4 = Image.open(requests.get(url, stream=True).raw)
33
  images = [image_1, image_2, image_3, image_4]
34
+ inputs = processor(prompt, images=images, return_tensors="pt").to(args.device)
35
  inputs["input_image_embeds"] = inputs["input_image_embeds"].to(args.precision)
36
  inputs["image_attention_mask"] = inputs["image_attention_mask"].to(args.precision)
37
 
 
110
  "--output_model", fpath_4,
111
  "--block_size", str(32),
112
  ]
113
+ if args.precision == torch.float32 or args.execution_provider == "webgpu": cmd.extend(["--accuracy_level", str(4)])
114
  subprocess.run(cmd)
115
  shutil.rmtree(temp_folder_3)
116
 
 
120
  prompt = f"{user_prompt}<|audio_1|>\n<|audio_2|>\nWhat are the stories that these audios come from?{prompt_suffix}{assistant_prompt}"
121
  audio1 = soundfile.read(os.path.join(args.input, "examples", "what_is_the_traffic_sign_in_the_image.wav"))
122
  audio2 = soundfile.read(os.path.join(args.input, "examples", "what_is_shown_in_this_image.wav"))
123
+ inputs = processor(prompt, audios=[audio1, audio2], return_tensors="pt").to(args.device)
124
  inputs["input_audio_embeds"] = inputs["input_audio_embeds"].to(args.precision)
125
 
126
  # TorchScript export
 
232
  "--output_model", fpath_5,
233
  "--block_size", str(32),
234
  ]
235
+ if args.precision == torch.float32 or args.execution_provider == "webgpu": cmd.extend(["--accuracy_level", str(4)])
236
  subprocess.run(cmd)
237
  shutil.rmtree(temp_folder_4)
238
 
 
241
  # TorchScript export
242
  batch_size, sequence_length, num_image_tokens, num_audio_tokens = 2, 8, 2, 2
243
  inputs = {
244
+ "input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.device, dtype=torch.int64),
245
+ "image_features": torch.randn(num_image_tokens, config.hidden_size, device=args.device, dtype=args.precision),
246
+ "audio_features": torch.randn(num_audio_tokens, config.hidden_size, device=args.device, dtype=args.precision),
247
  }
248
  inputs["input_ids"][0][0] = -1
249
  inputs["input_ids"][0][1] = -1
 
302
  extra_options = {
303
  "exclude_embeds": "true",
304
  "filename": "phi-4-mm-text.onnx",
305
+ "int4_algo_config": "k_quant_last",
306
  }
307
+ if args.precision == torch.float32 or args.execution_provider == "webgpu": extra_options["int4_accuracy_level"] = 4
308
  create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options)
309
 
310
 
 
534
  "--output_model", fpath_3,
535
  "--block_size", str(32),
536
  ]
537
+ if args.precision == torch.float32 or args.execution_provider == "webgpu": cmd.extend(["--accuracy_level", str(4)])
538
  subprocess.run(cmd)
539
 
540
  filename = "phi-4-mm-qlora-speech.onnx"
 
545
  "--output_model", fpath_4,
546
  "--block_size", str(32),
547
  ]
548
+ if args.precision == torch.float32 or args.execution_provider == "webgpu": cmd.extend(["--accuracy_level", str(4)])
549
  subprocess.run(cmd)
550
 
551
  os.remove(fpath_1)
 
595
  "-e",
596
  "--execution_provider",
597
  required=True,
598
+ choices=["cpu", "cuda", "dml", "webgpu"],
599
  help="Execution provider for Phi-4 multimodal components",
600
  )
601
 
 
609
 
610
  args = parser.parse_args()
611
  args.precision = torch.float16 if args.precision == "fp16" else torch.float32
612
+ setattr(args, "device", args.execution_provider.replace("dml", "cuda").replace("webgpu", "cuda"))
613
  return args
614
 
615
  if __name__ == "__main__":
 
620
  args = get_args()
621
  config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
622
  processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
623
+ model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.device)
624
 
625
  # Build model components
626
  build_vision(args)