Spaces:
Runtime error
Runtime error
prismleong
commited on
Commit
·
32a5366
1
Parent(s):
fc38cab
update
Browse files- app.py +73 -34
- infer.py +256 -22
- rift_svc/rf.py +123 -33
app.py
CHANGED
|
@@ -14,7 +14,8 @@ from infer import (
|
|
| 14 |
load_models,
|
| 15 |
load_audio,
|
| 16 |
apply_fade,
|
| 17 |
-
process_segment
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
# Global variables for models
|
|
@@ -115,7 +116,9 @@ def process_with_progress(
|
|
| 115 |
slicer_min_length=3000,
|
| 116 |
slicer_min_interval=100,
|
| 117 |
slicer_hop_size=10,
|
| 118 |
-
slicer_max_sil_kept=200
|
|
|
|
|
|
|
| 119 |
):
|
| 120 |
global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
|
| 121 |
|
|
@@ -182,41 +185,71 @@ def process_with_progress(
|
|
| 182 |
progress(0.2, desc="处理中: 开始转换...")
|
| 183 |
|
| 184 |
with torch.no_grad():
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
audio_out = process_segment(
|
| 191 |
-
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 192 |
speaker_id, sample_rate, hop_length, device,
|
| 193 |
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 194 |
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 195 |
-
cvec_downsample_rate, target_loudness, restore_loudness,
|
| 196 |
-
robust_f0, use_fp16
|
| 197 |
)
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
progress(0.9, desc="处理中: 完成音频...")
|
| 222 |
# Trim any extra padding
|
|
@@ -230,7 +263,8 @@ def process_with_progress(
|
|
| 230 |
torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate)
|
| 231 |
|
| 232 |
progress(1.0, desc="处理完成!")
|
| 233 |
-
|
|
|
|
| 234 |
|
| 235 |
except RuntimeError as e:
|
| 236 |
# Handle CUDA out of memory errors
|
|
@@ -341,6 +375,9 @@ def create_ui():
|
|
| 341 |
robust_f0 = gr.Radio(choices=[0, 1, 2], value=1, label="音高滤波",
|
| 342 |
info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)",
|
| 343 |
elem_id="robust_f0")
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
with gr.Accordion("🔬 高级CFG参数", open=True):
|
| 346 |
ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2,
|
|
@@ -403,6 +440,7 @@ def create_ui():
|
|
| 403 |
<li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li>
|
| 404 |
<li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li>
|
| 405 |
<li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li>
|
|
|
|
| 406 |
<li><strong>CFG参数:</strong> 调整转换质量和音色。</li>
|
| 407 |
</ul>
|
| 408 |
</div>
|
|
@@ -426,7 +464,8 @@ def create_ui():
|
|
| 426 |
inputs=[
|
| 427 |
input_audio, speaker, key_shift, infer_steps, robust_f0,
|
| 428 |
ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate,
|
| 429 |
-
slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept
|
|
|
|
| 430 |
],
|
| 431 |
outputs=[output_audio, output_message],
|
| 432 |
show_progress_on=output_audio
|
|
|
|
| 14 |
load_models,
|
| 15 |
load_audio,
|
| 16 |
apply_fade,
|
| 17 |
+
process_segment,
|
| 18 |
+
batch_process_segments
|
| 19 |
)
|
| 20 |
|
| 21 |
# Global variables for models
|
|
|
|
| 116 |
slicer_min_length=3000,
|
| 117 |
slicer_min_interval=100,
|
| 118 |
slicer_hop_size=10,
|
| 119 |
+
slicer_max_sil_kept=200,
|
| 120 |
+
# Batch processing
|
| 121 |
+
batch_size=1
|
| 122 |
):
|
| 123 |
global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
|
| 124 |
|
|
|
|
| 185 |
progress(0.2, desc="处理中: 开始转换...")
|
| 186 |
|
| 187 |
with torch.no_grad():
|
| 188 |
+
if batch_size > 1:
|
| 189 |
+
# Use batch processing
|
| 190 |
+
progress_desc = f"处理中: 批次 {{0}}/{{1}}"
|
| 191 |
+
processed_segments = batch_process_segments(
|
| 192 |
+
segments_with_pos, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
|
|
|
|
|
|
| 193 |
speaker_id, sample_rate, hop_length, device,
|
| 194 |
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 195 |
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 196 |
+
cvec_downsample_rate, target_loudness, restore_loudness,
|
| 197 |
+
robust_f0, use_fp16, batch_size, progress, progress_desc
|
| 198 |
)
|
| 199 |
|
| 200 |
+
for idx, (start_sample, audio_out, expected_length) in enumerate(processed_segments):
|
| 201 |
+
# Apply fades
|
| 202 |
+
if idx > 0: # Not first segment
|
| 203 |
+
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 204 |
+
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 205 |
+
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 206 |
+
|
| 207 |
+
if idx < len(processed_segments) - 1: # Not last segment
|
| 208 |
+
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 209 |
+
|
| 210 |
+
# Add to result
|
| 211 |
+
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
| 212 |
+
|
| 213 |
+
# Clean up memory after each segment
|
| 214 |
+
if idx % 5 == 0: # Clean up every 5 segments
|
| 215 |
+
torch.cuda.empty_cache()
|
| 216 |
+
else:
|
| 217 |
+
# Use sequential processing
|
| 218 |
+
for i, (start_sample, chunk) in enumerate(segments_with_pos):
|
| 219 |
+
segment_progress = 0.2 + (0.7 * (i / len(segments_with_pos)))
|
| 220 |
+
progress(segment_progress, desc=f"处理中: 片段 {i+1}/{len(segments_with_pos)}")
|
| 221 |
+
|
| 222 |
+
# Process the segment
|
| 223 |
+
audio_out = process_segment(
|
| 224 |
+
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 225 |
+
speaker_id, sample_rate, hop_length, device,
|
| 226 |
+
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 227 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 228 |
+
cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
|
| 229 |
+
robust_f0, use_fp16
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Ensure consistent length
|
| 233 |
+
expected_length = len(chunk)
|
| 234 |
+
if len(audio_out) > expected_length:
|
| 235 |
+
audio_out = audio_out[:expected_length]
|
| 236 |
+
elif len(audio_out) < expected_length:
|
| 237 |
+
audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
|
| 238 |
+
|
| 239 |
+
# Apply fades
|
| 240 |
+
if i > 0: # Not first segment
|
| 241 |
+
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 242 |
+
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 243 |
+
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 244 |
+
|
| 245 |
+
if i < len(segments_with_pos) - 1: # Not last segment
|
| 246 |
+
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 247 |
+
|
| 248 |
+
# Add to result
|
| 249 |
+
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
| 250 |
+
|
| 251 |
+
# Clean up memory after each segment
|
| 252 |
+
torch.cuda.empty_cache()
|
| 253 |
|
| 254 |
progress(0.9, desc="处理中: 完成音频...")
|
| 255 |
# Trim any extra padding
|
|
|
|
| 263 |
torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate)
|
| 264 |
|
| 265 |
progress(1.0, desc="处理完成!")
|
| 266 |
+
batch_text = f"批处理大小 {batch_size}" if batch_size > 1 else "顺序处理"
|
| 267 |
+
return (sample_rate, result_audio), f"✅ 转换完成! 已转换为 **{speaker}** 并调整 **{key_shift}** 个半音。{batch_text}"
|
| 268 |
|
| 269 |
except RuntimeError as e:
|
| 270 |
# Handle CUDA out of memory errors
|
|
|
|
| 375 |
robust_f0 = gr.Radio(choices=[0, 1, 2], value=1, label="音高滤波",
|
| 376 |
info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)",
|
| 377 |
elem_id="robust_f0")
|
| 378 |
+
batch_size = gr.Slider(minimum=1, maximum=64, step=1, value=4, label="批处理大小",
|
| 379 |
+
info="使用批处理可以加速转换,但需要更多VRAM。1=不使用批处理",
|
| 380 |
+
elem_id="batch_size")
|
| 381 |
|
| 382 |
with gr.Accordion("🔬 高级CFG参数", open=True):
|
| 383 |
ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2,
|
|
|
|
| 440 |
<li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li>
|
| 441 |
<li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li>
|
| 442 |
<li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li>
|
| 443 |
+
<li><strong>批处理大小:</strong> 值越大 = 转换越快,但需要更多GPU内存。遇到内存不足时降低此值。</li>
|
| 444 |
<li><strong>CFG参数:</strong> 调整转换质量和音色。</li>
|
| 445 |
</ul>
|
| 446 |
</div>
|
|
|
|
| 464 |
inputs=[
|
| 465 |
input_audio, speaker, key_shift, infer_steps, robust_f0,
|
| 466 |
ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate,
|
| 467 |
+
slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept,
|
| 468 |
+
batch_size
|
| 469 |
],
|
| 470 |
outputs=[output_audio, output_message],
|
| 471 |
show_progress_on=output_audio
|
infer.py
CHANGED
|
@@ -184,12 +184,31 @@ def run_inference(
|
|
| 184 |
model, mel, cvec, f0, rms, cvec_ds, spk_id,
|
| 185 |
infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 186 |
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 187 |
-
sliced_inference=False, use_fp16=True
|
| 188 |
):
|
| 189 |
"""Run the actual inference through the model"""
|
| 190 |
device_type = 'cuda' if mel.device.type == 'cuda' else 'cpu'
|
| 191 |
|
| 192 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# Use sliced inference for long segments
|
| 194 |
sliced_len = 256
|
| 195 |
mel_crossfade_len = 8 # Number of frames to crossfade in mel domain
|
|
@@ -392,6 +411,191 @@ def process_segment(
|
|
| 392 |
return audio_out
|
| 393 |
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
@click.command()
|
| 396 |
@click.option('--model', type=click.Path(exists=True), required=True, help='Path to model checkpoint')
|
| 397 |
@click.option('--input', type=click.Path(exists=True), required=True, help='Input audio file')
|
|
@@ -417,6 +621,7 @@ def process_segment(
|
|
| 417 |
@click.option('--slicer-hop-size', type=int, default=10, help='Hop size for audio slicing in milliseconds')
|
| 418 |
@click.option('--slicer-max-sil-kept', type=int, default=200, help='Maximum silence kept in milliseconds')
|
| 419 |
@click.option('--use-fp16', is_flag=True, default=True, help='Use float16 precision for faster inference')
|
|
|
|
| 420 |
def main(
|
| 421 |
model,
|
| 422 |
input,
|
|
@@ -441,7 +646,8 @@ def main(
|
|
| 441 |
slicer_min_interval,
|
| 442 |
slicer_hop_size,
|
| 443 |
slicer_max_sil_kept,
|
| 444 |
-
use_fp16
|
|
|
|
| 445 |
):
|
| 446 |
"""Convert the voice in an audio file to a target speaker."""
|
| 447 |
|
|
@@ -486,40 +692,68 @@ def main(
|
|
| 486 |
fade_samples = int(fade_duration * sample_rate / 1000)
|
| 487 |
|
| 488 |
# Process segments
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
audio_out = process_segment(
|
| 497 |
-
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 498 |
speaker_id, sample_rate, hop_length, device,
|
| 499 |
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 500 |
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 501 |
-
cvec_downsample_rate, target_loudness, restore_loudness,
|
| 502 |
-
robust_f0, use_fp16
|
| 503 |
)
|
| 504 |
|
| 505 |
-
|
| 506 |
-
expected_length = len(chunk)
|
| 507 |
-
if len(audio_out) > expected_length:
|
| 508 |
-
audio_out = audio_out[:expected_length]
|
| 509 |
-
elif len(audio_out) < expected_length:
|
| 510 |
-
audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
|
| 511 |
-
|
| 512 |
# Apply fades
|
| 513 |
if idx > 0: # Not first segment
|
| 514 |
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 515 |
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 516 |
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 517 |
|
| 518 |
-
if idx < len(
|
| 519 |
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 520 |
|
| 521 |
# Add to result
|
| 522 |
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
# Trim any extra padding
|
| 525 |
result_audio = result_audio[:len(audio)]
|
|
|
|
| 184 |
model, mel, cvec, f0, rms, cvec_ds, spk_id,
|
| 185 |
infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 186 |
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 187 |
+
sliced_inference=False, use_fp16=True, frame_lengths=None
|
| 188 |
):
|
| 189 |
"""Run the actual inference through the model"""
|
| 190 |
device_type = 'cuda' if mel.device.type == 'cuda' else 'cpu'
|
| 191 |
|
| 192 |
+
if frame_lengths is not None:
|
| 193 |
+
# Use batch inference with frame lengths
|
| 194 |
+
with autocast(device_type=device_type, enabled=use_fp16):
|
| 195 |
+
mel_out, _ = model.sample(
|
| 196 |
+
src_mel=mel,
|
| 197 |
+
spk_id=spk_id,
|
| 198 |
+
f0=f0,
|
| 199 |
+
rms=rms,
|
| 200 |
+
cvec=cvec,
|
| 201 |
+
steps=infer_steps,
|
| 202 |
+
bad_cvec=cvec_ds,
|
| 203 |
+
ds_cfg_strength=ds_cfg_strength,
|
| 204 |
+
spk_cfg_strength=spk_cfg_strength,
|
| 205 |
+
skip_cfg_strength=skip_cfg_strength,
|
| 206 |
+
cfg_skip_layers=cfg_skip_layers,
|
| 207 |
+
cfg_rescale=cfg_rescale,
|
| 208 |
+
frame_len=frame_lengths,
|
| 209 |
+
)
|
| 210 |
+
return mel_out
|
| 211 |
+
elif sliced_inference:
|
| 212 |
# Use sliced inference for long segments
|
| 213 |
sliced_len = 256
|
| 214 |
mel_crossfade_len = 8 # Number of frames to crossfade in mel domain
|
|
|
|
| 411 |
return audio_out
|
| 412 |
|
| 413 |
|
| 414 |
+
def pad_tensor_to_length(tensor, length):
|
| 415 |
+
"""Pad a tensor to the specified length along the sequence dimension (dim=1)"""
|
| 416 |
+
curr_len = tensor.shape[1]
|
| 417 |
+
if curr_len >= length:
|
| 418 |
+
return tensor
|
| 419 |
+
|
| 420 |
+
pad_len = length - curr_len
|
| 421 |
+
|
| 422 |
+
if tensor.dim() == 2:
|
| 423 |
+
padding = (0, pad_len)
|
| 424 |
+
elif tensor.dim() == 3:
|
| 425 |
+
padding = (0, 0, 0, pad_len)
|
| 426 |
+
else:
|
| 427 |
+
raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
|
| 428 |
+
|
| 429 |
+
padded = torch.nn.functional.pad(tensor, padding, "constant", 0)
|
| 430 |
+
return padded
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def batch_process_segments(
|
| 434 |
+
segments_with_pos,
|
| 435 |
+
svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 436 |
+
speaker_id, sample_rate, hop_length, device,
|
| 437 |
+
key_shift=0,
|
| 438 |
+
infer_steps=32,
|
| 439 |
+
ds_cfg_strength=0.0,
|
| 440 |
+
spk_cfg_strength=0.0,
|
| 441 |
+
skip_cfg_strength=0.0,
|
| 442 |
+
cfg_skip_layers=None,
|
| 443 |
+
cfg_rescale=0.7,
|
| 444 |
+
cvec_downsample_rate=2,
|
| 445 |
+
target_loudness=-18.0,
|
| 446 |
+
restore_loudness=True,
|
| 447 |
+
robust_f0=0,
|
| 448 |
+
use_fp16=True,
|
| 449 |
+
batch_size=1,
|
| 450 |
+
gr_progress=None,
|
| 451 |
+
progress_desc=None
|
| 452 |
+
):
|
| 453 |
+
"""Process audio segments in batches for faster inference"""
|
| 454 |
+
if batch_size <= 1:
|
| 455 |
+
results = []
|
| 456 |
+
for i, (start_sample, chunk) in enumerate(tqdm(segments_with_pos, desc="Processing segments")):
|
| 457 |
+
if gr_progress is not None:
|
| 458 |
+
gr_progress(0.2 + (0.7 * (i / len(segments_with_pos))), desc=progress_desc.format(i+1, len(segments_with_pos)))
|
| 459 |
+
audio_out = process_segment(
|
| 460 |
+
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 461 |
+
speaker_id, sample_rate, hop_length, device,
|
| 462 |
+
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 463 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 464 |
+
cvec_downsample_rate, target_loudness, restore_loudness,
|
| 465 |
+
robust_f0, use_fp16
|
| 466 |
+
)
|
| 467 |
+
results.append((start_sample, audio_out, len(chunk)))
|
| 468 |
+
return results
|
| 469 |
+
|
| 470 |
+
sorted_with_idx = sorted(enumerate(segments_with_pos), key=lambda x: len(x[1][1]))
|
| 471 |
+
sorted_segments = []
|
| 472 |
+
original_indices = []
|
| 473 |
+
|
| 474 |
+
for orig_idx, (pos, chunk) in sorted_with_idx:
|
| 475 |
+
original_indices.append(orig_idx)
|
| 476 |
+
sorted_segments.append((pos, chunk))
|
| 477 |
+
|
| 478 |
+
batched_segments = [sorted_segments[i:i + batch_size] for i in range(0, len(sorted_segments), batch_size)]
|
| 479 |
+
|
| 480 |
+
all_results = []
|
| 481 |
+
|
| 482 |
+
for batch_idx, batch in enumerate(tqdm(batched_segments, desc="Processing batches")):
|
| 483 |
+
if gr_progress is not None:
|
| 484 |
+
gr_progress(
|
| 485 |
+
0.2 + (0.7 * (batch_idx / len(batched_segments))),
|
| 486 |
+
desc=progress_desc.format(batch_idx+1, len(batched_segments)))
|
| 487 |
+
|
| 488 |
+
batch_start_samples = [pos for pos, _ in batch]
|
| 489 |
+
batch_chunks = [chunk for _, chunk in batch]
|
| 490 |
+
batch_lengths = [len(chunk) for chunk in batch_chunks]
|
| 491 |
+
|
| 492 |
+
batch_features = []
|
| 493 |
+
for chunk in batch_chunks:
|
| 494 |
+
mel, cvec, cvec_ds, f0, rms, original_loudness = extract_features(
|
| 495 |
+
chunk, sample_rate, hop_length, rmvpe, hubert, rms_extractor,
|
| 496 |
+
device, key_shift, ds_cfg_strength, cvec_downsample_rate, target_loudness,
|
| 497 |
+
robust_f0, use_fp16
|
| 498 |
+
)
|
| 499 |
+
batch_features.append({
|
| 500 |
+
'mel': mel,
|
| 501 |
+
'cvec': cvec,
|
| 502 |
+
'cvec_ds': cvec_ds,
|
| 503 |
+
'f0': f0,
|
| 504 |
+
'rms': rms,
|
| 505 |
+
'original_loudness': original_loudness,
|
| 506 |
+
'length': mel.shape[1]
|
| 507 |
+
})
|
| 508 |
+
|
| 509 |
+
max_length = max(feat['length'] for feat in batch_features)
|
| 510 |
+
|
| 511 |
+
padded_mels = []
|
| 512 |
+
padded_cvecs = []
|
| 513 |
+
padded_f0s = []
|
| 514 |
+
padded_rmss = []
|
| 515 |
+
frame_lengths = []
|
| 516 |
+
original_loudness_values = []
|
| 517 |
+
|
| 518 |
+
if ds_cfg_strength > 0:
|
| 519 |
+
padded_cvec_ds = []
|
| 520 |
+
|
| 521 |
+
for feat in batch_features:
|
| 522 |
+
curr_len = feat['length']
|
| 523 |
+
frame_lengths.append(curr_len)
|
| 524 |
+
|
| 525 |
+
padded_mels.append(pad_tensor_to_length(feat['mel'], max_length))
|
| 526 |
+
padded_cvecs.append(pad_tensor_to_length(feat['cvec'], max_length))
|
| 527 |
+
padded_f0s.append(pad_tensor_to_length(feat['f0'], max_length))
|
| 528 |
+
padded_rmss.append(pad_tensor_to_length(feat['rms'], max_length))
|
| 529 |
+
|
| 530 |
+
if ds_cfg_strength > 0:
|
| 531 |
+
padded_cvec_ds.append(pad_tensor_to_length(feat['cvec_ds'], max_length))
|
| 532 |
+
|
| 533 |
+
original_loudness_values.append(feat['original_loudness'])
|
| 534 |
+
|
| 535 |
+
batched_mel = torch.cat(padded_mels, dim=0)
|
| 536 |
+
batched_cvec = torch.cat(padded_cvecs, dim=0)
|
| 537 |
+
batched_f0 = torch.cat(padded_f0s, dim=0)
|
| 538 |
+
batched_rms = torch.cat(padded_rmss, dim=0)
|
| 539 |
+
|
| 540 |
+
if ds_cfg_strength > 0:
|
| 541 |
+
batched_cvec_ds = torch.cat(padded_cvec_ds, dim=0)
|
| 542 |
+
else:
|
| 543 |
+
batched_cvec_ds = None
|
| 544 |
+
|
| 545 |
+
frame_lengths = torch.tensor(frame_lengths, device=device)
|
| 546 |
+
|
| 547 |
+
batch_spk_id = torch.LongTensor([speaker_id] * len(batch)).to(device)
|
| 548 |
+
|
| 549 |
+
with torch.no_grad():
|
| 550 |
+
mel_out = run_inference(
|
| 551 |
+
model=svc_model,
|
| 552 |
+
mel=batched_mel,
|
| 553 |
+
cvec=batched_cvec,
|
| 554 |
+
f0=batched_f0,
|
| 555 |
+
rms=batched_rms,
|
| 556 |
+
cvec_ds=batched_cvec_ds,
|
| 557 |
+
spk_id=batch_spk_id,
|
| 558 |
+
infer_steps=infer_steps,
|
| 559 |
+
ds_cfg_strength=ds_cfg_strength,
|
| 560 |
+
spk_cfg_strength=spk_cfg_strength,
|
| 561 |
+
skip_cfg_strength=skip_cfg_strength,
|
| 562 |
+
cfg_skip_layers=cfg_skip_layers,
|
| 563 |
+
cfg_rescale=cfg_rescale,
|
| 564 |
+
frame_lengths=frame_lengths,
|
| 565 |
+
use_fp16=use_fp16
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', enabled=use_fp16):
|
| 569 |
+
audio_out = vocoder(mel_out.transpose(1, 2), batched_f0)
|
| 570 |
+
|
| 571 |
+
for i in range(len(batch)):
|
| 572 |
+
expected_audio_length = batch_lengths[i]
|
| 573 |
+
|
| 574 |
+
curr_audio = audio_out[i].squeeze().cpu().numpy()
|
| 575 |
+
|
| 576 |
+
if len(curr_audio) > expected_audio_length:
|
| 577 |
+
curr_audio = curr_audio[:expected_audio_length]
|
| 578 |
+
elif len(curr_audio) < expected_audio_length:
|
| 579 |
+
curr_audio = np.pad(curr_audio, (0, expected_audio_length - len(curr_audio)), 'constant')
|
| 580 |
+
|
| 581 |
+
if restore_loudness:
|
| 582 |
+
meter = pyln.Meter(44100, block_size=0.1)
|
| 583 |
+
curr_loudness = meter.integrated_loudness(curr_audio)
|
| 584 |
+
curr_audio = pyln.normalize.loudness(curr_audio, curr_loudness, original_loudness_values[i])
|
| 585 |
+
|
| 586 |
+
max_amp = np.max(np.abs(curr_audio))
|
| 587 |
+
if max_amp > 1.0:
|
| 588 |
+
curr_audio = curr_audio * (0.99 / max_amp)
|
| 589 |
+
|
| 590 |
+
expected_length = batch_lengths[i]
|
| 591 |
+
|
| 592 |
+
all_results.append((batch_idx, i, batch_start_samples[i], curr_audio, expected_length, original_indices[batch_size * batch_idx + i]))
|
| 593 |
+
|
| 594 |
+
all_results.sort(key=lambda x: x[5])
|
| 595 |
+
|
| 596 |
+
return [(pos, audio, length) for _, _, pos, audio, length, _ in all_results]
|
| 597 |
+
|
| 598 |
+
|
| 599 |
@click.command()
|
| 600 |
@click.option('--model', type=click.Path(exists=True), required=True, help='Path to model checkpoint')
|
| 601 |
@click.option('--input', type=click.Path(exists=True), required=True, help='Input audio file')
|
|
|
|
| 621 |
@click.option('--slicer-hop-size', type=int, default=10, help='Hop size for audio slicing in milliseconds')
|
| 622 |
@click.option('--slicer-max-sil-kept', type=int, default=200, help='Maximum silence kept in milliseconds')
|
| 623 |
@click.option('--use-fp16', is_flag=True, default=True, help='Use float16 precision for faster inference')
|
| 624 |
+
@click.option('--batch-size', type=int, default=1, help='Batch size for parallel inference')
|
| 625 |
def main(
|
| 626 |
model,
|
| 627 |
input,
|
|
|
|
| 646 |
slicer_min_interval,
|
| 647 |
slicer_hop_size,
|
| 648 |
slicer_max_sil_kept,
|
| 649 |
+
use_fp16,
|
| 650 |
+
batch_size
|
| 651 |
):
|
| 652 |
"""Convert the voice in an audio file to a target speaker."""
|
| 653 |
|
|
|
|
| 692 |
fade_samples = int(fade_duration * sample_rate / 1000)
|
| 693 |
|
| 694 |
# Process segments
|
| 695 |
+
if batch_size > 1:
|
| 696 |
+
click.echo(f"Processing {len(segments_with_pos)} segments with batch size {batch_size}...")
|
| 697 |
+
result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
|
| 698 |
+
|
| 699 |
+
with torch.no_grad():
|
| 700 |
+
processed_segments = batch_process_segments(
|
| 701 |
+
segments_with_pos, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
|
|
|
|
|
|
| 702 |
speaker_id, sample_rate, hop_length, device,
|
| 703 |
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 704 |
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 705 |
+
cvec_downsample_rate, target_loudness, restore_loudness,
|
| 706 |
+
robust_f0, use_fp16, batch_size
|
| 707 |
)
|
| 708 |
|
| 709 |
+
for idx, (start_sample, audio_out, expected_length) in enumerate(processed_segments):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
# Apply fades
|
| 711 |
if idx > 0: # Not first segment
|
| 712 |
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 713 |
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 714 |
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 715 |
|
| 716 |
+
if idx < len(processed_segments) - 1: # Not last segment
|
| 717 |
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 718 |
|
| 719 |
# Add to result
|
| 720 |
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
| 721 |
+
else:
|
| 722 |
+
# Original processing method using sliced_inference
|
| 723 |
+
click.echo(f"Processing {len(segments_with_pos)} segments...")
|
| 724 |
+
result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
|
| 725 |
+
|
| 726 |
+
with torch.no_grad():
|
| 727 |
+
for idx, (start_sample, chunk) in enumerate(tqdm(segments_with_pos)):
|
| 728 |
+
|
| 729 |
+
# Process the segment
|
| 730 |
+
audio_out = process_segment(
|
| 731 |
+
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 732 |
+
speaker_id, sample_rate, hop_length, device,
|
| 733 |
+
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 734 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 735 |
+
cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
|
| 736 |
+
robust_f0, use_fp16
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Ensure consistent length
|
| 740 |
+
expected_length = len(chunk)
|
| 741 |
+
if len(audio_out) > expected_length:
|
| 742 |
+
audio_out = audio_out[:expected_length]
|
| 743 |
+
elif len(audio_out) < expected_length:
|
| 744 |
+
audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
|
| 745 |
+
|
| 746 |
+
# Apply fades
|
| 747 |
+
if idx > 0: # Not first segment
|
| 748 |
+
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 749 |
+
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 750 |
+
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 751 |
+
|
| 752 |
+
if idx < len(segments_with_pos) - 1: # Not last segment
|
| 753 |
+
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 754 |
+
|
| 755 |
+
# Add to result
|
| 756 |
+
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
| 757 |
|
| 758 |
# Trim any extra padding
|
| 759 |
result_audio = result_audio[:len(audio)]
|
rift_svc/rf.py
CHANGED
|
@@ -84,34 +84,138 @@ class RF(nn.Module):
|
|
| 84 |
|
| 85 |
# Define the ODE function
|
| 86 |
def fn(t, x):
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
rms=rms,
|
| 92 |
-
cvec=cvec,
|
| 93 |
-
time=t,
|
| 94 |
-
mask=mask
|
| 95 |
-
)
|
| 96 |
-
cfg_flag = (ds_cfg_strength > 1e-5) or (skip_cfg_strength > 1e-5) or (spk_cfg_strength > 1e-5)
|
| 97 |
-
if cfg_rescale > 1e-5 and cfg_flag:
|
| 98 |
-
std_pred = pred.std()
|
| 99 |
-
|
| 100 |
if ds_cfg_strength > 1e-5:
|
| 101 |
assert exists(bad_cvec), "bad_cvec is required when cfg_strength is greater than 0"
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
x=x,
|
| 104 |
spk=spk_id,
|
| 105 |
f0=f0,
|
| 106 |
rms=rms,
|
| 107 |
-
cvec=
|
| 108 |
time=t,
|
| 109 |
-
mask=mask
|
| 110 |
-
skip_layers=cfg_skip_layers
|
| 111 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
pred = pred + (pred - bad_cvec_pred) * ds_cfg_strength
|
| 114 |
-
|
| 115 |
if skip_cfg_strength > 1e-5:
|
| 116 |
skip_pred = self.transformer(
|
| 117 |
x=x,
|
|
@@ -125,20 +229,6 @@ class RF(nn.Module):
|
|
| 125 |
)
|
| 126 |
|
| 127 |
pred = pred + (pred - skip_pred) * skip_cfg_strength
|
| 128 |
-
|
| 129 |
-
if spk_cfg_strength > 1e-5:
|
| 130 |
-
null_spk_pred = self.transformer(
|
| 131 |
-
x=x,
|
| 132 |
-
spk=spk_id,
|
| 133 |
-
f0=f0,
|
| 134 |
-
rms=rms,
|
| 135 |
-
cvec=cvec,
|
| 136 |
-
time=t,
|
| 137 |
-
mask=mask,
|
| 138 |
-
drop_speaker=True
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
pred = pred + (pred - null_spk_pred) * spk_cfg_strength
|
| 142 |
|
| 143 |
if cfg_rescale > 1e-5 and cfg_flag:
|
| 144 |
std_cfg = pred.std()
|
|
|
|
| 84 |
|
| 85 |
# Define the ODE function
|
| 86 |
def fn(t, x):
|
| 87 |
+
# Check if we need to do batched processing
|
| 88 |
+
need_batched = False
|
| 89 |
+
num_cond = 1 # Regular prediction
|
| 90 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
if ds_cfg_strength > 1e-5:
|
| 92 |
assert exists(bad_cvec), "bad_cvec is required when cfg_strength is greater than 0"
|
| 93 |
+
need_batched = True
|
| 94 |
+
num_cond += 1
|
| 95 |
+
|
| 96 |
+
if spk_cfg_strength > 1e-5:
|
| 97 |
+
need_batched = True
|
| 98 |
+
num_cond += 1
|
| 99 |
+
|
| 100 |
+
if not need_batched:
|
| 101 |
+
# Standard case - just do the regular prediction
|
| 102 |
+
pred = self.transformer(
|
| 103 |
x=x,
|
| 104 |
spk=spk_id,
|
| 105 |
f0=f0,
|
| 106 |
rms=rms,
|
| 107 |
+
cvec=cvec,
|
| 108 |
time=t,
|
| 109 |
+
mask=mask
|
|
|
|
| 110 |
)
|
| 111 |
+
else:
|
| 112 |
+
# Get original batch size
|
| 113 |
+
orig_batch = x.shape[0]
|
| 114 |
+
total_batch = orig_batch * num_cond
|
| 115 |
+
|
| 116 |
+
# Batched processing - prepare inputs by repeating interleaved
|
| 117 |
+
# For each input sample, we'll create num_cond versions in sequence
|
| 118 |
+
|
| 119 |
+
# Handle x: reshape as [total_batch, seq_len, feat_dim]
|
| 120 |
+
x_batched = x.repeat_interleave(num_cond, dim=0)
|
| 121 |
+
|
| 122 |
+
# Handle speaker ID: reshape as [total_batch]
|
| 123 |
+
spk_batched = spk_id.repeat_interleave(num_cond, dim=0)
|
| 124 |
+
|
| 125 |
+
# Handle f0 and rms: reshape as [total_batch, seq_len]
|
| 126 |
+
f0_batched = f0.repeat_interleave(num_cond, dim=0)
|
| 127 |
+
rms_batched = rms.repeat_interleave(num_cond, dim=0)
|
| 128 |
+
|
| 129 |
+
# Create batched cvec, handling bad_cvec if needed
|
| 130 |
+
if ds_cfg_strength > 1e-5 and spk_cfg_strength > 1e-5:
|
| 131 |
+
# Need to create interleaved: [cvec, bad_cvec, cvec] for each original batch item
|
| 132 |
+
cvec_expanded = []
|
| 133 |
+
for i in range(orig_batch):
|
| 134 |
+
cvec_expanded.append(cvec[i:i+1]) # Regular
|
| 135 |
+
cvec_expanded.append(bad_cvec[i:i+1]) # Bad cvec
|
| 136 |
+
cvec_expanded.append(cvec[i:i+1]) # Regular (for null spk)
|
| 137 |
+
cvec_batched = torch.cat(cvec_expanded, dim=0)
|
| 138 |
+
elif ds_cfg_strength > 1e-5:
|
| 139 |
+
# Interleave: [cvec, bad_cvec] for each original batch item
|
| 140 |
+
cvec_list = []
|
| 141 |
+
for i in range(orig_batch):
|
| 142 |
+
cvec_list.append(cvec[i:i+1])
|
| 143 |
+
cvec_list.append(bad_cvec[i:i+1])
|
| 144 |
+
cvec_batched = torch.cat(cvec_list, dim=0)
|
| 145 |
+
elif spk_cfg_strength > 1e-5:
|
| 146 |
+
# Interleave: [cvec, cvec] for each original batch item
|
| 147 |
+
cvec_batched = cvec.repeat_interleave(num_cond, dim=0)
|
| 148 |
+
|
| 149 |
+
if isinstance(t, torch.Tensor) and t.ndim > 0:
|
| 150 |
+
t_batched = t.repeat_interleave(num_cond, dim=0)
|
| 151 |
+
else:
|
| 152 |
+
t_batched = t # It's a scalar, handled by the transformer
|
| 153 |
+
|
| 154 |
+
# Handle mask if exists
|
| 155 |
+
mask_batched = mask.repeat_interleave(num_cond, dim=0) if exists(mask) else None
|
| 156 |
+
|
| 157 |
+
# Create drop_speaker flag tensor - only activate for the appropriate indices
|
| 158 |
+
drop_speaker_batched = torch.zeros(total_batch, dtype=torch.bool, device=x.device)
|
| 159 |
+
|
| 160 |
+
if spk_cfg_strength > 1e-5:
|
| 161 |
+
# Set drop_speaker=True for the third condition of each original batch item
|
| 162 |
+
if ds_cfg_strength > 1e-5:
|
| 163 |
+
# Pattern is [False, False, True] repeated
|
| 164 |
+
for i in range(orig_batch):
|
| 165 |
+
drop_speaker_batched[i*num_cond + 2] = True
|
| 166 |
+
else:
|
| 167 |
+
# Pattern is [False, True] repeated
|
| 168 |
+
for i in range(orig_batch):
|
| 169 |
+
drop_speaker_batched[i*num_cond + 1] = True
|
| 170 |
+
|
| 171 |
+
# Single batched forward pass
|
| 172 |
+
preds_batched = self.transformer(
|
| 173 |
+
x=x_batched,
|
| 174 |
+
spk=spk_batched,
|
| 175 |
+
f0=f0_batched,
|
| 176 |
+
rms=rms_batched,
|
| 177 |
+
cvec=cvec_batched,
|
| 178 |
+
time=t_batched,
|
| 179 |
+
mask=mask_batched,
|
| 180 |
+
drop_speaker=drop_speaker_batched
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Reshape and extract the predictions for each condition
|
| 184 |
+
# First, reshape the predictions to [orig_batch, num_cond, seq_len, feat_dim]
|
| 185 |
+
predictions = []
|
| 186 |
+
|
| 187 |
+
# Extract predictions for each original batch item
|
| 188 |
+
for b in range(orig_batch):
|
| 189 |
+
batch_predictions = []
|
| 190 |
+
for c in range(num_cond):
|
| 191 |
+
idx = b * num_cond + c
|
| 192 |
+
batch_predictions.append(preds_batched[idx:idx+1])
|
| 193 |
+
predictions.append(batch_predictions)
|
| 194 |
+
|
| 195 |
+
# Apply classifier-free guidance per original batch item
|
| 196 |
+
pred_results = []
|
| 197 |
+
for b in range(orig_batch):
|
| 198 |
+
pred = predictions[b][0] # Regular prediction
|
| 199 |
+
|
| 200 |
+
cond_idx = 1
|
| 201 |
+
if ds_cfg_strength > 1e-5:
|
| 202 |
+
bad_cvec_pred = predictions[b][cond_idx]
|
| 203 |
+
pred = pred + (pred - bad_cvec_pred) * ds_cfg_strength
|
| 204 |
+
cond_idx += 1
|
| 205 |
+
|
| 206 |
+
if spk_cfg_strength > 1e-5:
|
| 207 |
+
null_spk_pred = predictions[b][cond_idx]
|
| 208 |
+
pred = pred + (pred - null_spk_pred) * spk_cfg_strength
|
| 209 |
+
|
| 210 |
+
pred_results.append(pred)
|
| 211 |
+
|
| 212 |
+
# Combine back to original batch dimension
|
| 213 |
+
pred = torch.cat(pred_results, dim=0)
|
| 214 |
+
|
| 215 |
+
cfg_flag = (ds_cfg_strength > 1e-5) or (skip_cfg_strength > 1e-5) or (spk_cfg_strength > 1e-5)
|
| 216 |
+
if cfg_rescale > 1e-5 and cfg_flag:
|
| 217 |
+
std_pred = pred.std()
|
| 218 |
|
|
|
|
|
|
|
| 219 |
if skip_cfg_strength > 1e-5:
|
| 220 |
skip_pred = self.transformer(
|
| 221 |
x=x,
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
pred = pred + (pred - skip_pred) * skip_cfg_strength
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
if cfg_rescale > 1e-5 and cfg_flag:
|
| 234 |
std_cfg = pred.std()
|