prismleong commited on
Commit
32a5366
·
1 Parent(s): fc38cab
Files changed (3) hide show
  1. app.py +73 -34
  2. infer.py +256 -22
  3. rift_svc/rf.py +123 -33
app.py CHANGED
@@ -14,7 +14,8 @@ from infer import (
14
  load_models,
15
  load_audio,
16
  apply_fade,
17
- process_segment
 
18
  )
19
 
20
  # Global variables for models
@@ -115,7 +116,9 @@ def process_with_progress(
115
  slicer_min_length=3000,
116
  slicer_min_interval=100,
117
  slicer_hop_size=10,
118
- slicer_max_sil_kept=200
 
 
119
  ):
120
  global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
121
 
@@ -182,41 +185,71 @@ def process_with_progress(
182
  progress(0.2, desc="处理中: 开始转换...")
183
 
184
  with torch.no_grad():
185
- for i, (start_sample, chunk) in enumerate(segments_with_pos):
186
- segment_progress = 0.2 + (0.7 * (i / len(segments_with_pos)))
187
- progress(segment_progress, desc=f"处理中: 片段 {i+1}/{len(segments_with_pos)}")
188
-
189
- # Process the segment
190
- audio_out = process_segment(
191
- chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
192
  speaker_id, sample_rate, hop_length, device,
193
  key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
194
  skip_cfg_strength, cfg_skip_layers, cfg_rescale,
195
- cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
196
- robust_f0, use_fp16
197
  )
198
 
199
- # Ensure consistent length
200
- expected_length = len(chunk)
201
- if len(audio_out) > expected_length:
202
- audio_out = audio_out[:expected_length]
203
- elif len(audio_out) < expected_length:
204
- audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
205
-
206
- # Apply fades
207
- if i > 0: # Not first segment
208
- audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
209
- result_audio[start_sample:start_sample + fade_samples] *= \
210
- np.linspace(1, 0, fade_samples) # Fade out previous
211
-
212
- if i < len(segments_with_pos) - 1: # Not last segment
213
- audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
214
-
215
- # Add to result
216
- result_audio[start_sample:start_sample + len(audio_out)] += audio_out
217
-
218
- # Clean up memory after each segment
219
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  progress(0.9, desc="处理中: 完成音频...")
222
  # Trim any extra padding
@@ -230,7 +263,8 @@ def process_with_progress(
230
  torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate)
231
 
232
  progress(1.0, desc="处理完成!")
233
- return (sample_rate, result_audio), f" 转换完成! 已转换为 **{speaker}** 并调整 **{key_shift}** 个半音。"
 
234
 
235
  except RuntimeError as e:
236
  # Handle CUDA out of memory errors
@@ -341,6 +375,9 @@ def create_ui():
341
  robust_f0 = gr.Radio(choices=[0, 1, 2], value=1, label="音高滤波",
342
  info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)",
343
  elem_id="robust_f0")
 
 
 
344
 
345
  with gr.Accordion("🔬 高级CFG参数", open=True):
346
  ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2,
@@ -403,6 +440,7 @@ def create_ui():
403
  <li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li>
404
  <li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li>
405
  <li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li>
 
406
  <li><strong>CFG参数:</strong> 调整转换质量和音色。</li>
407
  </ul>
408
  </div>
@@ -426,7 +464,8 @@ def create_ui():
426
  inputs=[
427
  input_audio, speaker, key_shift, infer_steps, robust_f0,
428
  ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate,
429
- slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept
 
430
  ],
431
  outputs=[output_audio, output_message],
432
  show_progress_on=output_audio
 
14
  load_models,
15
  load_audio,
16
  apply_fade,
17
+ process_segment,
18
+ batch_process_segments
19
  )
20
 
21
  # Global variables for models
 
116
  slicer_min_length=3000,
117
  slicer_min_interval=100,
118
  slicer_hop_size=10,
119
+ slicer_max_sil_kept=200,
120
+ # Batch processing
121
+ batch_size=1
122
  ):
123
  global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
124
 
 
185
  progress(0.2, desc="处理中: 开始转换...")
186
 
187
  with torch.no_grad():
