Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,7 +22,7 @@ os.environ['USE_FLASH_ATTENTION'] = '1'
|
|
| 22 |
|
| 23 |
import torch
|
| 24 |
|
| 25 |
-
torch.set_float32_matmul_precision('
|
| 26 |
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
| 27 |
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
| 28 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
@@ -61,11 +61,9 @@ print('=' * 70)
|
|
| 61 |
|
| 62 |
MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
|
| 63 |
|
| 64 |
-
|
| 65 |
|
| 66 |
-
MAX_MELODY_NOTES =
|
| 67 |
-
|
| 68 |
-
MAX_GEN_TOKS = 3072
|
| 69 |
|
| 70 |
#==================================================================================
|
| 71 |
|
|
@@ -126,14 +124,14 @@ print('=' * 70)
|
|
| 126 |
|
| 127 |
#==================================================================================
|
| 128 |
|
| 129 |
-
def load_midi(input_midi, melody_patch=-1
|
| 130 |
|
| 131 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 132 |
|
| 133 |
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
| 134 |
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
|
| 135 |
|
| 136 |
-
sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes
|
| 137 |
|
| 138 |
if melody_patch == -1:
|
| 139 |
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
|
@@ -147,31 +145,22 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
|
|
| 147 |
else:
|
| 148 |
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
| 149 |
|
| 150 |
-
cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES
|
| 151 |
|
| 152 |
score = []
|
| 153 |
|
| 154 |
-
score_list = []
|
| 155 |
-
|
| 156 |
pc = cscore[0]
|
| 157 |
|
| 158 |
for c in cscore:
|
| 159 |
score.append(max(0, min(127, c[0][1]-pc[0][1])))
|
| 160 |
-
|
| 161 |
-
scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
|
| 162 |
-
|
| 163 |
n = c[0]
|
| 164 |
|
| 165 |
score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
|
| 166 |
-
scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
|
| 167 |
-
|
| 168 |
-
score_list.append(scl)
|
| 169 |
|
| 170 |
pc = c
|
| 171 |
-
|
| 172 |
-
score_list.append(scl)
|
| 173 |
|
| 174 |
-
return score
|
| 175 |
|
| 176 |
#==================================================================================
|
| 177 |
|
|
@@ -179,7 +168,6 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
|
|
| 179 |
def Generate_Accompaniment(input_midi,
|
| 180 |
input_melody,
|
| 181 |
melody_patch,
|
| 182 |
-
use_nth_note,
|
| 183 |
model_temperature,
|
| 184 |
model_sampling_top_k
|
| 185 |
):
|
|
@@ -250,7 +238,6 @@ def Generate_Accompaniment(input_midi,
|
|
| 250 |
else:
|
| 251 |
print('Input sample melody:', input_melody)
|
| 252 |
print('Source melody patch:', melody_patch)
|
| 253 |
-
print('Use nth melody note:', use_nth_note)
|
| 254 |
print('Model temperature:', model_temperature)
|
| 255 |
print('Model top k:', model_sampling_top_k)
|
| 256 |
|
|
@@ -433,7 +420,6 @@ with gr.Blocks() as demo:
|
|
| 433 |
gr.Markdown("## Generation options")
|
| 434 |
|
| 435 |
melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
|
| 436 |
-
use_nth_note = gr.Slider(1, 8, value=1, step=1, label="Use each nth melody note")
|
| 437 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 438 |
model_sampling_top_k = gr.Slider(1, 100, value=15, step=1, label="Model sampling top k value")
|
| 439 |
|
|
@@ -450,7 +436,6 @@ with gr.Blocks() as demo:
|
|
| 450 |
[input_midi,
|
| 451 |
input_melody,
|
| 452 |
melody_patch,
|
| 453 |
-
use_nth_note,
|
| 454 |
model_temperature,
|
| 455 |
model_sampling_top_k
|
| 456 |
],
|
|
@@ -462,13 +447,12 @@ with gr.Blocks() as demo:
|
|
| 462 |
)
|
| 463 |
|
| 464 |
gr.Examples(
|
| 465 |
-
[["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", -1,
|
| 466 |
-
["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", -1,
|
| 467 |
],
|
| 468 |
[input_midi,
|
| 469 |
input_melody,
|
| 470 |
melody_patch,
|
| 471 |
-
use_nth_note,
|
| 472 |
model_temperature,
|
| 473 |
model_sampling_top_k
|
| 474 |
],
|
|
|
|
| 22 |
|
| 23 |
import torch
|
| 24 |
|
| 25 |
+
torch.set_float32_matmul_precision('high')
|
| 26 |
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
| 27 |
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
| 28 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
|
|
| 61 |
|
| 62 |
MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
|
| 63 |
|
| 64 |
+
SOUNDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
|
| 65 |
|
| 66 |
+
MAX_MELODY_NOTES = 128
|
|
|
|
|
|
|
| 67 |
|
| 68 |
#==================================================================================
|
| 69 |
|
|
|
|
| 124 |
|
| 125 |
#==================================================================================
|
| 126 |
|
| 127 |
+
def load_midi(input_midi, melody_patch=-1):
|
| 128 |
|
| 129 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 130 |
|
| 131 |
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
| 132 |
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
|
| 133 |
|
| 134 |
+
sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
|
| 135 |
|
| 136 |
if melody_patch == -1:
|
| 137 |
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
|
|
|
| 145 |
else:
|
| 146 |
zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
|
| 147 |
|
| 148 |
+
cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES]
|
| 149 |
|
| 150 |
score = []
|
| 151 |
|
|
|
|
|
|
|
| 152 |
pc = cscore[0]
|
| 153 |
|
| 154 |
for c in cscore:
|
| 155 |
score.append(max(0, min(127, c[0][1]-pc[0][1])))
|
| 156 |
+
|
|
|
|
|
|
|
| 157 |
n = c[0]
|
| 158 |
|
| 159 |
score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
pc = c
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
return score
|
| 164 |
|
| 165 |
#==================================================================================
|
| 166 |
|
|
|
|
| 168 |
def Generate_Accompaniment(input_midi,
|
| 169 |
input_melody,
|
| 170 |
melody_patch,
|
|
|
|
| 171 |
model_temperature,
|
| 172 |
model_sampling_top_k
|
| 173 |
):
|
|
|
|
| 238 |
else:
|
| 239 |
print('Input sample melody:', input_melody)
|
| 240 |
print('Source melody patch:', melody_patch)
|
|
|
|
| 241 |
print('Model temperature:', model_temperature)
|
| 242 |
print('Model top k:', model_sampling_top_k)
|
| 243 |
|
|
|
|
| 420 |
gr.Markdown("## Generation options")
|
| 421 |
|
| 422 |
melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
|
|
|
|
| 423 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
| 424 |
model_sampling_top_k = gr.Slider(1, 100, value=15, step=1, label="Model sampling top k value")
|
| 425 |
|
|
|
|
| 436 |
[input_midi,
|
| 437 |
input_melody,
|
| 438 |
melody_patch,
|
|
|
|
| 439 |
model_temperature,
|
| 440 |
model_sampling_top_k
|
| 441 |
],
|
|
|
|
| 447 |
)
|
| 448 |
|
| 449 |
gr.Examples(
|
| 450 |
+
[["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", -1, 0.9, 15],
|
| 451 |
+
["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", -1, 0.9, 15]
|
| 452 |
],
|
| 453 |
[input_midi,
|
| 454 |
input_melody,
|
| 455 |
melody_patch,
|
|
|
|
| 456 |
model_temperature,
|
| 457 |
model_sampling_top_k
|
| 458 |
],
|