yangzhitao commited on
Commit
1e0ed83
·
1 Parent(s): c94136a

refactor: improve benchmark handling in create_submit_tab function by restructuring input processing and enhancing data validation

Browse files
Files changed (1) hide show
  1. app.py +61 -30
app.py CHANGED
@@ -1,4 +1,10 @@
 
 
 
1
  import threading
 
 
 
2
 
3
  import gradio as gr
4
  import gradio.components as grc
@@ -139,8 +145,6 @@ def search_models_in_dataframe(search_text: str, df: pd.DataFrame) -> pd.DataFra
139
  return df
140
 
141
  # 分割逗号,去除空白并转换为小写用于匹配
142
- import re
143
-
144
  keywords = [keyword.strip().lower() for keyword in search_text.split(',') if keyword.strip()]
145
  if not keywords:
146
  return df
@@ -493,8 +497,6 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
493
  if file is None:
494
  return ""
495
  try:
496
- import json
497
-
498
  # file 是文件路径字符串(当 type="filepath" 时)
499
  file_path = file if isinstance(file, str) else file.name
500
  with open(file_path, encoding='utf-8') as f:
@@ -532,12 +534,10 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
532
  model_name: str,
533
  revision: str,
534
  precision: str,
535
- benchmark_checkbox_values: list,
536
- benchmark_result_values: list,
537
  ) -> str:
538
  """Build JSON from form inputs"""
539
- import json
540
-
541
  if not model_name or not model_name.strip():
542
  raise ValueError("Model name is required")
543
 
@@ -549,7 +549,7 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
549
  "model_name": model_name,
550
  "model_key": model_key,
551
  "model_dtype": f"torch.{precision}" if precision else None,
552
- "model_sha": revision or "main",
553
  "model_args": None,
554
  }
555
 
@@ -568,6 +568,21 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
568
 
569
  return json.dumps({"config": config, "results": results}, indent=2, ensure_ascii=False)
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  def submit_with_form_or_json(
572
  model: str,
573
  base_model: str,
@@ -577,12 +592,11 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
577
  model_type: str,
578
  json_str: str,
579
  commit_message: str,
580
- oauth_profile: gr.OAuthProfile | None = None,
581
- *benchmark_values,
 
582
  ):
583
  """Submit with either form data or JSON"""
584
- import json
585
-
586
  # Check if user is logged in
587
  if oauth_profile is None:
588
  return styled_error("Please log in before submitting.")
@@ -604,12 +618,25 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
604
  # Build JSON from form
605
  # benchmark_values contains pairs of (checkbox_value, result_value) for each benchmark
606
  benchmarks_list = get_benchmarks()
607
- if len(benchmark_values) != len(benchmarks_list) * 2:
 
 
 
 
 
 
 
 
608
  return styled_error("Invalid benchmark form data. Please check your inputs.")
609
 
610
  # Split into checkbox values and result values
611
- benchmark_checkbox_values = [benchmark_values[i] for i in range(0, len(benchmark_values), 2)]
612
- benchmark_result_values = [benchmark_values[i] for i in range(1, len(benchmark_values), 2)]
 
 
 
 
 
613
 
614
  try:
