asigalov61 commited on
Commit
dc3be67
·
verified ·
1 Parent(s): 7aa858c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -26
app.py CHANGED
@@ -22,7 +22,7 @@ os.environ['USE_FLASH_ATTENTION'] = '1'
22
 
23
  import torch
24
 
25
- torch.set_float32_matmul_precision('medium')
26
  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
  torch.backends.cuda.enable_mem_efficient_sdp(True)
@@ -61,11 +61,9 @@ print('=' * 70)
61
 
62
  MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
63
 
64
- SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
 
66
- MAX_MELODY_NOTES = 64
67
-
68
- MAX_GEN_TOKS = 3072
69
 
70
  #==================================================================================
71
 
@@ -126,14 +124,14 @@ print('=' * 70)
126
 
127
  #==================================================================================
128
 
129
- def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
130
 
131
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
132
 
133
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
134
  escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
135
 
136
- sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)
137
 
138
  if melody_patch == -1:
139
  zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
@@ -147,31 +145,22 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
147
  else:
148
  zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
149
 
150
- cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES:use_nth_note]
151
 
152
  score = []
153
 
154
- score_list = []
155
-
156
  pc = cscore[0]
157
 
158
  for c in cscore:
159
  score.append(max(0, min(127, c[0][1]-pc[0][1])))
160
-
161
- scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
162
-
163
  n = c[0]
164
 
165
  score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
166
- scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
167
-
168
- score_list.append(scl)
169
 
170
  pc = c
171
-
172
- score_list.append(scl)
173
 
174
- return score, score_list
175
 
176
  #==================================================================================
177
 
@@ -179,7 +168,6 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
179
  def Generate_Accompaniment(input_midi,
180
  input_melody,
181
  melody_patch,
182
- use_nth_note,
183
  model_temperature,
184
  model_sampling_top_k
185
  ):
@@ -250,7 +238,6 @@ def Generate_Accompaniment(input_midi,
250
  else:
251
  print('Input sample melody:', input_melody)
252
  print('Source melody patch:', melody_patch)
253
- print('Use nth melody note:', use_nth_note)
254
  print('Model temperature:', model_temperature)
255
  print('Model top k:', model_sampling_top_k)
256
 
@@ -433,7 +420,6 @@ with gr.Blocks() as demo:
433
  gr.Markdown("## Generation options")
434
 
435
  melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
436
- use_nth_note = gr.Slider(1, 8, value=1, step=1, label="Use each nth melody note")
437
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
438
  model_sampling_top_k = gr.Slider(1, 100, value=15, step=1, label="Model sampling top k value")
439
 
@@ -450,7 +436,6 @@ with gr.Blocks() as demo:
450
  [input_midi,
451
  input_melody,
452
  melody_patch,
453
- use_nth_note,
454
  model_temperature,
455
  model_sampling_top_k
456
  ],
@@ -462,13 +447,12 @@ with gr.Blocks() as demo:
462
  )
463
 
464
  gr.Examples(
465
- [["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", -1, 1, 0.9, 15],
466
- ["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", -1, 1, 0.9, 15]
467
  ],
468
  [input_midi,
469
  input_melody,
470
  melody_patch,
471
- use_nth_note,
472
  model_temperature,
473
  model_sampling_top_k
474
  ],
 
22
 
23
  import torch
24
 
25
+ torch.set_float32_matmul_precision('high')
26
  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
  torch.backends.cuda.enable_mem_efficient_sdp(True)
 
61
 
62
  MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
63
 
64
+ SOUNDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
 
66
+ MAX_MELODY_NOTES = 128
 
 
67
 
68
  #==================================================================================
69
 
 
124
 
125
  #==================================================================================
126
 
127
+ def load_midi(input_midi, melody_patch=-1):
128
 
129
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
130
 
131
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
132
  escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
133
 
134
+ sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
135
 
136
  if melody_patch == -1:
137
  zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
 
145
  else:
146
  zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
147
 
148
+ cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES]
149
 
150
  score = []
151
 
 
 
152
  pc = cscore[0]
153
 
154
  for c in cscore:
155
  score.append(max(0, min(127, c[0][1]-pc[0][1])))
156
+
 
 
157
  n = c[0]
158
 
159
  score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
 
 
 
160
 
161
  pc = c
 
 
162
 
163
+ return score
164
 
165
  #==================================================================================
166
 
 
168
  def Generate_Accompaniment(input_midi,
169
  input_melody,
170
  melody_patch,
 
171
  model_temperature,
172
  model_sampling_top_k
173
  ):
 
238
  else:
239
  print('Input sample melody:', input_melody)
240
  print('Source melody patch:', melody_patch)
 
241
  print('Model temperature:', model_temperature)
242
  print('Model top k:', model_sampling_top_k)
243
 
 
420
  gr.Markdown("## Generation options")
421
 
422
  melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
 
423
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
424
  model_sampling_top_k = gr.Slider(1, 100, value=15, step=1, label="Model sampling top k value")
425
 
 
436
  [input_midi,
437
  input_melody,
438
  melody_patch,
 
439
  model_temperature,
440
  model_sampling_top_k
441
  ],
 
447
  )
448
 
449
  gr.Examples(
450
+ [["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", -1, 0.9, 15],
451
+ ["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", -1, 0.9, 15]
452
  ],
453
  [input_midi,
454
  input_melody,
455
  melody_patch,
 
456
  model_temperature,
457
  model_sampling_top_k
458
  ],