188
+ if batch_size > 1:
189
+ # Use batch processing
190
+ progress_desc = f"处理中: 批次 {{0}}/{{1}}"
191
+ processed_segments = batch_process_segments(
192
+ segments_with_pos, svc_model, vocoder, rmvpe, hubert, rms_extractor,
 
 
193
  speaker_id, sample_rate, hop_length, device,
194
  key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
195
  skip_cfg_strength, cfg_skip_layers, cfg_rescale,
196
+ cvec_downsample_rate, target_loudness, restore_loudness,
197
+ robust_f0, use_fp16, batch_size, progress, progress_desc
198
  )
199
 
200
+ for idx, (start_sample, audio_out, expected_length) in enumerate(processed_segments):
201
+ # Apply fades
202
+ if idx > 0: # Not first segment
203
+ audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
204
+ result_audio[start_sample:start_sample + fade_samples] *= \
205
+ np.linspace(1, 0, fade_samples) # Fade out previous
206
+
207
+ if idx < len(processed_segments) - 1: # Not last segment
208
+ audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
209
+
210
+ # Add to result
211
+ result_audio[start_sample:start_sample + len(audio_out)] += audio_out
212
+
213
+ # Clean up memory after each segment
214
+ if idx % 5 == 0: # Clean up every 5 segments
215
+ torch.cuda.empty_cache()
216
+ else:
217
+ # Use sequential processing
218
+ for i, (start_sample, chunk) in enumerate(segments_with_pos):
219
+ segment_progress = 0.2 + (0.7 * (i / len(segments_with_pos)))
220
+ progress(segment_progress, desc=f"处理中: 片段 {i+1}/{len(segments_with_pos)}")
221
+
222
+ # Process the segment
223
+ audio_out = process_segment(
224
+ chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
225
+ speaker_id, sample_rate, hop_length, device,
226
+ key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
227
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
228
+ cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
229
+ robust_f0, use_fp16
230
+ )
231
+
232
+ # Ensure consistent length
233
+ expected_length = len(chunk)
234
+ if len(audio_out) > expected_length:
235
+ audio_out = audio_out[:expected_length]
236
+ elif len(audio_out) < expected_length:
237
+ audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
238
+
239
+ # Apply fades
240
+ if i > 0: # Not first segment
241
+ audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
242
+ result_audio[start_sample:start_sample + fade_samples] *= \
243
+ np.linspace(1, 0, fade_samples) # Fade out previous
244
+
245
+ if i < len(segments_with_pos) - 1: # Not last segment
246
+ audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
247
+
248
+ # Add to result
249
+ result_audio[start_sample:start_sample + len(audio_out)] += audio_out
250
+
251
+ # Clean up memory after each segment
252
+ torch.cuda.empty_cache()
253
 
254
  progress(0.9, desc="处理中: 完成音频...")
255
  # Trim any extra padding
 
263
  torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate)
264
 
265
  progress(1.0, desc="处理完成!")
266
+ batch_text = f"批处理大小 {batch_size}" if batch_size > 1 else "顺序处理"
267
+ return (sample_rate, result_audio), f"✅ 转换完成! 已转换为 **{speaker}** 并调整 **{key_shift}** 个半音。{batch_text}"
268
 
269
  except RuntimeError as e:
270
  # Handle CUDA out of memory errors
 
375
  robust_f0 = gr.Radio(choices=[0, 1, 2], value=1, label="音高滤波",
376
  info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)",
377
  elem_id="robust_f0")
378
+ batch_size = gr.Slider(minimum=1, maximum=64, step=1, value=4, label="批处理大小",
379
+ info="使用批处理可以加速转换,但需要更多VRAM。1=不使用批处理",
380
+ elem_id="batch_size")
381
 
382
  with gr.Accordion("🔬 高级CFG参数", open=True):
383
  ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2,
 
440
  <li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li>
441
  <li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li>
442
  <li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li>
443
+ <li><strong>批处理大小:</strong> 值越大 = 转换越快,但需要更多GPU内存。遇到内存不足时降低此值。</li>
444
  <li><strong>CFG参数:</strong> 调整转换质量和音色。</li>
445
  </ul>
446
  </div>
 
464
  inputs=[
465
  input_audio, speaker, key_shift, infer_steps, robust_f0,
466
  ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate,
467
+ slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept,
468
+ batch_size
469
  ],