615
  final_json = build_json_from_form(
@@ -664,24 +691,28 @@ def create_submit_tab(tab_id: int, demo: gr.Blocks):
664
  submission_result = gr.Markdown()
665
 
666
  # Collect all inputs for submission
667
- all_inputs = [
668
- model_name_textbox,
669
- base_model_name_textbox,
670
- revision_name_textbox,
671
- precision,
672
- weight_type,
673
- model_type,
674
- json_str,
675
- commit_textbox,
676
- login_button, # oauth_profile must be before *benchmark_values
677
- ]
678
  # Add benchmark form inputs (these will be captured by *benchmark_values)
 
679
  for _, checkbox, result_input in benchmark_results_form:
680
- all_inputs.extend([checkbox, result_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
 
682
  submit_button.click(
683
- fn=submit_with_form_or_json,
684
- inputs=all_inputs,
685
  outputs=submission_result,
686
  )
687
 
 
1
+ import json
2
+ import re
3
+ import sys
4
  import threading
5
+ from collections import namedtuple
6
+ from functools import partial
7
+ from textwrap import dedent
8
 
9
  import gradio as gr
10
  import gradio.components as grc
 
145
  return df
146
 
147
  # 分割逗号,去除空白并转换为小写用于匹配
 
 
148
  keywords = [keyword.strip().lower() for keyword in search_text.split(',') if keyword.strip()]
149
  if not keywords:
150
  return df
 
497
  if file is None:
498
  return ""
499
  try:
 
 
500
  # file 是文件路径字符串(当 type="filepath" 时)
501
  file_path = file if isinstance(file, str) else file.name
502
  with open(file_path, encoding='utf-8') as f:
 
534
  model_name: str,
535
  revision: str,
536
  precision: str,
537
+ benchmark_checkbox_values: list[bool],
538
+ benchmark_result_values: list[float],
539
  ) -> str:
540
  """Build JSON from form inputs"""
 
 
541
  if not model_name or not model_name.strip():
542
  raise ValueError("Model name is required")
543
 
 
549
  "model_name": model_name,
550
  "model_key": model_key,
551
  "model_dtype": f"torch.{precision}" if precision else None,
552
+ "model_sha": revision or None, # None means "main"
553
  "model_args": None,
554
  }
555
 
 
568
 
569
  return json.dumps({"config": config, "results": results}, indent=2, ensure_ascii=False)
570
 
571
+ SubmitWithFormOrJsonInputs = namedtuple(
572
+ "SubmitWithFormOrJsonInputs",
573
+ [
574
+ "model",
575
+ "base_model",
576
+ "revision",
577
+ "precision",
578
+ "weight_type",
579
+ "model_type",
580
+ "json_str",
581
+ "commit_message",
582
+ # "oauth_profile",
583
+ ],
584
+ )
585
+
586
  def submit_with_form_or_json(
587
  model: str,
588
  base_model: str,
 
592
  model_type: str,
593
  json_str: str,
594
  commit_message: str,
595
+ oauth_profile: gr.OAuthProfile,
596
+ *,
597
+ benchmark_values: list[bool | float],
598
  ):
599
  """Submit with either form data or JSON"""
 
 
600
  # Check if user is logged in
601
  if oauth_profile is None:
602
  return styled_error("Please log in before submitting.")
 
618
  # Build JSON from form
619
  # benchmark_values contains pairs of (checkbox_value, result_value) for each benchmark
620
  benchmarks_list = get_benchmarks()
621
+ if len(benchmark_values) != len(benchmarks_list):
622
+ print(
623
+ dedent(f"""
624
+ Invalid benchmark form data. Please check your inputs.
625
+ * benchmarks_list: {benchmarks_list!r}
626
+ * benchmark_values: {benchmark_values!r}
627
+ """),
628
+ file=sys.stderr,
629
+ )
630
  return styled_error("Invalid benchmark form data. Please check your inputs.")
631
 
632
  # Split into checkbox values and result values
633
+ benchmark_checkbox_values: list[bool] = []
634
+ benchmark_result_values: list[float] = []
635
+ for i, val in enumerate(benchmark_values):
636
+ if i % 2 == 0:
637
+ benchmark_checkbox_values.append(bool(val))
638
+ else:
639
+ benchmark_result_values.append(float(val)) # pyright: ignore[reportArgumentType]
640
 
641
  try:
642
  final_json = build_json_from_form(
 
691
  submission_result = gr.Markdown()
692
 
693
  # Collect all inputs for submission
 
 
 
 
 
 
 
 
 
 
 
694
  # Add benchmark form inputs (these will be captured by *benchmark_values)
695
+ benchmark_values = []
696
  for _, checkbox, result_input in benchmark_results_form:
697
+ benchmark_values.extend([checkbox.value, result_input.value])
698
+
699
+ all_inputs = list(
700
+ SubmitWithFormOrJsonInputs(
701
+ model=model_name_textbox,
702
+ base_model=base_model_name_textbox,
703
+ revision=revision_name_textbox,
704
+ precision=precision,
705
+ weight_type=weight_type,
706
+ model_type=model_type,
707
+ json_str=json_str,
708
+ commit_message=commit_textbox,
709
+ # oauth_profile=login_button,
710
+ )
711
+ )
712
 
713
  submit_button.click(
714
+ fn=partial(submit_with_form_or_json, benchmark_values=benchmark_values),
715
+ inputs=list(all_inputs),
716
  outputs=submission_result,
717
  )
718