470
  outputs=[output_audio, output_message],
471
  show_progress_on=output_audio
infer.py CHANGED
@@ -184,12 +184,31 @@ def run_inference(
184
  model, mel, cvec, f0, rms, cvec_ds, spk_id,
185
  infer_steps, ds_cfg_strength, spk_cfg_strength,
186
  skip_cfg_strength, cfg_skip_layers, cfg_rescale,
187
- sliced_inference=False, use_fp16=True
188
  ):
189
  """Run the actual inference through the model"""
190
  device_type = 'cuda' if mel.device.type == 'cuda' else 'cpu'
191
 
192
- if sliced_inference:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # Use sliced inference for long segments
194
  sliced_len = 256
195
  mel_crossfade_len = 8 # Number of frames to crossfade in mel domain
@@ -392,6 +411,191 @@ def process_segment(
392
  return audio_out
393
 
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  @click.command()
396
  @click.option('--model', type=click.Path(exists=True), required=True, help='Path to model checkpoint')
397
  @click.option('--input', type=click.Path(exists=True), required=True, help='Input audio file')
@@ -417,6 +621,7 @@ def process_segment(
417
  @click.option('--slicer-hop-size', type=int, default=10, help='Hop size for audio slicing in milliseconds')
418
  @click.option('--slicer-max-sil-kept', type=int, default=200, help='Maximum silence kept in milliseconds')
419
  @click.option('--use-fp16', is_flag=True, default=True, help='Use float16 precision for faster inference')
 
420
  def main(
421
  model,
422
  input,
@@ -441,7 +646,8 @@ def main(
441
  slicer_min_interval,
442
  slicer_hop_size,
443
  slicer_max_sil_kept,
444
- use_fp16
 
445
  ):
446
  """Convert the voice in an audio file to a target speaker."""
447
 
@@ -486,40 +692,68 @@ def main(
486
  fade_samples = int(fade_duration * sample_rate / 1000)
487
 
488
  # Process segments
489
- click.echo(f"Processing {len(segments_with_pos)} segments...")
490
- result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
491
-
492
- with torch.no_grad():
493
- for idx, (start_sample, chunk) in enumerate(tqdm(segments_with_pos)):
494
-
495
- # Process the segment
496
- audio_out = process_segment(
497
- chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
498
  speaker_id, sample_rate, hop_length, device,
499
  key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
500
  skip_cfg_strength, cfg_skip_layers, cfg_rescale,
501
- cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
502
- robust_f0, use_fp16
503
  )
504
 
505
- # Ensure consistent length
506
- expected_length = len(chunk)
507
- if len(audio_out) > expected_length:
508
- audio_out = audio_out[:expected_length]
509
- elif len(audio_out) < expected_length:
510
- audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
511
-
512
  # Apply fades
513
  if idx > 0: # Not first segment
514
  audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
515
  result_audio[start_sample:start_sample + fade_samples] *= \
516
  np.linspace(1, 0, fade_samples) # Fade out previous
517
 
518
- if idx < len(segments_with_pos) - 1: # Not last segment
519
  audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
520
 
521
  # Add to result
522
  result_audio[start_sample:start_sample + len(audio_out)] += audio_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  # Trim any extra padding
525
  result_audio = result_audio[:len(audio)]
 
184
  model, mel, cvec, f0, rms, cvec_ds, spk_id,
185
  infer_steps, ds_cfg_strength, spk_cfg_strength,
186
  skip_cfg_strength, cfg_skip_layers, cfg_rescale,
187
+ sliced_inference=False, use_fp16=True, frame_lengths=None
188
  ):
189
  """Run the actual inference through the model"""
190
  device_type = 'cuda' if mel.device.type == 'cuda' else 'cpu'
191
 
192
+ if frame_lengths is not None:
193
+ # Use batch inference with frame lengths
194
+ with autocast(device_type=device_type, enabled=use_fp16):
195
+ mel_out, _ = model.sample(
196
+ src_mel=mel,
197
+ spk_id=spk_id,
198
+ f0=f0,
199
+ rms=rms,
200
+ cvec=cvec,
201
+ steps=infer_steps,
202
+ bad_cvec=cvec_ds,
203
+ ds_cfg_strength=ds_cfg_strength,
204
+ spk_cfg_strength=spk_cfg_strength,
205
+ skip_cfg_strength=skip_cfg_strength,
206
+ cfg_skip_layers=cfg_skip_layers,
207
+ cfg_rescale=cfg_rescale,
208
+ frame_len=frame_lengths,
209
+ )
210
+ return mel_out
211
+ elif sliced_inference:
212
  # Use sliced inference for long segments
213
  sliced_len = 256
214
  mel_crossfade_len = 8 # Number of frames to crossfade in mel domain
 
411
  return audio_out
412
 
413
 
414
+ def pad_tensor_to_length(tensor, length):
415
+ """Pad a tensor to the specified length along the sequence dimension (dim=1)"""
416
+ curr_len = tensor.shape[1]
417
+ if curr_len >= length:
418
+ return tensor
419
+
420
+ pad_len = length - curr_len
421
+
422
+ if tensor.dim() == 2:
423
+ padding = (0, pad_len)
424
+ elif tensor.dim() == 3:
425
+ padding = (0, 0, 0, pad_len)
426
+ else:
427
+ raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
428
+
429
+ padded = torch.nn.functional.pad(tensor, padding, "constant", 0)
430
+ return padded
431
+
432
+
433
+ def batch_process_segments(
434
+ segments_with_pos,
435
+ svc_model, vocoder, rmvpe, hubert, rms_extractor,
436
+ speaker_id, sample_rate, hop_length, device,
437
+ key_shift=0,
438
+ infer_steps=32,
439
+ ds_cfg_strength=0.0,
440
+ spk_cfg_strength=0.0,
441
+ skip_cfg_strength=0.0,
442
+ cfg_skip_layers=None,
443
+ cfg_rescale=0.7,
444
+ cvec_downsample_rate=2,
445
+ target_loudness=-18.0,
446
+ restore_loudness=True,
447
+ robust_f0=0,
448
+ use_fp16=True,
449
+ batch_size=1,
450
+ gr_progress=None,
451
+ progress_desc=None
452
+ ):
453
+ """Process audio segments in batches for faster inference"""
454
+ if batch_size <= 1:
455
+ results = []
456
+ for i, (start_sample, chunk) in enumerate(tqdm(segments_with_pos, desc="Processing segments")):
457
+ if gr_progress is not None:
458
+ gr_progress(0.2 + (0.7 * (i / len(segments_with_pos))), desc=progress_desc.format(i+1, len(segments_with_pos)))
459
+ audio_out = process_segment(
460
+ chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
461
+ speaker_id, sample_rate, hop_length, device,
462
+ key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
463
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
464
+ cvec_downsample_rate, target_loudness, restore_loudness,
465
+ robust_f0, use_fp16
466
+ )
467
+ results.append((start_sample, audio_out, len(chunk)))
468
+ return results
469
+
470
+ sorted_with_idx = sorted(enumerate(segments_with_pos), key=lambda x: len(x[1][1]))
471
+ sorted_segments = []
472
+ original_indices = []
473
+
474
+ for orig_idx, (pos, chunk) in sorted_with_idx:
475
+ original_indices.append(orig_idx)
476
+ sorted_segments.append((pos, chunk))
477
+
478
+ batched_segments = [sorted_segments[i:i + batch_size] for i in range(0, len(sorted_segments), batch_size)]
479
+
480
+ all_results = []
481
+
482
+ for batch_idx, batch in enumerate(tqdm(batched_segments, desc="Processing batches")):
483
+ if gr_progress is not None:
484
+ gr_progress(
485
+ 0.2 + (0.7 * (batch_idx / len(batched_segments))),
486
+ desc=progress_desc.format(batch_idx+1, len(batched_segments)))
487
+
488
+ batch_start_samples = [pos for pos, _ in batch]
489
+ batch_chunks = [chunk for _, chunk in batch]
490
+ batch_lengths = [len(chunk) for chunk in batch_chunks]
491
+
492
+ batch_features = []
493
+ for chunk in batch_chunks:
494
+ mel, cvec, cvec_ds, f0, rms, original_loudness = extract_features(
495
+ chunk, sample_rate, hop_length, rmvpe, hubert, rms_extractor,
496
+ device, key_shift, ds_cfg_strength, cvec_downsample_rate, target_loudness,
497
+ robust_f0, use_fp16
498
+ )
499
+ batch_features.append({
500
+ 'mel': mel,
501
+ 'cvec': cvec,
502
+ 'cvec_ds': cvec_ds,
503
+ 'f0': f0,
504
+ 'rms': rms,
505
+ 'original_loudness': original_loudness,
506
+ 'length': mel.shape[1]
507
+ })
508
+
509
+ max_length = max(feat['length'] for feat in batch_features)
510
+
511
+ padded_mels = []
512
+ padded_cvecs = []
513
+ padded_f0s = []
514
+ padded_rmss = []
515
+ frame_lengths = []
516
+ original_loudness_values = []
517
+
518
+ if ds_cfg_strength > 0:
519
+ padded_cvec_ds = []
520
+
521
+ for feat in batch_features:
522
+ curr_len = feat['length']
523
+ frame_lengths.append(curr_len)
524
+
525
+ padded_mels.append(pad_tensor_to_length(feat['mel'], max_length))
526
+ padded_cvecs.append(pad_tensor_to_length(feat['cvec'], max_length))
527
+ padded_f0s.append(pad_tensor_to_length(feat['f0'], max_length))
528
+ padded_rmss.append(pad_tensor_to_length(feat['rms'], max_length))
529
+
530
+ if ds_cfg_strength > 0:
531
+ padded_cvec_ds.append(pad_tensor_to_length(feat['cvec_ds'], max_length))
532
+
533
+ original_loudness_values.append(feat['original_loudness'])
534
+
535
+ batched_mel = torch.cat(padded_mels, dim=0)
536
+ batched_cvec = torch.cat(padded_cvecs, dim=0)
537
+ batched_f0 = torch.cat(padded_f0s, dim=0)
538
+ batched_rms = torch.cat(padded_rmss, dim=0)
539
+
540
+ if ds_cfg_strength > 0:
541
+ batched_cvec_ds = torch.cat(padded_cvec_ds, dim=0)
542
+ else:
543
+ batched_cvec_ds = None
544
+
545
+ frame_lengths = torch.tensor(frame_lengths, device=device)
546
+
547
+ batch_spk_id = torch.LongTensor([speaker_id] * len(batch)).to(device)
548
+
549
+ with torch.no_grad():
550
+ mel_out = run_inference(
551
+ model=svc_model,
552
+ mel=batched_mel,
553
+ cvec=batched_cvec,
554
+ f0=batched_f0,
555
+ rms=batched_rms,
556
+ cvec_ds=batched_cvec_ds,
557
+ spk_id=batch_spk_id,
558
+ infer_steps=infer_steps,
559
+ ds_cfg_strength=ds_cfg_strength,
560
+ spk_cfg_strength=spk_cfg_strength,
561
+ skip_cfg_strength=skip_cfg_strength,
562
+ cfg_skip_layers=cfg_skip_layers,
563
+ cfg_rescale=cfg_rescale,
564
+ frame_lengths=frame_lengths,
565
+ use_fp16=use_fp16
566
+ )
567
+
568
+ with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', enabled=use_fp16):
569
+ audio_out = vocoder(mel_out.transpose(1, 2), batched_f0)
570
+
571
+ for i in range(len(batch)):
572
+ expected_audio_length = batch_lengths[i]
573
+
574
+ curr_audio = audio_out[i].squeeze().cpu().numpy()
575
+
576
+ if len(curr_audio) > expected_audio_length:
577
+ curr_audio = curr_audio[:expected_audio_length]
578
+ elif len(curr_audio) < expected_audio_length:
579
+ curr_audio = np.pad(curr_audio, (0, expected_audio_length - len(curr_audio)), 'constant')
580
+
581
+ if restore_loudness:
582
+ meter = pyln.Meter(44100, block_size=0.1)
583
+ curr_loudness = meter.integrated_loudness(curr_audio)
584
+ curr_audio = pyln.normalize.loudness(curr_audio, curr_loudness, original_loudness_values[i])
585
+
586
+ max_amp = np.max(np.abs(curr_audio))
587
+ if max_amp > 1.0:
588
+ curr_audio = curr_audio * (0.99 / max_amp)
589
+
590
+ expected_length = batch_lengths[i]
591
+
592
+ all_results.append((batch_idx, i, batch_start_samples[i], curr_audio, expected_length, original_indices[batch_size * batch_idx + i]))
593
+
594
+ all_results.sort(key=lambda x: x[5])
595
+
596
+ return [(pos, audio, length) for _, _, pos, audio, length, _ in all_results]
597
+
598
+
599
  @click.command()
600
  @click.option('--model', type=click.Path(exists=True), required=True, help='Path to model checkpoint')
601
  @click.option('--input', type=click.Path(exists=True), required=True, help='Input audio file')
 
621
  @click.option('--slicer-hop-size', type=int, default=10, help='Hop size for audio slicing in milliseconds')
622
  @click.option('--slicer-max-sil-kept', type=int, default=200, help='Maximum silence kept in milliseconds')
623
  @click.option('--use-fp16', is_flag=True, default=True, help='Use float16 precision for faster inference')
624
+ @click.option('--batch-size', type=int, default=1, help='Batch size for parallel inference')
625
  def main(
626
  model,
627
  input,
 
646
  slicer_min_interval,
647
  slicer_hop_size,
648
  slicer_max_sil_kept,
649
+ use_fp16,
650
+ batch_size
651
  ):
652
  """Convert the voice in an audio file to a target speaker."""
653
 
 
692
  fade_samples = int(fade_duration * sample_rate / 1000)
693
 
694
  # Process segments
695
+ if batch_size > 1:
696
+ click.echo(f"Processing {len(segments_with_pos)} segments with batch size {batch_size}...")
697
+ result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
698
+
699
+ with torch.no_grad():
700
+ processed_segments = batch_process_segments(
701
+ segments_with_pos, svc_model, vocoder, rmvpe, hubert, rms_extractor,
 
 
702
  speaker_id, sample_rate, hop_length, device,
703
  key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
704
  skip_cfg_strength, cfg_skip_layers, cfg_rescale,
705
+ cvec_downsample_rate, target_loudness, restore_loudness,
706
+ robust_f0, use_fp16, batch_size
707
  )
708
 
709
+ for idx, (start_sample, audio_out, expected_length) in enumerate(processed_segments):
 
 
 
 
 
 
710
  # Apply fades
711
  if idx > 0: # Not first segment
712
  audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
713
  result_audio[start_sample:start_sample + fade_samples] *= \
714
  np.linspace(1, 0, fade_samples) # Fade out previous
715
 
716
+ if idx < len(processed_segments) - 1: # Not last segment
717
  audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
718
 
719
  # Add to result
720
  result_audio[start_sample:start_sample + len(audio_out)] += audio_out
721
+ else:
722
+ # Original processing method using sliced_inference
723
+ click.echo(f"Processing {len(segments_with_pos)} segments...")
724
+ result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
725
+
726
+ with torch.no_grad():
727
+ for idx, (start_sample, chunk) in enumerate(tqdm(segments_with_pos)):
728
+
729
+ # Process the segment
730
+ audio_out = process_segment(
731
+ chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
732
+ speaker_id, sample_rate, hop_length, device,
733
+ key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
734
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
735
+ cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
736
+ robust_f0, use_fp16
737
+ )
738
+
739
+ # Ensure consistent length
740
+ expected_length = len(chunk)
741
+ if len(audio_out) > expected_length:
742
+ audio_out = audio_out[:expected_length]
743
+ elif len(audio_out) < expected_length:
744
+ audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
745
+
746
+ # Apply fades
747
+ if idx > 0: # Not first segment
748
+ audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
749
+ result_audio[start_sample:start_sample + fade_samples] *= \
750
+ np.linspace(1, 0, fade_samples) # Fade out previous
751
+
752
+ if idx < len(segments_with_pos) - 1: # Not last segment
753
+ audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
754
+
755
+ # Add to result
756
+ result_audio[start_sample:start_sample + len(audio_out)] += audio_out
757
 
758
  # Trim any extra padding
759
  result_audio = result_audio[:len(audio)]
rift_svc/rf.py CHANGED
@@ -84,34 +84,138 @@ class RF(nn.Module):
84
 
85
  # Define the ODE function
86
  def fn(t, x):
87
- pred = self.transformer(
88
- x=x,
89
- spk=spk_id,
90
- f0=f0,
91
- rms=rms,
92
- cvec=cvec,
93
- time=t,
94
- mask=mask
95
- )
96
- cfg_flag = (ds_cfg_strength > 1e-5) or (skip_cfg_strength > 1e-5) or (spk_cfg_strength > 1e-5)
97
- if cfg_rescale > 1e-5 and cfg_flag:
98
- std_pred = pred.std()
99
-
100
  if ds_cfg_strength > 1e-5:
101
  assert exists(bad_cvec), "bad_cvec is required when cfg_strength is greater than 0"
102
- bad_cvec_pred = self.transformer(
 
 
 
 
 
 
 
 
 
103
  x=x,
104
  spk=spk_id,
105
  f0=f0,
106
  rms=rms,
107
- cvec=bad_cvec,
108
  time=t,
109
- mask=mask,
110
- skip_layers=cfg_skip_layers
111
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- pred = pred + (pred - bad_cvec_pred) * ds_cfg_strength
114
-
115
  if skip_cfg_strength > 1e-5:
116
  skip_pred = self.transformer(
117
  x=x,
@@ -125,20 +229,6 @@ class RF(nn.Module):
125
  )
126
 
127
  pred = pred + (pred - skip_pred) * skip_cfg_strength
128
-
129
- if spk_cfg_strength > 1e-5:
130
- null_spk_pred = self.transformer(
131
- x=x,
132
- spk=spk_id,
133
- f0=f0,
134
- rms=rms,
135
- cvec=cvec,
136
- time=t,
137
- mask=mask,
138
- drop_speaker=True
139
- )
140
-
141
- pred = pred + (pred - null_spk_pred) * spk_cfg_strength
142
 
143
  if cfg_rescale > 1e-5 and cfg_flag:
144
  std_cfg = pred.std()
 
84
 
85
  # Define the ODE function
86
  def fn(t, x):
87
+ # Check if we need to do batched processing
88
+ need_batched = False
89
+ num_cond = 1 # Regular prediction
90
+
 
 
 
 
 
 
 
 
 
91
  if ds_cfg_strength > 1e-5:
92
  assert exists(bad_cvec), "bad_cvec is required when cfg_strength is greater than 0"
93
+ need_batched = True
94
+ num_cond += 1
95
+
96
+ if spk_cfg_strength > 1e-5:
97
+ need_batched = True
98
+ num_cond += 1
99
+
100
+ if not need_batched:
101
+ # Standard case - just do the regular prediction
102
+ pred = self.transformer(
103
  x=x,
104
  spk=spk_id,
105
  f0=f0,
106
  rms=rms,
107
+ cvec=cvec,
108
  time=t,
109
+ mask=mask
 
110
  )
111
+ else:
112
+ # Get original batch size
113
+ orig_batch = x.shape[0]
114
+ total_batch = orig_batch * num_cond
115
+
116
+ # Batched processing - prepare inputs by repeating interleaved
117
+ # For each input sample, we'll create num_cond versions in sequence
118
+
119
+ # Handle x: reshape as [total_batch, seq_len, feat_dim]
120
+ x_batched = x.repeat_interleave(num_cond, dim=0)
121
+
122
+ # Handle speaker ID: reshape as [total_batch]
123
+ spk_batched = spk_id.repeat_interleave(num_cond, dim=0)
124
+
125
+ # Handle f0 and rms: reshape as [total_batch, seq_len]
126
+ f0_batched = f0.repeat_interleave(num_cond, dim=0)
127
+ rms_batched = rms.repeat_interleave(num_cond, dim=0)
128
+
129
+ # Create batched cvec, handling bad_cvec if needed
130
+ if ds_cfg_strength > 1e-5 and spk_cfg_strength > 1e-5:
131
+ # Need to create interleaved: [cvec, bad_cvec, cvec] for each original batch item
132
+ cvec_expanded = []
133
+ for i in range(orig_batch):
134
+ cvec_expanded.append(cvec[i:i+1]) # Regular
135
+ cvec_expanded.append(bad_cvec[i:i+1]) # Bad cvec
136
+ cvec_expanded.append(cvec[i:i+1]) # Regular (for null spk)
137
+ cvec_batched = torch.cat(cvec_expanded, dim=0)
138
+ elif ds_cfg_strength > 1e-5:
139
+ # Interleave: [cvec, bad_cvec] for each original batch item
140
+ cvec_list = []
141
+ for i in range(orig_batch):
142
+ cvec_list.append(cvec[i:i+1])
143
+ cvec_list.append(bad_cvec[i:i+1])
144
+ cvec_batched = torch.cat(cvec_list, dim=0)
145
+ elif spk_cfg_strength > 1e-5:
146
+ # Interleave: [cvec, cvec] for each original batch item
147
+ cvec_batched = cvec.repeat_interleave(num_cond, dim=0)
148
+
149
+ if isinstance(t, torch.Tensor) and t.ndim > 0:
150
+ t_batched = t.repeat_interleave(num_cond, dim=0)
151
+ else:
152
+ t_batched = t # It's a scalar, handled by the transformer
153
+
154
+ # Handle mask if exists
155
+ mask_batched = mask.repeat_interleave(num_cond, dim=0) if exists(mask) else None
156
+
157
+ # Create drop_speaker flag tensor - only activate for the appropriate indices
158
+ drop_speaker_batched = torch.zeros(total_batch, dtype=torch.bool, device=x.device)
159
+
160
+ if spk_cfg_strength > 1e-5:
161
+ # Set drop_speaker=True for the third condition of each original batch item
162
+ if ds_cfg_strength > 1e-5:
163
+ # Pattern is [False, False, True] repeated
164
+ for i in range(orig_batch):
165
+ drop_speaker_batched[i*num_cond + 2] = True
166
+ else:
167
+ # Pattern is [False, True] repeated
168
+ for i in range(orig_batch):
169
+ drop_speaker_batched[i*num_cond + 1] = True
170
+
171
+ # Single batched forward pass
172
+ preds_batched = self.transformer(
173
+ x=x_batched,
174
+ spk=spk_batched,
175
+ f0=f0_batched,
176
+ rms=rms_batched,
177
+ cvec=cvec_batched,
178
+ time=t_batched,
179
+ mask=mask_batched,
180
+ drop_speaker=drop_speaker_batched
181
+ )
182
+
183
+ # Reshape and extract the predictions for each condition
184
+ # First, reshape the predictions to [orig_batch, num_cond, seq_len, feat_dim]
185
+ predictions = []
186
+
187
+ # Extract predictions for each original batch item
188
+ for b in range(orig_batch):
189
+ batch_predictions = []
190
+ for c in range(num_cond):
191
+ idx = b * num_cond + c
192
+ batch_predictions.append(preds_batched[idx:idx+1])
193
+ predictions.append(batch_predictions)
194
+
195
+ # Apply classifier-free guidance per original batch item
196
+ pred_results = []
197
+ for b in range(orig_batch):
198
+ pred = predictions[b][0] # Regular prediction
199
+
200
+ cond_idx = 1
201
+ if ds_cfg_strength > 1e-5:
202
+ bad_cvec_pred = predictions[b][cond_idx]
203
+ pred = pred + (pred - bad_cvec_pred) * ds_cfg_strength
204
+ cond_idx += 1
205
+
206
+ if spk_cfg_strength > 1e-5:
207
+ null_spk_pred = predictions[b][cond_idx]
208
+ pred = pred + (pred - null_spk_pred) * spk_cfg_strength
209
+
210
+ pred_results.append(pred)
211
+
212
+ # Combine back to original batch dimension
213
+ pred = torch.cat(pred_results, dim=0)
214
+
215
+ cfg_flag = (ds_cfg_strength > 1e-5) or (skip_cfg_strength > 1e-5) or (spk_cfg_strength > 1e-5)
216
+ if cfg_rescale > 1e-5 and cfg_flag:
217
+ std_pred = pred.std()
218
 
 
 
219
  if skip_cfg_strength > 1e-5:
220
  skip_pred = self.transformer(
221
  x=x,
 
229
  )
230
 
231
  pred = pred + (pred - skip_pred) * skip_cfg_strength
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  if cfg_rescale > 1e-5 and cfg_flag:
234
  std_cfg = pred.std()