prismleong commited on
Commit
898b100
·
1 Parent(s): 5c61d84
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: RIFT SVC
3
- emoji: 😻
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
@@ -11,4 +11,4 @@ license: cc-by-nc-sa-4.0
11
  short_description: https://github.com/Pur1zumu/RIFT-SVC
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: RIFT-SVC (七海Nanami demo)
3
+ emoji: 🎵
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
 
11
  short_description: https://github.com/Pur1zumu/RIFT-SVC
12
  ---
13
 
14
+
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchaudio
4
+ import gradio as gr
5
+ import tempfile
6
+ import gc
7
+ import traceback
8
+ from slicer import Slicer
9
+
10
+ from infer import (
11
+ load_models,
12
+ load_audio,
13
+ apply_fade,
14
+ process_segment
15
+ )
16
+
17
+ # Global variables for models
18
+ global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg, device
19
+ svc_model = vocoder = rmvpe = hubert = rms_extractor = spk2idx = dataset_cfg = None
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ # Set default model path
23
+ DEFAULT_MODEL_PATH = "pretrained/dit-768-12_nanami.ckpt"
24
+
25
+ # Maximum audio duration in seconds to avoid memory issues
26
+ MAX_AUDIO_DURATION = 300 # 5 minutes
27
+
28
+ def initialize_models(model_path=DEFAULT_MODEL_PATH):
29
+ global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
30
+
31
+ # Clean up memory before loading models
32
+ if svc_model is not None:
33
+ del svc_model
34
+ del vocoder
35
+ del rmvpe
36
+ del hubert
37
+ del rms_extractor
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+ try:
42
+ svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg = load_models(model_path, device)
43
+ available_speakers = list(spk2idx.keys())
44
+ return available_speakers, f"✅ 模型加载成功!可用说话人: {', '.join(available_speakers)}"
45
+ except Exception as e:
46
+ error_trace = traceback.format_exc()
47
+ return [], f"❌ 加载模型出错: {str(e)}\n\n详细信息: {error_trace}"
48
+
49
+ def check_audio_length(audio_path, max_duration=MAX_AUDIO_DURATION):
50
+ """Check if audio file is too long to process safely"""
51
+ try:
52
+ info = torchaudio.info(audio_path)
53
+ duration = info.num_frames / info.sample_rate
54
+ return duration <= max_duration, duration
55
+ except Exception:
56
+ # If we can't determine the length, we'll try to process it anyway
57
+ return True, 0
58
+
59
+ def process_with_progress(
60
+ progress=gr.Progress(),
61
+ input_audio=None,
62
+ speaker=None,
63
+ key_shift=0,
64
+ infer_steps=32,
65
+ robust_f0=0,
66
+ # Advanced CFG parameters
67
+ ds_cfg_strength=0.05,
68
+ spk_cfg_strength=1.0,
69
+ skip_cfg_strength=0.0,
70
+ cfg_skip_layers=6,
71
+ cfg_rescale=0.7,
72
+ cvec_downsample_rate=2,
73
+ # Slicer parameters
74
+ slicer_threshold=-30.0,
75
+ slicer_min_length=3000,
76
+ slicer_min_interval=100,
77
+ slicer_hop_size=10,
78
+ slicer_max_sil_kept=200
79
+ ):
80
+ global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
81
+
82
+ # Fixed target loudness value
83
+ target_loudness = -18.0
84
+
85
+ # Fixed audio parameters
86
+ restore_loudness = True
87
+ fade_duration = 20.0
88
+ sliced_inference = False
89
+
90
+ # Input validation
91
+ if input_audio is None:
92
+ return None, "❌ 错误: 未提供输入音频。"
93
+
94
+ if svc_model is None:
95
+ return None, "❌ 错误: 模型未加载。请重新加载页面或检查模型路径。"
96
+
97
+ if speaker is None or speaker not in spk2idx:
98
+ return None, f"❌ 错误: 无效的说话人选择。可用说话人: {', '.join(spk2idx.keys())}"
99
+
100
+ # Check audio length to avoid memory issues
101
+ is_safe_length, duration = check_audio_length(input_audio)
102
+ if not is_safe_length:
103
+ return None, f"❌ 错误: 音频过长 ({duration:.1f} 秒)。允许的最大时长为 {MAX_AUDIO_DURATION} 秒。"
104
+
105
+ # Process the audio
106
+ try:
107
+ # Update status message
108
+ progress(0, desc="处理中: 加载音频...")
109
+
110
+ # Convert speaker name to ID
111
+ speaker_id = spk2idx[speaker]
112
+
113
+ # Get config from loaded model
114
+ hop_length = 512
115
+ sample_rate = 44100
116
+
117
+ # Load audio
118
+ audio = load_audio(input_audio, sample_rate)
119
+
120
+ # Initialize Slicer
121
+ slicer = Slicer(
122
+ sr=sample_rate,
123
+ threshold=slicer_threshold,
124
+ min_length=slicer_min_length,
125
+ min_interval=slicer_min_interval,
126
+ hop_size=slicer_hop_size,
127
+ max_sil_kept=slicer_max_sil_kept
128
+ )
129
+
130
+ progress(0.1, desc="处理中: 切分音频...")
131
+ # Slice the input audio
132
+ segments_with_pos = slicer.slice(audio)
133
+
134
+ if not segments_with_pos:
135
+ return None, "❌ 错误: 在输入文件中未找到有效的音频片段。"
136
+
137
+ # Calculate fade size in samples
138
+ fade_samples = int(fade_duration * sample_rate / 1000)
139
+
140
+ # Process segments
141
+ result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
142
+
143
+ progress(0.2, desc="处理中: 开始转换...")
144
+
145
+ with torch.no_grad():
146
+ for i, (start_sample, chunk) in enumerate(segments_with_pos):
147
+ segment_progress = 0.2 + (0.7 * (i / len(segments_with_pos)))
148
+ progress(segment_progress, desc=f"处理中: 片段 {i+1}/{len(segments_with_pos)}")
149
+
150
+ # Process the segment
151
+ audio_out = process_segment(
152
+ chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
153
+ speaker_id, sample_rate, hop_length, device,
154
+ key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
155
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
156
+ cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
157
+ robust_f0
158
+ )
159
+
160
+ # Ensure consistent length
161
+ expected_length = len(chunk)
162
+ if len(audio_out) > expected_length:
163
+ audio_out = audio_out[:expected_length]
164
+ elif len(audio_out) < expected_length:
165
+ audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
166
+
167
+ # Apply fades
168
+ if i > 0: # Not first segment
169
+ audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
170
+ result_audio[start_sample:start_sample + fade_samples] *= \
171
+ np.linspace(1, 0, fade_samples) # Fade out previous
172
+
173
+ if i < len(segments_with_pos) - 1: # Not last segment
174
+ audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
175
+
176
+ # Add to result
177
+ result_audio[start_sample:start_sample + len(audio_out)] += audio_out
178
+
179
+ # Clean up memory after each segment
180
+ torch.cuda.empty_cache()
181
+
182
+ progress(0.9, desc="处理中: 完成音频...")
183
+ # Trim any extra padding
184
+ result_audio = result_audio[:len(audio)]
185
+
186
+ # Create a temporary file to save the result
187
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
188
+ output_path = temp_file.name
189
+
190
+ # Save output
191
+ torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate)
192
+
193
+ progress(1.0, desc="处理完成!")
194
+ return (sample_rate, result_audio), f"✅ 转换完成! 已转换为 **{speaker}** 并调整 **{key_shift}** 个半音。"
195
+
196
+ except RuntimeError as e:
197
+ # Handle CUDA out of memory errors
198
+ if "CUDA out of memory" in str(e):
199
+ # Clean up memory
200
+ torch.cuda.empty_cache()
201
+ gc.collect()
202
+
203
+ return None, f"❌ 错误: 内存不足。请尝试更短的音频文件或减少推理步骤。"
204
+ else:
205
+ return None, f"❌ 转换过程中出错: {str(e)}"
206
+ except Exception as e:
207
+ error_trace = traceback.format_exc()
208
+ return None, f"❌ 转换过程中出错: {str(e)}\n\n详细信息: {error_trace}"
209
+ finally:
210
+ # Clean up memory
211
+ torch.cuda.empty_cache()
212
+ gc.collect()
213
+
214
+ def create_ui():
215
+ # CSS for better styling
216
+ css = """
217
+ .gradio-container {
218
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
219
+ }
220
+ .container {
221
+ max-width: 1200px;
222
+ margin: auto;
223
+ }
224
+ .footer {
225
+ margin-top: 20px;
226
+ text-align: center;
227
+ font-size: 0.9em;
228
+ color: #666;
229
+ }
230
+ .title {
231
+ text-align: center;
232
+ margin-bottom: 10px;
233
+ }
234
+ .subtitle {
235
+ text-align: center;
236
+ margin-bottom: 20px;
237
+ color: #666;
238
+ }
239
+ .button-primary {
240
+ background-color: #5460DE !important;
241
+ }
242
+ .output-message {
243
+ margin-top: 10px;
244
+ padding: 10px;
245
+ border-radius: 4px;
246
+ background-color: #f8f9fa;
247
+ border-left: 4px solid #5460DE;
248
+ }
249
+ .error-message {
250
+ color: #d62828;
251
+ font-weight: bold;
252
+ }
253
+ .success-message {
254
+ color: #588157;
255
+ font-weight: bold;
256
+ }
257
+ .info-box {
258
+ background-color: #f8f9fa;
259
+ border-left: 4px solid #5460DE;
260
+ padding: 10px;
261
+ margin: 10px 0;
262
+ border-radius: 4px;
263
+ }
264
+ """
265
+
266
+ # Initialize models
267
+ available_speakers, init_message = initialize_models()
268
+
269
+ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="RIFT-SVC 声音转换") as app:
270
+ gr.HTML("""
271
+ <div class="title">
272
+ <h1>🎤 RIFT-SVC 歌声音色转换 (七海Nanami demo)</h1>
273
+ </div>
274
+ <div class="subtitle">
275
+ <h3>使用 RIFT-SVC 模型将歌声或语音转换为七海Nanami的音色</h3>
276
+ </div>
277
+ <div class="info-box">
278
+ <p>📝 <strong>注意:</strong> 为获得最佳效果,请使用背景噪音较少的干净音频。最大音频长度为5分钟。</p>
279
+ </div>
280
+ <div class="info-box">
281
+ <p>🔗 <strong>想要微调自己的说话人?</strong> 请访问 <a href="https://github.com/Pur1zumu/RIFT-SVC" target="_blank">RIFT-SVC GitHub 仓库</a> 获取完整的��练和微调指南。</p>
282
+ </div>
283
+ """)
284
+
285
+ with gr.Row():
286
+ # Left column (input parameters)
287
+ with gr.Column(scale=1):
288
+ with gr.Group():
289
+ gr.Markdown("### 📥 输入")
290
+ model_path = gr.Textbox(label="模型路径", value=DEFAULT_MODEL_PATH, interactive=True)
291
+ input_audio = gr.Audio(label="输入音频文件", type="filepath", elem_id="input_audio")
292
+ reload_btn = gr.Button("🔄 重新加载模型", elem_id="reload_btn")
293
+
294
+ with gr.Accordion("⚙️ 基本参数", open=True):
295
+ speaker = gr.Dropdown(choices=available_speakers, label="目标说话人", interactive=True, elem_id="speaker")
296
+ key_shift = gr.Slider(minimum=-12, maximum=12, step=1, value=0, label="音调调整(半音)", elem_id="key_shift")
297
+ infer_steps = gr.Slider(minimum=8, maximum=64, step=1, value=32, label="推理步数", elem_id="infer_steps",
298
+ info="更低的值 = 更快但质量较低,更高的值 = 更慢但质量更好")
299
+ robust_f0 = gr.Radio(choices=[0, 1, 2], value=0, label="音高滤波",
300
+ info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)",
301
+ elem_id="robust_f0")
302
+
303
+ with gr.Accordion("🔬 高级CFG参数", open=True):
304
+ ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05,
305
+ label="内容向量引导强度",
306
+ info="更高的值可以改善内容保留和咬字清晰度。过高会用力过猛。",
307
+ elem_id="ds_cfg_strength")
308
+ spk_cfg_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0,
309
+ label="说话人引导强度",
310
+ info="更高的值可以增强说话人相似度。过高可能导致音色失真。",
311
+ elem_id="spk_cfg_strength")
312
+ skip_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0,
313
+ label="层引导强度(实验性功能)",
314
+ info="增强指定层的特征渲染。效果取决于目标层的功能。",
315
+ elem_id="skip_cfg_strength")
316
+ cfg_skip_layers = gr.Number(value=6, label="CFG跳过层(实验性功能)", precision=0,
317
+ info="目标增强层下标",
318
+ elem_id="cfg_skip_layers")
319
+ cfg_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7,
320
+ label="CFG重缩放因子",
321
+ info="约束整体引导强度。当引导效果过于强烈时使用调高该值。",
322
+ elem_id="cfg_rescale")
323
+ cvec_downsample_rate = gr.Radio(choices=[1, 2, 4, 8], value=2,
324
+ label="用于反向引导的内容向量下采样率",
325
+ info="更高的值(可能)可以提高内容清晰度。",
326
+ elem_id="cvec_downsample_rate")
327
+
328
+ with gr.Accordion("✂️ 切片参数", open=False):
329
+ slicer_threshold = gr.Slider(minimum=-60.0, maximum=-20.0, step=0.1, value=-30.0,
330
+ label="阈值 (dB)",
331
+ info="静音检测阈值",
332
+ elem_id="slicer_threshold")
333
+ slicer_min_length = gr.Slider(minimum=1000, maximum=10000, step=100, value=3000,
334
+ label="最小长度 (毫秒)",
335
+ info="最小片段长度",
336
+ elem_id="slicer_min_length")
337
+ slicer_min_interval = gr.Slider(minimum=10, maximum=500, step=10, value=100,
338
+ label="最小静音间隔 (毫秒)",
339
+ info="片段之间的最小静音间隔",
340
+ elem_id="slicer_min_interval")
341
+ slicer_hop_size = gr.Slider(minimum=1, maximum=50, step=1, value=10,
342
+ label="跳跃大小 (毫秒)",
343
+ info="分析窗口跳跃大小",
344
+ elem_id="slicer_hop_size")
345
+ slicer_max_sil_kept = gr.Slider(minimum=50, maximum=10000, step=10, value=200,
346
+ label="最大保留静音 (毫秒)",
347
+ info="边界处保留的最大静音",
348
+ elem_id="slicer_max_sil_kept")
349
+
350
+
351
+ # Right column (output)
352
+ with gr.Column(scale=1):
353
+ convert_btn = gr.Button("🎵 转换声音", variant="primary", elem_id="convert_btn")
354
+ gr.Markdown("### 📤 输出")
355
+ output_audio = gr.Audio(label="转换后的音频", elem_id="output_audio", autoplay=False, show_share_button=False)
356
+ output_message = gr.Markdown(init_message, elem_id="output_message", elem_classes="output-message")
357
+
358
+ gr.HTML("""
359
+ <div class="info-box">
360
+ <h4>🔍 快速提示</h4>
361
+ <ul>
362
+ <li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li>
363
+ <li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li>
364
+ <li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li>
365
+ <li><strong>CFG参数:</strong> 调整转换质量和音色。</li>
366
+ </ul>
367
+ </div>
368
+ """)
369
+
370
+ # Define button click events
371
+ reload_btn.click(
372
+ fn=initialize_models,
373
+ inputs=[model_path],
374
+ outputs=[speaker, output_message]
375
+ )
376
+
377
+ # Updated convert button click event
378
+ convert_btn.click(
379
+ fn=lambda: "⏳ 处理中... 请稍候。",
380
+ inputs=None,
381
+ outputs=output_message,
382
+ queue=False
383
+ ).then(
384
+ fn=process_with_progress,
385
+ inputs=[
386
+ input_audio, speaker, key_shift, infer_steps, robust_f0,
387
+ ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate,
388
+ slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept
389
+ ],
390
+ outputs=[output_audio, output_message],
391
+ show_progress_on=output_audio
392
+ )
393
+
394
+ return app
395
+
396
+ if __name__ == "__main__":
397
+ app = create_ui()
398
+ app.launch()
infer.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import librosa
3
+ import numpy as np
4
+ import pyloudnorm as pyln
5
+ import torch
6
+ import torchaudio
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+
10
+ from rift_svc import DiT, RF
11
+ from rift_svc.feature_extractors import HubertModelWithFinalProj, RMSExtractor, get_mel_spectrogram
12
+ from rift_svc.nsf_hifigan import NsfHifiGAN
13
+ from rift_svc.rmvpe import RMVPE
14
+ from rift_svc.utils import linear_interpolate_tensor, post_process_f0, f0_ensemble, f0_ensemble_light, get_f0_pw, get_f0_pm
15
+ from slicer import Slicer
16
+
17
+
18
+ torch.set_grad_enabled(False)
19
+
20
+
21
+ def extract_state_dict(ckpt):
22
+ state_dict = ckpt['state_dict']
23
+ new_state_dict = {}
24
+ for k, v in state_dict.items():
25
+ if k.startswith('model.'):
26
+ new_k = k.replace('model.', '')
27
+ new_state_dict[new_k] = v
28
+ spk2idx = ckpt['hyper_parameters']['cfg']['spk2idx']
29
+ model_cfg = ckpt['hyper_parameters']['cfg']['model']
30
+ dataset_cfg = ckpt['hyper_parameters']['cfg']['dataset']
31
+ return new_state_dict, spk2idx, model_cfg, dataset_cfg
32
+
33
+
34
+ def load_models(model_path, device):
35
+ """Load all required models and return them"""
36
+ click.echo("Loading models...")
37
+
38
+ # Load the conversion model
39
+ ckpt = torch.load(model_path, map_location='cpu')
40
+ state_dict, spk2idx, dit_cfg, dataset_cfg = extract_state_dict(ckpt)
41
+
42
+ transformer = DiT(num_speaker=len(spk2idx), **dit_cfg)
43
+ svc_model = RF(transformer=transformer)
44
+ svc_model.load_state_dict(state_dict)
45
+ svc_model = svc_model.to(device)
46
+ svc_model.eval()
47
+
48
+ # Load additional models
49
+ vocoder = NsfHifiGAN('pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt').to(device)
50
+ rmvpe = RMVPE(model_path="pretrained/rmvpe/model.pt", hop_length=160, device=device)
51
+ hubert = HubertModelWithFinalProj.from_pretrained("pretrained/content-vec-best").to(device)
52
+ rms_extractor = RMSExtractor().to(device)
53
+
54
+ return svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
55
+
56
+
57
+ def load_audio(file_path, target_sr):
58
+ """Load and preprocess audio file"""
59
+ click.echo("Loading audio...")
60
+ audio, sr = torchaudio.load(file_path)
61
+ if sr != target_sr:
62
+ audio = torchaudio.functional.resample(audio, sr, target_sr)
63
+
64
+ if len(audio.shape) > 1:
65
+ audio = audio.mean(dim=0, keepdim=True)
66
+
67
+ return audio.numpy().squeeze()
68
+
69
+
70
+ def apply_fade(audio, fade_samples, fade_in=True):
71
+ """Apply fade in/out using half of a Hanning window"""
72
+ fade_window = np.hanning(fade_samples * 2)
73
+ if fade_in:
74
+ fade_curve = fade_window[:fade_samples]
75
+ else:
76
+ fade_curve = fade_window[fade_samples:]
77
+ audio[:fade_samples] *= fade_curve
78
+ return audio
79
+
80
+
81
+ def extract_features(audio_segment, sample_rate, hop_length, rmvpe, hubert, rms_extractor,
82
+ device, key_shift=0, ds_cfg_strength=0.0, cvec_downsample_rate=2, target_loudness=-18.0,
83
+ robust_f0=0):
84
+ """Extract all required features from an audio segment"""
85
+ # Normalize input segment
86
+ meter = pyln.Meter(sample_rate, block_size=0.1)
87
+ original_loudness = meter.integrated_loudness(audio_segment)
88
+ normalized_audio = pyln.normalize.loudness(audio_segment, original_loudness, target_loudness)
89
+
90
+ # Handle potential clipping
91
+ max_amp = np.max(np.abs(normalized_audio))
92
+ if max_amp > 1.0:
93
+ normalized_audio = normalized_audio * (0.99 / max_amp)
94
+
95
+ audio_tensor = torch.from_numpy(normalized_audio).float().unsqueeze(0).to(device)
96
+ audio_16khz = torch.from_numpy(librosa.resample(normalized_audio, orig_sr=sample_rate, target_sr=16000)).float().unsqueeze(0).to(device)
97
+
98
+ # Extract mel spectrogram
99
+ mel = get_mel_spectrogram(
100
+ audio_tensor,
101
+ sampling_rate=sample_rate,
102
+ n_fft=2048,
103
+ num_mels=128,
104
+ hop_size=512,
105
+ win_size=2048,
106
+ fmin=40,
107
+ fmax=16000
108
+ ).transpose(1, 2)
109
+
110
+ # Extract content vector
111
+ cvec = hubert(audio_16khz)["last_hidden_state"].squeeze(0)
112
+ cvec = linear_interpolate_tensor(cvec, mel.shape[1])[None, :]
113
+
114
+ # Create bad_cvec (downsampled) for classifier-free guidance
115
+ if ds_cfg_strength > 0:
116
+ cvec_ds = cvec.clone()
117
+ # Downsample and then interpolate back, similar to dataset.py
118
+ cvec_ds = cvec_ds[0, ::2, :] # Take every other frame
119
+ cvec_ds = linear_interpolate_tensor(cvec_ds, cvec_ds.shape[0]//cvec_downsample_rate)
120
+ cvec_ds = linear_interpolate_tensor(cvec_ds, mel.shape[1])[None, :]
121
+ else:
122
+ cvec_ds = None
123
+
124
+ # Extract f0
125
+ if robust_f0 > 0:
126
+ # Parameters for F0 extraction
127
+ time_step = hop_length / sample_rate
128
+ f0_min = 40
129
+ f0_max = 1100
130
+
131
+ # Extract F0 using multiple methods
132
+ rmvpe_f0 = rmvpe.infer_from_audio(audio_tensor, sample_rate=sample_rate, device=device)
133
+ rmvpe_f0 = post_process_f0(rmvpe_f0, sample_rate, hop_length, mel.shape[1], silence_front=0.0, cut_last=False)
134
+ pw_f0 = get_f0_pw(normalized_audio, sample_rate, time_step, f0_min, f0_max)
135
+ pmac_f0 = get_f0_pm(normalized_audio, sample_rate, time_step, f0_min, f0_max)
136
+
137
+ if robust_f0 == 1:
138
+ # Level 1: Light ensemble that preserves expressiveness
139
+ rms_np = rms_extractor(audio_tensor).squeeze().cpu().numpy()
140
+ f0 = f0_ensemble_light(rmvpe_f0, pw_f0, pmac_f0, rms=rms_np)
141
+ else:
142
+ # Level 2: Strong ensemble with more filtering
143
+ f0 = f0_ensemble(rmvpe_f0, pw_f0, pmac_f0)
144
+ else:
145
+ # Level 0: Use only RMVPE for F0 extraction (original method)
146
+ f0 = rmvpe.infer_from_audio(audio_tensor, sample_rate=sample_rate, device=device)
147
+ f0 = post_process_f0(f0, sample_rate, hop_length, mel.shape[1], silence_front=0.0, cut_last=False)
148
+
149
+ if key_shift != 0:
150
+ f0 = f0 * 2 ** (key_shift / 12)
151
+ f0 = torch.from_numpy(f0).float().to(device)[None, :]
152
+
153
+ # Extract RMS
154
+ rms = rms_extractor(audio_tensor)
155
+
156
+ return mel, cvec, cvec_ds, f0, rms, original_loudness
157
+
158
+
159
+ def run_inference(
160
+ model, mel, cvec, f0, rms, cvec_ds, spk_id,
161
+ infer_steps, ds_cfg_strength, spk_cfg_strength,
162
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
163
+ sliced_inference=False
164
+ ):
165
+ """Run the actual inference through the model"""
166
+ if sliced_inference:
167
+ # Use sliced inference for long segments
168
+ sliced_len = 256
169
+ mel_crossfade_len = 8 # Number of frames to crossfade in mel domain
170
+
171
+ # If the segment is shorter than one slice, just process it directly
172
+ if mel.shape[1] <= sliced_len:
173
+ mel_out, _ = model.sample(
174
+ src_mel=mel,
175
+ spk_id=spk_id,
176
+ f0=f0,
177
+ rms=rms,
178
+ cvec=cvec,
179
+ steps=infer_steps,
180
+ bad_cvec=cvec_ds,
181
+ ds_cfg_strength=ds_cfg_strength,
182
+ spk_cfg_strength=spk_cfg_strength,
183
+ skip_cfg_strength=skip_cfg_strength,
184
+ cfg_skip_layers=cfg_skip_layers,
185
+ cfg_rescale=cfg_rescale,
186
+ )
187
+ return mel_out
188
+
189
+ # Create a tensor to hold the full output with crossfading
190
+ full_mel_out = torch.zeros_like(mel)
191
+
192
+ # Process each slice
193
+ for i in range(0, mel.shape[1], sliced_len - mel_crossfade_len):
194
+ # Determine slice boundaries
195
+ start_idx = i
196
+ end_idx = min(i + sliced_len, mel.shape[1])
197
+
198
+ # Skip if we're at the end
199
+ if start_idx >= mel.shape[1]:
200
+ break
201
+
202
+ # Extract slices for this window
203
+ mel_slice = mel[:, start_idx:end_idx, :]
204
+ cvec_slice = cvec[:, start_idx:end_idx, :]
205
+ f0_slice = f0[:, start_idx:end_idx]
206
+ rms_slice = rms[:, start_idx:end_idx]
207
+
208
+ # Slice the bad_cvec if it exists
209
+ cvec_ds_slice = None
210
+ if cvec_ds is not None:
211
+ cvec_ds_slice = cvec_ds[:, start_idx:end_idx, :]
212
+
213
+ # Process with model
214
+ mel_out_slice, _ = model.sample(
215
+ src_mel=mel_slice,
216
+ spk_id=spk_id,
217
+ f0=f0_slice,
218
+ rms=rms_slice,
219
+ cvec=cvec_slice,
220
+ steps=infer_steps,
221
+ bad_cvec=cvec_ds_slice,
222
+ ds_cfg_strength=ds_cfg_strength,
223
+ spk_cfg_strength=spk_cfg_strength,
224
+ skip_cfg_strength=skip_cfg_strength,
225
+ cfg_skip_layers=cfg_skip_layers,
226
+ cfg_rescale=cfg_rescale,
227
+ )
228
+
229
+ # Create crossfade weights
230
+ slice_len = end_idx - start_idx
231
+
232
+ # Apply different strategies depending on position
233
+ if i == 0: # First slice
234
+ # No crossfade at the beginning
235
+ weights = torch.ones((1, slice_len, 1), device=mel.device)
236
+ if i + sliced_len < mel.shape[1]: # If not the last slice too
237
+ # Fade out at the end - use the minimum of slice_len and mel_crossfade_len
238
+ actual_crossfade_len = min(mel_crossfade_len, slice_len)
239
+ if actual_crossfade_len > 0: # Only apply if we have space
240
+ fade_out = torch.linspace(1, 0, actual_crossfade_len, device=mel.device)
241
+ weights[:, -actual_crossfade_len:, :] = fade_out.view(1, -1, 1)
242
+ elif end_idx >= mel.shape[1]: # Last slice
243
+ # Fade in at the beginning - use the minimum of slice_len and mel_crossfade_len
244
+ weights = torch.ones((1, slice_len, 1), device=mel.device)
245
+ actual_crossfade_len = min(mel_crossfade_len, slice_len)
246
+ if actual_crossfade_len > 0: # Only apply if we have space
247
+ fade_in = torch.linspace(0, 1, actual_crossfade_len, device=mel.device)
248
+ weights[:, :actual_crossfade_len, :] = fade_in.view(1, -1, 1)
249
+ else: # Middle slices
250
+ # Crossfade both sides, handling the case where slice_len < 2*mel_crossfade_len
251
+ weights = torch.ones((1, slice_len, 1), device=mel.device)
252
+
253
+ # Determine the actual crossfade length (might be shorter for small slices)
254
+ actual_crossfade_len = min(mel_crossfade_len, slice_len // 2)
255
+ if actual_crossfade_len > 0:
256
+ fade_in = torch.linspace(0, 1, actual_crossfade_len, device=mel.device)
257
+ fade_out = torch.linspace(1, 0, actual_crossfade_len, device=mel.device)
258
+ weights[:, :actual_crossfade_len, :] = fade_in.view(1, -1, 1)
259
+ weights[:, -actual_crossfade_len:, :] = fade_out.view(1, -1, 1)
260
+
261
+ # Apply weights to current slice output
262
+ mel_out_slice = mel_out_slice * weights
263
+
264
+ # Add to the appropriate region of the output
265
+ full_mel_out[:, start_idx:end_idx, :] += mel_out_slice
266
+
267
+ # Return the full crossfaded output
268
+ mel_out = full_mel_out
269
+ else:
270
+ # Process the entire segment at once
271
+ mel_out, _ = model.sample(
272
+ src_mel=mel,
273
+ spk_id=spk_id,
274
+ f0=f0,
275
+ rms=rms,
276
+ cvec=cvec,
277
+ steps=infer_steps,
278
+ bad_cvec=cvec_ds,
279
+ ds_cfg_strength=ds_cfg_strength,
280
+ spk_cfg_strength=spk_cfg_strength,
281
+ skip_cfg_strength=skip_cfg_strength,
282
+ cfg_skip_layers=cfg_skip_layers,
283
+ cfg_rescale=cfg_rescale,
284
+ )
285
+
286
+ return mel_out
287
+
288
+
289
+ def generate_audio(vocoder, mel_out, f0, original_loudness=None, restore_loudness=True):
290
+ """Generate audio from mel spectrogram using vocoder"""
291
+ audio_out = vocoder(mel_out.transpose(1, 2), f0)
292
+ audio_out = audio_out.squeeze().cpu().numpy()
293
+
294
+ if restore_loudness and original_loudness is not None:
295
+ # Restore original loudness
296
+ meter = pyln.Meter(44100, block_size=0.1) # Using default sample rate for vocoder
297
+ audio_out_loudness = meter.integrated_loudness(audio_out)
298
+ audio_out = pyln.normalize.loudness(audio_out, audio_out_loudness, original_loudness)
299
+
300
+ # Handle clipping
301
+ max_amp = np.max(np.abs(audio_out))
302
+ if max_amp > 1.0:
303
+ audio_out = audio_out * (0.99 / max_amp)
304
+
305
+ return audio_out
306
+
307
+
308
+ def process_segment(
309
+ audio_segment,
310
+ svc_model, vocoder, rmvpe, hubert, rms_extractor,
311
+ speaker_id, sample_rate, hop_length, device,
312
+ key_shift=0,
313
+ infer_steps=32,
314
+ ds_cfg_strength=0.0,
315
+ spk_cfg_strength=0.0,
316
+ skip_cfg_strength=0.0,
317
+ cfg_skip_layers=None,
318
+ cfg_rescale=0.7,
319
+ cvec_downsample_rate=2,
320
+ target_loudness=-18.0,
321
+ restore_loudness=True,
322
+ sliced_inference=False,
323
+ robust_f0=0
324
+ ):
325
+ """Process a single audio segment and return the converted audio"""
326
+ # Extract features
327
+ mel, cvec, cvec_ds, f0, rms, original_loudness = extract_features(
328
+ audio_segment, sample_rate, hop_length, rmvpe, hubert, rms_extractor,
329
+ device, key_shift, ds_cfg_strength, cvec_downsample_rate, target_loudness,
330
+ robust_f0
331
+ )
332
+
333
+ # Prepare speaker ID
334
+ spk_id = torch.LongTensor([speaker_id]).to(device)
335
+
336
+ # Run inference
337
+ mel_out = run_inference(
338
+ svc_model, mel, cvec, f0, rms, cvec_ds, spk_id,
339
+ infer_steps, ds_cfg_strength, spk_cfg_strength,
340
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
341
+ sliced_inference
342
+ )
343
+
344
+ # Generate audio
345
+ audio_out = generate_audio(
346
+ vocoder, mel_out, f0,
347
+ original_loudness if restore_loudness else None,
348
+ restore_loudness
349
+ )
350
+
351
+ return audio_out
352
+
353
+
354
+ @click.command()
355
+ @click.option('--model', type=click.Path(exists=True), required=True, help='Path to model checkpoint')
356
+ @click.option('--input', type=click.Path(exists=True), required=True, help='Input audio file')
357
+ @click.option('--output', type=click.Path(), required=True, help='Output audio file')
358
+ @click.option('--speaker', type=str, required=True, help='Target speaker')
359
+ @click.option('--key-shift', type=int, default=0, help='Pitch shift in semitones')
360
+ @click.option('--device', type=str, default=None, help='Device to use (cuda/cpu)')
361
+ @click.option('--infer-steps', type=int, default=32, help='Number of inference steps')
362
+ @click.option('--ds-cfg-strength', type=float, default=0.0, help='Downsampled content vector guidance strength')
363
+ @click.option('--spk-cfg-strength', type=float, default=0.0, help='Speaker guidance strength')
364
+ @click.option('--skip-cfg-strength', type=float, default=0.0, help='Skip layer guidance strength')
365
+ @click.option('--cfg-skip-layers', type=int, default=None, help='Layer to skip for classifier-free guidance')
366
+ @click.option('--cfg-rescale', type=float, default=0.7, help='Classifier-free guidance rescale factor')
367
+ @click.option('--cvec-downsample-rate', type=int, default=2, help='Downsampling rate for bad_cvec creation')
368
+ @click.option('--target-loudness', type=float, default=-18.0, help='Target loudness in LUFS for normalization')
369
+ @click.option('--restore-loudness', default=True, help='Restore loudness to original')
370
+ @click.option('--fade-duration', type=float, default=20.0, help='Fade duration in milliseconds')
371
+ @click.option('--sliced-inference', is_flag=True, default=False, help='Use sliced inference for processing long segments')
372
+ @click.option('--robust-f0', type=int, default=0, help='Level of robust f0 filtering (0=none, 1=light, 2=aggressive)')
373
+ @click.option('--slicer-threshold', type=float, default=-35.0, help='Threshold for audio slicing in dB')
374
+ @click.option('--slicer-min-length', type=int, default=3000, help='Minimum length of audio segments in milliseconds')
375
+ @click.option('--slicer-min-interval', type=int, default=100, help='Minimum interval between audio segments in milliseconds')
376
+ @click.option('--slicer-hop-size', type=int, default=10, help='Hop size for audio slicing in milliseconds')
377
+ @click.option('--slicer-max-sil-kept', type=int, default=300, help='Maximum silence kept in milliseconds')
378
+ def main(
379
+ model,
380
+ input,
381
+ output,
382
+ speaker,
383
+ key_shift,
384
+ device,
385
+ infer_steps,
386
+ ds_cfg_strength,
387
+ spk_cfg_strength,
388
+ skip_cfg_strength,
389
+ cfg_skip_layers,
390
+ cfg_rescale,
391
+ cvec_downsample_rate,
392
+ target_loudness,
393
+ restore_loudness,
394
+ fade_duration,
395
+ sliced_inference,
396
+ robust_f0,
397
+ slicer_threshold,
398
+ slicer_min_length,
399
+ slicer_min_interval,
400
+ slicer_hop_size,
401
+ slicer_max_sil_kept
402
+ ):
403
+ """Convert the voice in an audio file to a target speaker."""
404
+
405
+ # Setup device
406
+ if device is None:
407
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
408
+ device = torch.device(device)
409
+
410
+ # Load models
411
+ svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg = load_models(model, device)
412
+
413
+ try:
414
+ speaker_id = spk2idx[speaker]
415
+ except KeyError:
416
+ raise ValueError(f"Speaker {speaker} not found in the model's speaker list, valid speakers are {spk2idx.keys()}")
417
+
418
+ # Get config from loaded model
419
+ hop_length = 512
420
+ sample_rate = 44100
421
+
422
+ # Load audio
423
+ audio = load_audio(input, sample_rate)
424
+
425
+ # Initialize Slicer
426
+ slicer = Slicer(
427
+ sr=sample_rate,
428
+ threshold=slicer_threshold,
429
+ min_length=slicer_min_length,
430
+ min_interval=slicer_min_interval,
431
+ hop_size=slicer_hop_size,
432
+ max_sil_kept=slicer_max_sil_kept
433
+ )
434
+
435
+ # Step (1): Use slicer to segment the input audio and get positions
436
+ click.echo("Slicing audio...")
437
+ segments_with_pos = slicer.slice(audio) # Now returns list of (start_pos, chunk)
438
+
439
+ if restore_loudness:
440
+ click.echo(f"Will restore loudness to original")
441
+
442
+ # Calculate fade size in samples
443
+ fade_samples = int(fade_duration * sample_rate / 1000)
444
+
445
+ # Process segments
446
+ click.echo(f"Processing {len(segments_with_pos)} segments...")
447
+ result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
448
+
449
+ with torch.no_grad():
450
+ for idx, (start_sample, chunk) in enumerate(tqdm(segments_with_pos)):
451
+
452
+ # Process the segment
453
+ audio_out = process_segment(
454
+ chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
455
+ speaker_id, sample_rate, hop_length, device,
456
+ key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
457
+ skip_cfg_strength, cfg_skip_layers, cfg_rescale,
458
+ cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
459
+ robust_f0
460
+ )
461
+
462
+ # Ensure consistent length
463
+ expected_length = len(chunk)
464
+ if len(audio_out) > expected_length:
465
+ audio_out = audio_out[:expected_length]
466
+ elif len(audio_out) < expected_length:
467
+ audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
468
+
469
+ # Apply fades
470
+ if idx > 0: # Not first segment
471
+ audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
472
+ result_audio[start_sample:start_sample + fade_samples] *= \
473
+ np.linspace(1, 0, fade_samples) # Fade out previous
474
+
475
+ if idx < len(segments_with_pos) - 1: # Not last segment
476
+ audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
477
+
478
+ # Add to result
479
+ result_audio[start_sample:start_sample + len(audio_out)] += audio_out
480
+
481
+ # Trim any extra padding
482
+ result_audio = result_audio[:len(audio)]
483
+
484
+ # Save output
485
+ click.echo("Saving output...")
486
+ output_path = Path(output)
487
+ output_path.parent.mkdir(parents=True, exist_ok=True)
488
+ torchaudio.save(output, torch.from_numpy(result_audio).unsqueeze(0), sample_rate)
489
+ click.echo("Done!")
490
+
491
+
492
+ if __name__ == '__main__':
493
+ main()
pretrained/content-vec-best/.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pretrained/content-vec-best/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ content-vec-best-legacy-500.pt
pretrained/content-vec-best/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # Content Vec Best
6
+ Official Repo: [ContentVec](https://github.com/auspicious3000/contentvec)
7
+ This repo brings fairseq ContentVec model to HuggingFace Transformers.
8
+
9
+ ## How to use
10
+ To use this model, you need to define
11
+ ```python
12
+ class HubertModelWithFinalProj(HubertModel):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+
16
+ # The final projection layer is only used for backward compatibility.
17
+ # Following https://github.com/auspicious3000/contentvec/issues/6
18
+ # Remove this layer is necessary to achieve the desired outcome.
19
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
20
+ ```
21
+
22
+ and then load the model with
23
+ ```python
24
+ model = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best")
25
+
26
+ x = model(audio)["last_hidden_state"]
27
+ ```
28
+
29
+ ## How to convert
30
+ You need to download the ContentVec_legacy model from the official repo, and then run
31
+ ```bash
32
+ python convert.py
33
+ ```
pretrained/content-vec-best/config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "HubertModelWithFinalProj"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "classifier_proj_size": 256,
10
+ "conv_bias": false,
11
+ "conv_dim": [
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512
19
+ ],
20
+ "conv_kernel": [
21
+ 10,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 2,
27
+ 2
28
+ ],
29
+ "conv_stride": [
30
+ 5,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2
37
+ ],
38
+ "ctc_loss_reduction": "sum",
39
+ "ctc_zero_infinity": false,
40
+ "do_stable_layer_norm": false,
41
+ "eos_token_id": 2,
42
+ "feat_extract_activation": "gelu",
43
+ "feat_extract_norm": "group",
44
+ "feat_proj_dropout": 0.0,
45
+ "feat_proj_layer_norm": true,
46
+ "final_dropout": 0.1,
47
+ "hidden_act": "gelu",
48
+ "hidden_dropout": 0.1,
49
+ "hidden_size": 768,
50
+ "initializer_range": 0.02,
51
+ "intermediate_size": 3072,
52
+ "layer_norm_eps": 1e-05,
53
+ "layerdrop": 0.1,
54
+ "mask_feature_length": 10,
55
+ "mask_feature_min_masks": 0,
56
+ "mask_feature_prob": 0.0,
57
+ "mask_time_length": 10,
58
+ "mask_time_min_masks": 2,
59
+ "mask_time_prob": 0.05,
60
+ "model_type": "hubert",
61
+ "num_attention_heads": 12,
62
+ "num_conv_pos_embedding_groups": 16,
63
+ "num_conv_pos_embeddings": 128,
64
+ "num_feat_extract_layers": 7,
65
+ "num_hidden_layers": 12,
66
+ "pad_token_id": 0,
67
+ "torch_dtype": "float32",
68
+ "transformers_version": "4.27.3",
69
+ "use_weighted_layer_sum": false,
70
+ "vocab_size": 32
71
+ }
pretrained/content-vec-best/convert.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import HubertConfig, HubertModel
4
+ import logging
5
+
6
+ # Ignore fairseq's logger
7
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
8
+ logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING)
9
+
10
+ from fairseq import checkpoint_utils
11
+
12
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
13
+ ["content-vec-best-legacy-500.pt"], suffix=""
14
+ )
15
+ model = models[0]
16
+ model.eval()
17
+ model.eval()
18
+
19
+
20
+ class HubertModelWithFinalProj(HubertModel):
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
25
+
26
+
27
+ # Default Config
28
+ hubert = HubertModelWithFinalProj(HubertConfig())
29
+
30
+ # huggingface: fairseq
31
+ mapping = {
32
+ "masked_spec_embed": "mask_emb",
33
+ "encoder.layer_norm.bias": "encoder.layer_norm.bias",
34
+ "encoder.layer_norm.weight": "encoder.layer_norm.weight",
35
+ "encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias",
36
+ "encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g",
37
+ "encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v",
38
+ "feature_projection.layer_norm.bias": "layer_norm.bias",
39
+ "feature_projection.layer_norm.weight": "layer_norm.weight",
40
+ "feature_projection.projection.bias": "post_extract_proj.bias",
41
+ "feature_projection.projection.weight": "post_extract_proj.weight",
42
+ "final_proj.bias": "final_proj.bias",
43
+ "final_proj.weight": "final_proj.weight",
44
+ }
45
+
46
+ # Convert encoder
47
+ for layer in range(12):
48
+ for j in ["q", "k", "v"]:
49
+ mapping[
50
+ f"encoder.layers.{layer}.attention.{j}_proj.weight"
51
+ ] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight"
52
+ mapping[
53
+ f"encoder.layers.{layer}.attention.{j}_proj.bias"
54
+ ] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias"
55
+
56
+ mapping[
57
+ f"encoder.layers.{layer}.final_layer_norm.bias"
58
+ ] = f"encoder.layers.{layer}.final_layer_norm.bias"
59
+ mapping[
60
+ f"encoder.layers.{layer}.final_layer_norm.weight"
61
+ ] = f"encoder.layers.{layer}.final_layer_norm.weight"
62
+
63
+ mapping[
64
+ f"encoder.layers.{layer}.layer_norm.bias"
65
+ ] = f"encoder.layers.{layer}.self_attn_layer_norm.bias"
66
+ mapping[
67
+ f"encoder.layers.{layer}.layer_norm.weight"
68
+ ] = f"encoder.layers.{layer}.self_attn_layer_norm.weight"
69
+
70
+ mapping[
71
+ f"encoder.layers.{layer}.attention.out_proj.bias"
72
+ ] = f"encoder.layers.{layer}.self_attn.out_proj.bias"
73
+ mapping[
74
+ f"encoder.layers.{layer}.attention.out_proj.weight"
75
+ ] = f"encoder.layers.{layer}.self_attn.out_proj.weight"
76
+
77
+ mapping[
78
+ f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias"
79
+ ] = f"encoder.layers.{layer}.fc1.bias"
80
+ mapping[
81
+ f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight"
82
+ ] = f"encoder.layers.{layer}.fc1.weight"
83
+
84
+ mapping[
85
+ f"encoder.layers.{layer}.feed_forward.output_dense.bias"
86
+ ] = f"encoder.layers.{layer}.fc2.bias"
87
+ mapping[
88
+ f"encoder.layers.{layer}.feed_forward.output_dense.weight"
89
+ ] = f"encoder.layers.{layer}.fc2.weight"
90
+
91
+ # Convert Conv Layers
92
+ for layer in range(7):
93
+ mapping[
94
+ f"feature_extractor.conv_layers.{layer}.conv.weight"
95
+ ] = f"feature_extractor.conv_layers.{layer}.0.weight"
96
+
97
+ if layer != 0:
98
+ continue
99
+
100
+ mapping[
101
+ f"feature_extractor.conv_layers.{layer}.layer_norm.weight"
102
+ ] = f"feature_extractor.conv_layers.{layer}.2.weight"
103
+ mapping[
104
+ f"feature_extractor.conv_layers.{layer}.layer_norm.bias"
105
+ ] = f"feature_extractor.conv_layers.{layer}.2.bias"
106
+
107
+ hf_keys = set(hubert.state_dict().keys())
108
+ fair_keys = set(model.state_dict().keys())
109
+
110
+ hf_keys -= set(mapping.keys())
111
+ fair_keys -= set(mapping.values())
112
+
113
+ for i, j in zip(sorted(hf_keys), sorted(fair_keys)):
114
+ print(i, j)
115
+
116
+ print(hf_keys, fair_keys)
117
+ print(len(hf_keys), len(fair_keys))
118
+
119
+ # try loading the weights
120
+ new_state_dict = {}
121
+ for k, v in mapping.items():
122
+ new_state_dict[k] = model.state_dict()[v]
123
+
124
+ x = hubert.load_state_dict(new_state_dict, strict=False)
125
+ print(x)
126
+ hubert.eval()
127
+
128
+ with torch.no_grad():
129
+ new_input = torch.randn(1, 16384)
130
+
131
+ result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9]
132
+ result1 = hubert.final_proj(result1)
133
+
134
+ result2 = model.extract_features(
135
+ **{
136
+ "source": new_input,
137
+ "padding_mask": torch.zeros(1, 16384, dtype=torch.bool),
138
+ # "features_only": True,
139
+ "output_layer": 9,
140
+ }
141
+ )[0]
142
+ result2 = model.final_proj(result2)
143
+
144
+ assert torch.allclose(result1, result2, atol=1e-3)
145
+
146
+ print("Sanity check passed")
147
+
148
+ # Save huggingface model
149
+ hubert.save_pretrained(".")
150
+ print("Saved model")
pretrained/download.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+
4
+ if __name__ == "__main__":
5
+ model_path = snapshot_download(
6
+ repo_id="Pur1zumu/RIFT-SVC-modules",
7
+ local_dir='pretrained',
8
+ local_dir_use_symlinks=False, # Don't use symlinks
9
+ local_files_only=False, # Allow downloading new files
10
+ ignore_patterns=["*.git*"], # Ignore git-related files
11
+ resume_download=True # Resume interrupted downloads
12
+ )
pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/NOTICE.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- DiffSinger Community Vocoder ---
2
+
3
+ ARCHITECTURE: NSF-HiFiGAN
4
+ RELEASE DATE: 2024-02-19
5
+
6
+ HYPER PARAMETERS:
7
+ - 44100 sample rate
8
+ - 128 mel bins
9
+ - 512 hop size
10
+ - 2048 window size
11
+ - fmin at 40Hz
12
+ - fmax at 16000Hz
13
+
14
+
15
+ NOTICE:
16
+
17
+ All model weights in the [DiffSinger Community Vocoder Project](https://openvpi.github.io/vocoders/), including
18
+ model weights in this directory, are provided by the [OpenVPI Team](https://github.com/openvpi/), under the
19
+ [Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
20
+
21
+
22
+ ACKNOWLEDGEMENTS:
23
+
24
+ Training data of this vocoder is provided and permitted by the following organizations, societies and individuals:
25
+
26
+ 孙飒 https://www.qfssr.cn
27
+ 赤松_Akamatsu https://www.zhibin.club
28
+ 乐威 https://www.zhibin.club
29
+ 伯添 https://space.bilibili.com/24087011
30
+ 雲宇光 https://space.bilibili.com/660675050
31
+ 橙子言 https://space.bilibili.com/318486464
32
+ 人衣大人 https://space.bilibili.com/2270344
33
+ 玖蝶 https://space.bilibili.com/676771003
34
+ Yuuko
35
+ 白夜零BYL https://space.bilibili.com/1605040503
36
+ 嗷天 https://space.bilibili.com/5675252
37
+ 洛泠羽 https://space.bilibili.com/347373318
38
+ 灰条纹的灰猫君 https://space.bilibili.com/2083633
39
+ 幽寂 https://space.bilibili.com/478860
40
+ 恶魔王女 https://space.bilibili.com/2475098
41
+ AlexYHX 芮晴
42
+ 绮萱 https://y.qq.com/n/ryqq/singer/003HjD6H4aZn1K
43
+ 诗芸 https://y.qq.com/n/ryqq/singer/0005NInj142zm0
44
+ 汐蕾 https://y.qq.com/n/ryqq/singer/0023cWMH1Bq1PJ
45
+ 1262917464
46
+ 炜阳
47
+ 叶卡yolka
48
+ 幸の夏 https://space.bilibili.com/1017297686
49
+ 暮色未量 https://space.bilibili.com/272904686
50
+ 晓寞sama https://space.bilibili.com/3463394
51
+ 没头绪的节操君
52
+ 串串BunC https://space.bilibili.com/95817834
53
+ 落雨 https://space.bilibili.com/1292427
54
+ 长尾巴的翎艾 https://space.bilibili.com/1638666
55
+ 声闻计划 https://space.bilibili.com/392812269
56
+ 唐家大小姐 http://5sing.kugou.com/palmusic/default.html
57
+ 不伊子
58
+ 芸青岩 https://space.bilibili.com/35236775
59
+ 妖橙 https://space.bilibili.com/161975631
60
+ 双桨 https://space.bilibili.com/13245483
61
+ 灵滅 https://space.bilibili.com/276988145
62
+ AlexYHX https://space.bilibili.com/13303439
63
+ 祁唱 https://space.bilibili.com/11256670
64
+ 早稻叽 https://space.bilibili.com/1950658
65
+
66
+ The following public datasets are used:
67
+
68
+ Opencpop https://wenet.org.cn/opencpop/
69
+ CCMUSIC https://ccmusic-database.github.io/index.html
70
+ SingingVoiceDataset http://isophonics.net/SingingVoiceDataset
71
+
72
+ Training machines are provided by:
73
+
74
+ 花儿不哭 https://space.bilibili.com/5760446
75
+
76
+
77
+ TERMS OF REDISTRIBUTIONS:
78
+
79
+ 1. Do not sell this vocoder, or charge any fees from redistributing it, as prohibited by
80
+ the license.
81
+ 2. Include a copy of the CC BY-NC-SA 4.0 license, or a link referring to it.
82
+ 3. Include a copy of this notice, or any other notices informing that this vocoder is
83
+ provided by the OpenVPI Team, that this vocoder is licensed under CC BY-NC-SA 4.0, and
84
+ with a complete acknowledgement list as shown above.
85
+ 4. If you fine-tuned or modified the weights, leave a notice about what has been changed.
86
+ 5. (Optional) Leave a link to the official release page of the vocoder, and tell users
87
+ that other versions and future updates of this vocoder can be obtained from the website.
pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/NOTICE.zh-CN.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- DiffSinger 社区声码器 ---
2
+
3
+ 架构:NSF-HiFiGAN
4
+ 发布日期:2024-02-19
5
+
6
+ 超参数:
7
+ - 44100 sample rate
8
+ - 128 mel bins
9
+ - 512 hop size
10
+ - 2048 window size
11
+ - fmin at 40Hz
12
+ - fmax at 16000Hz
13
+
14
+
15
+ 注意事项:
16
+
17
+ [DiffSinger 社区声码器企划](https://openvpi.github.io/vocoders/) 中的所有模型权重,
18
+ 包括此目录下的模型权重,均由 [OpenVPI Team](https://github.com/openvpi/) 提供,并基于
19
+ [Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/)
20
+ 进行许可。
21
+
22
+
23
+ 致谢:
24
+
25
+ 此声码器的训练数据由以下组织、社团和个人提供并许可:
26
+
27
+ 孙飒 https://www.qfssr.cn
28
+ 赤松_Akamatsu https://www.zhibin.club
29
+ 乐威 https://www.zhibin.club
30
+ 伯添 https://space.bilibili.com/24087011
31
+ 雲宇光 https://space.bilibili.com/660675050
32
+ 橙子言 https://space.bilibili.com/318486464
33
+ 人衣大人 https://space.bilibili.com/2270344
34
+ 玖蝶 https://space.bilibili.com/676771003
35
+ Yuuko
36
+ 白夜零BYL https://space.bilibili.com/1605040503
37
+ 嗷天 https://space.bilibili.com/5675252
38
+ 洛泠羽 https://space.bilibili.com/347373318
39
+ 灰条纹的灰猫君 https://space.bilibili.com/2083633
40
+ 幽寂 https://space.bilibili.com/478860
41
+ 恶魔王女 https://space.bilibili.com/2475098
42
+ 芮晴
43
+ 绮萱 https://y.qq.com/n/ryqq/singer/003HjD6H4aZn1K
44
+ 诗芸 https://y.qq.com/n/ryqq/singer/0005NInj142zm0
45
+ 汐蕾 https://y.qq.com/n/ryqq/singer/0023cWMH1Bq1PJ
46
+ 1262917464
47
+ 炜阳
48
+ 叶卡yolka
49
+ 幸の夏 https://space.bilibili.com/1017297686
50
+ 暮色未量 https://space.bilibili.com/272904686
51
+ 晓寞sama https://space.bilibili.com/3463394
52
+ 没头绪的节操君
53
+ 串串BunC https://space.bilibili.com/95817834
54
+ 落雨 https://space.bilibili.com/1292427
55
+ 长尾巴的翎艾 https://space.bilibili.com/1638666
56
+ 声闻计划 https://space.bilibili.com/392812269
57
+ 唐家大小姐 http://5sing.kugou.com/palmusic/default.html
58
+ 不伊子
59
+ 芸青岩 https://space.bilibili.com/35236775
60
+ 妖橙 https://space.bilibili.com/161975631
61
+ 双桨 https://space.bilibili.com/13245483
62
+ 灵滅 https://space.bilibili.com/276988145
63
+ AlexYHX https://space.bilibili.com/13303439
64
+ 祁唱 https://space.bilibili.com/11256670
65
+ 早稻叽 https://space.bilibili.com/1950658
66
+
67
+ 使用了以下公开数据集:
68
+
69
+ Opencpop https://wenet.org.cn/opencpop/
70
+ CCMUSIC https://ccmusic-database.github.io/index.html
71
+ SingingVoiceDataset http://isophonics.net/SingingVoiceDataset
72
+
73
+ 训练算力的提供者如下:
74
+
75
+ 花儿不哭 https://space.bilibili.com/5760446
76
+
77
+
78
+ 二次分发条款:
79
+
80
+ 1. 请勿售卖此声码器或从其二次分发过程中收取任何费用,因为此类行为受到许可证的禁止。
81
+ 2. 请在二次分发文件中包含一份 CC BY-NC-SA 4.0 许可证的副本或指向该许可证的链接。
82
+ 3. 请在二次分发文件中包含这份声明,或以其他形式声明此声码器由 OpenVPI Team 提供并基于 CC BY-NC-SA 4.0 许可,
83
+ 并附带上述完整的致谢名单。
84
+ 4. 如果您微调或修改了权重,请留下一份关于其受到了何种修改的说明。
85
+ 5.(可选)留下一份指向此声码器的官方发布页面的链接,并告知使用者可从该网站获取此声码器的其他版本和未来的更新。
pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "discriminator_periods": [
3
+ 3,
4
+ 5,
5
+ 7,
6
+ 11,
7
+ 17,
8
+ 23,
9
+ 37
10
+ ],
11
+ "resblock": "1",
12
+ "resblock_dilation_sizes": [
13
+ [
14
+ 1,
15
+ 3,
16
+ 5
17
+ ],
18
+ [
19
+ 1,
20
+ 3,
21
+ 5
22
+ ],
23
+ [
24
+ 1,
25
+ 3,
26
+ 5
27
+ ]
28
+ ],
29
+ "resblock_kernel_sizes": [
30
+ 3,
31
+ 7,
32
+ 11
33
+ ],
34
+ "upsample_initial_channel": 512,
35
+ "upsample_kernel_sizes": [
36
+ 16,
37
+ 16,
38
+ 4,
39
+ 4,
40
+ 4
41
+ ],
42
+ "upsample_rates": [
43
+ 8,
44
+ 8,
45
+ 2,
46
+ 2,
47
+ 2
48
+ ],
49
+ "sampling_rate": 44100,
50
+ "num_mels": 128,
51
+ "hop_size": 512,
52
+ "n_fft": 2048,
53
+ "win_size": 2048,
54
+ "fmin": 40,
55
+ "fmax": 16000
56
+ }
pretrained/rmvpe/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ click
2
+ einops
3
+ gradio
4
+ huggingface_hub
5
+ hydra-core
6
+ jaxtyping
7
+ librosa
8
+ matplotlib
9
+ numpy
10
+ omegaconf
11
+ Pillow
12
+ praat-parselmouth
13
+ pyloudnorm
14
+ PyYAML
15
+ pytorch_lightning
16
+ resampy
17
+ schedulefree
18
+ scipy
19
+ soundfile
20
+ tensorboard
21
+ thop
22
+ torch
23
+ torchaudio
24
+ torchdiffeq
25
+ tqdm
26
+ transformers
27
+ wandb
28
+ x_transformers
rift_svc/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from rift_svc.rf import RF
2
+ from rift_svc.dit import DiT
3
+ from rift_svc.lightning_module import RIFTSVCLightningModule
rift_svc/dataset.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from functools import partial
5
+ from typing import Literal
6
+ import torch
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from torch.utils.data import Dataset
9
+
10
+ from rift_svc.utils import linear_interpolate_tensor, nearest_interpolate_tensor
11
+
12
+ pt_load = partial(torch.load, weights_only=True, map_location='cpu', mmap=True)
13
+
14
+
15
+ class SVCDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ data_dir: str,
19
+ meta_info_path: str,
20
+ max_frame_len = 256,
21
+ split = "train",
22
+ use_cvec_downsampled: bool = False,
23
+ cvec_downsample_rate: int = 2,
24
+ ):
25
+ self.data_dir = data_dir
26
+ self.max_frame_len = max_frame_len
27
+
28
+ with open(meta_info_path, 'r', encoding='utf-8') as f:
29
+ meta = json.load(f)
30
+
31
+ speakers = meta["speakers"]
32
+ self.num_speakers = len(speakers)
33
+ self.spk2idx = {spk: idx for idx, spk in enumerate(speakers)}
34
+ self.split = split
35
+ self.samples = meta[f"{split}_audios"]
36
+ self.use_cvec_downsampled = use_cvec_downsampled
37
+ self.cvec_downsample_rate = cvec_downsample_rate
38
+
39
+ def get_frame_len(self, index):
40
+ return self.samples[index]['frame_len']
41
+
42
+ def __len__(self):
43
+ return len(self.samples)
44
+
45
+ def __getitem__(self, index):
46
+
47
+ sample = self.samples[index]
48
+ spk = sample['speaker']
49
+ path = os.path.join(self.data_dir, spk, sample['file_name'])
50
+ spk_id = torch.LongTensor([self.spk2idx[spk]]) # [1]
51
+
52
+ mel = pt_load(path + ".mel.pt").squeeze(0).T
53
+ rms = pt_load(path + ".rms.pt").squeeze(0)
54
+ f0 = pt_load(path + ".f0.pt").squeeze(0)
55
+ cvec = pt_load(path + ".cvec.pt").squeeze(0)
56
+
57
+ cvec = linear_interpolate_tensor(cvec, mel.shape[0])
58
+ if self.use_cvec_downsampled:
59
+ cvec_ds = cvec[::2, :]
60
+ cvec_ds = linear_interpolate_tensor(cvec_ds, cvec_ds.shape[0]//self.cvec_downsample_rate)
61
+ cvec_ds = linear_interpolate_tensor(cvec_ds, mel.shape[0])
62
+
63
+ frame_len = mel.shape[0]
64
+
65
+ if frame_len > self.max_frame_len:
66
+ if self.split == "train":
67
+ # Keep trying until we find a good segment or hit max attempts
68
+ max_attempts = 10
69
+ attempt = 0
70
+ while attempt < max_attempts:
71
+ start = random.randint(0, frame_len - self.max_frame_len)
72
+ end = start + self.max_frame_len
73
+ f0_segment = f0[start:end]
74
+ # Check if more than 90% of f0 values are 0
75
+ zero_ratio = (f0_segment == 0).float().mean().item()
76
+ if zero_ratio < 0.9: # Found a good segment
77
+ break
78
+ attempt += 1
79
+ else:
80
+ start = 0
81
+ end = start + self.max_frame_len
82
+ mel = mel[start:end]
83
+ rms = rms[start:end]
84
+ f0 = f0[start:end]
85
+ cvec = cvec[start:end]
86
+ if self.use_cvec_downsampled:
87
+ cvec_ds = cvec_ds[start:end]
88
+ frame_len = self.max_frame_len
89
+
90
+ result = dict(
91
+ spk_id = spk_id,
92
+ mel = mel,
93
+ rms = rms,
94
+ f0 = f0,
95
+ cvec = cvec,
96
+ frame_len = frame_len
97
+ )
98
+
99
+ if self.use_cvec_downsampled:
100
+ result['cvec_ds'] = cvec_ds
101
+
102
+ return result
103
+
104
+
105
+ def collate_fn(batch):
106
+ spk_ids = [item['spk_id'] for item in batch]
107
+ mels = [item['mel'] for item in batch]
108
+ rmss = [item['rms'] for item in batch]
109
+ f0s = [item['f0'] for item in batch]
110
+ cvecs = [item['cvec'] for item in batch]
111
+ if 'cvec_ds' in batch[0]:
112
+ cvecs_ds = [item['cvec_ds'] for item in batch]
113
+
114
+ frame_lens = [item['frame_len'] for item in batch]
115
+
116
+ # Pad sequences to max length
117
+ mels_padded = pad_sequence(mels, batch_first=True)
118
+ rmss_padded = pad_sequence(rmss, batch_first=True)
119
+ f0s_padded = pad_sequence(f0s, batch_first=True)
120
+ cvecs_padded = pad_sequence(cvecs, batch_first=True)
121
+ if 'cvec_ds' in batch[0]:
122
+ cvecs_ds_padded = pad_sequence(cvecs_ds, batch_first=True)
123
+
124
+ spk_ids = torch.cat(spk_ids)
125
+ frame_len = torch.tensor(frame_lens)
126
+
127
+ result = {
128
+ 'spk_id': spk_ids,
129
+ 'mel': mels_padded,
130
+ 'rms': rmss_padded,
131
+ 'f0': f0s_padded,
132
+ 'cvec': cvecs_padded,
133
+ 'frame_len': frame_len
134
+ }
135
+
136
+ if 'cvec_ds' in batch[0]:
137
+ result['cvec_ds'] = cvecs_ds_padded
138
+
139
+ return result
rift_svc/dit.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Union, List
3
+
4
+ from einops import repeat
5
+ from jaxtyping import Bool, Float, Int
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from x_transformers.x_transformers import RotaryEmbedding
10
+
11
+ from rift_svc.modules import (
12
+ AdaLayerNormZero_Final,
13
+ DiTBlock,
14
+ TimestepEmbedding,
15
+ LoRALinear,
16
+ )
17
+
18
+ # Conditional embedding for f0, rms, cvec
19
+ class CondEmbedding(nn.Module):
20
+ def __init__(self, cvec_dim: int, cond_dim: int):
21
+ super().__init__()
22
+ self.cvec_dim = cvec_dim
23
+ self.cond_dim = cond_dim
24
+
25
+ self.f0_embed = nn.Linear(1, cond_dim)
26
+ self.rms_embed = nn.Linear(1, cond_dim)
27
+ self.cvec_embed = nn.Linear(cvec_dim, cond_dim)
28
+ self.out = nn.Linear(cond_dim, cond_dim)
29
+
30
+ self.ln_cvec = nn.LayerNorm(cond_dim, elementwise_affine=False, eps=1e-6)
31
+ self.ln = nn.LayerNorm(cond_dim, elementwise_affine=True, eps=1e-6)
32
+
33
+
34
+ def forward(
35
+ self,
36
+ f0: Float[torch.Tensor, "b n"],
37
+ rms: Float[torch.Tensor, "b n"],
38
+ cvec: Float[torch.Tensor, "b n d"],
39
+ ):
40
+ if f0.ndim == 2:
41
+ f0 = f0.unsqueeze(-1)
42
+ if rms.ndim == 2:
43
+ rms = rms.unsqueeze(-1)
44
+
45
+ f0_embed = self.f0_embed(f0 / 1200)
46
+ rms_embed = self.rms_embed(rms)
47
+ cvec_embed = self.ln_cvec(self.cvec_embed(cvec))
48
+
49
+ cond = f0_embed + rms_embed + cvec_embed
50
+ cond = self.ln(self.out(cond))
51
+ return cond
52
+
53
+
54
+ # noised input audio and context mixing embedding
55
+ class InputEmbedding(nn.Module):
56
+ def __init__(self, mel_dim: int, out_dim: int):
57
+ super().__init__()
58
+ self.mel_embed = nn.Linear(mel_dim, out_dim)
59
+ self.proj = nn.Linear(2 * out_dim, out_dim)
60
+ self.ln = nn.LayerNorm(out_dim, elementwise_affine=False, eps=1e-6)
61
+
62
+ def forward(self, x: Float[torch.Tensor, "b n d1"], cond_embed: Float[torch.Tensor, "b n d2"]):
63
+ x = self.mel_embed(x)
64
+ x = torch.cat((x, cond_embed), dim = -1)
65
+ x = self.proj(x)
66
+ x = self.ln(x)
67
+ return x
68
+
69
+
70
+ # backbone using DiT blocks
71
+ class DiT(nn.Module):
72
+ def __init__(self,
73
+ dim: int, depth: int, head_dim: int = 64, dropout: float = 0.0, ff_mult: int = 4,
74
+ n_mel_channels: int = 128, num_speaker: int = 1, cvec_dim: int = 768,
75
+ kernel_size: int = 31, zero_null_spk: bool = False,
76
+ init_std: float = 1):
77
+ super().__init__()
78
+
79
+ self.num_speaker = num_speaker
80
+ self.spk_embed = nn.Embedding(num_speaker, dim)
81
+ self.null_spk_embed = nn.Embedding(1, dim)
82
+ self.tembed = TimestepEmbedding(dim)
83
+ self.cond_embed = CondEmbedding(cvec_dim, dim)
84
+ self.input_embed = InputEmbedding(n_mel_channels, dim)
85
+
86
+ self.rotary_embed = RotaryEmbedding(head_dim)
87
+
88
+ self.dim = dim
89
+ self.depth = depth
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ DiTBlock(
93
+ dim = dim,
94
+ head_dim = head_dim,
95
+ ff_mult = ff_mult,
96
+ dropout = dropout,
97
+ kernel_size = kernel_size,
98
+ )
99
+ for _ in range(depth)
100
+ ]
101
+ )
102
+
103
+ self.norm_out = AdaLayerNormZero_Final(dim)
104
+ self.output = nn.Linear(dim, n_mel_channels)
105
+
106
+ self.init_std = init_std
107
+ self.apply(self._init_weights)
108
+ for block in self.transformer_blocks:
109
+ torch.nn.init.constant_(block.attn_norm.proj.weight, 0)
110
+ torch.nn.init.constant_(block.attn_norm.proj.bias, 0)
111
+
112
+ torch.nn.init.constant_(self.norm_out.proj.weight, 0)
113
+ torch.nn.init.constant_(self.norm_out.proj.bias, 0)
114
+ torch.nn.init.constant_(self.output.weight, 0)
115
+ torch.nn.init.constant_(self.output.bias, 0)
116
+
117
+ if zero_null_spk:
118
+ self.null_spk_embed.weight.data.zero_()
119
+ self.null_spk_embed.requires_grad = False
120
+
121
+ def _init_weights(self, module: nn.Module):
122
+ if isinstance(module, nn.Linear):
123
+ fan_out, fan_in = module.weight.shape
124
+ # Spectral parameterization from the [paper](https://arxiv.org/abs/2310.17813).
125
+ init_std = (self.init_std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
126
+ torch.nn.init.normal_(module.weight, mean=0.0, std=init_std)
127
+ if module.bias is not None:
128
+ torch.nn.init.zeros_(module.bias)
129
+ elif isinstance(module, nn.Conv1d):
130
+ # weight shape: (out_channels, in_channels/groups, kernel_size)
131
+ fan_out = module.weight.shape[0] # out_channels
132
+ fan_in = module.weight.shape[1] * module.weight.shape[2] # (in_channels/groups) * kernel_size
133
+ init_std = (self.init_std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
134
+ torch.nn.init.normal_(module.weight, mean=0.0, std=init_std)
135
+ if module.bias is not None:
136
+ torch.nn.init.zeros_(module.bias)
137
+ elif isinstance(module, nn.Embedding):
138
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.init_std/math.sqrt(self.dim))
139
+
140
+
141
+ def forward(
142
+ self,
143
+ x: Float[torch.Tensor, "b n d1`"], # nosied input mel
144
+ spk: Int[torch.Tensor, "b"], # speaker
145
+ f0: Float[torch.Tensor, "b n"],
146
+ rms: Float[torch.Tensor, "b n"],
147
+ cvec: Float[torch.Tensor, "b n d2"],
148
+ time: Float[torch.Tensor, "b"], # time step
149
+ drop_speaker: Union[bool, Bool[torch.Tensor, "b"]] = False,
150
+ mask: Bool[torch.Tensor, "b n"] | None = None,
151
+ skip_layers: Union[int, List[int], None] = None,
152
+ ):
153
+ batch, seq_len = x.shape[0], x.shape[1]
154
+ if time.ndim == 0:
155
+ time = repeat(time, ' -> b', b = batch)
156
+
157
+ if isinstance(drop_speaker, bool):
158
+ drop_speaker = torch.full((batch,), drop_speaker, dtype=torch.bool, device=x.device)
159
+
160
+ spk_embeds = self.spk_embed(spk)
161
+ null_spk_embeds = self.null_spk_embed(torch.zeros_like(spk, dtype=torch.long))
162
+ spk_embeds = torch.where(drop_speaker.unsqueeze(-1), null_spk_embeds, spk_embeds)
163
+
164
+ t = self.tembed(time)
165
+ t = t + spk_embeds
166
+
167
+ cond_embed = self.cond_embed(f0, rms, cvec)
168
+ x = self.input_embed(x, cond_embed)
169
+
170
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
171
+
172
+ if skip_layers is not None:
173
+ if isinstance(skip_layers, int):
174
+ skip_layers = [skip_layers]
175
+
176
+ for i, block in enumerate(self.transformer_blocks):
177
+ if skip_layers is not None and i in skip_layers:
178
+ continue
179
+ x = block(x, t, mask = mask, rope = rope)
180
+
181
+ x = self.norm_out(x, t)
182
+ output = self.output(x)
183
+
184
+ return output
185
+
186
+
187
+ def apply_lora(self, rank, alpha):
188
+ for n, p in self.named_parameters():
189
+ p.requires_grad = False
190
+ self.spk_embed.weight.requires_grad = True
191
+ # Apply LoRA to k_proj and v_proj in each attention block
192
+ for block in self.transformer_blocks:
193
+ block.attn.k_proj = LoRALinear(block.attn.k_proj, rank, alpha)
194
+ block.attn.v_proj = LoRALinear(block.attn.v_proj, rank, alpha)
195
+
196
+
197
+ def merge_lora(self):
198
+ # Iterate over each transformer block in the DiT backbone
199
+ for block in self.transformer_blocks:
200
+ # Merge for k_proj if it is a LoRALinear instance
201
+ if isinstance(block.attn.k_proj, LoRALinear):
202
+ with torch.no_grad():
203
+ # Compute delta update: B @ A^T
204
+ delta = block.attn.k_proj.B @ block.attn.k_proj.A.T
205
+ # The underlying linear layer has weight of shape (out_features, in_features)
206
+ # and its forward computes x * weight.T
207
+ # Note: delta.T equals A @ B^T, so merging works correctly:
208
+ block.attn.k_proj.linear.weight.add_(delta)
209
+ # Replace the LoRALinear module with the merged linear layer
210
+ block.attn.k_proj = block.attn.k_proj.linear
211
+
212
+ # Merge for v_proj in the same way
213
+ if isinstance(block.attn.v_proj, LoRALinear):
214
+ with torch.no_grad():
215
+ delta = block.attn.v_proj.B @ block.attn.v_proj.A.T
216
+ block.attn.v_proj.linear.weight.add_(delta)
217
+ block.attn.v_proj = block.attn.v_proj.linear
218
+
219
+
220
+ def freeze_adaln_and_tembed(self):
221
+ for p in self.tembed.parameters():
222
+ p.requires_grad = False
223
+ for p in self.norm_out.parameters():
224
+ p.requires_grad = False
225
+ for block in self.transformer_blocks:
226
+ for p in block.attn_norm.parameters():
227
+ p.requires_grad = False
rift_svc/feature_extractors.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from jaxtyping import Float
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from transformers import HubertModel
6
+
7
+
8
+ def dynamic_range_compression_torch(
9
+ x: Float[torch.Tensor, "n_mels mel_len"],
10
+ C: float = 1,
11
+ clip_val: float = 1e-5
12
+ ) -> Float[torch.Tensor, "n_mels mel_len"]:
13
+ return torch.log(torch.clamp(x, min=clip_val) * C)
14
+
15
+
16
+ def spectral_normalize_torch(
17
+ magnitudes: Float[torch.Tensor, "n_mels mel_len"]
18
+ ) -> Float[torch.Tensor, "n_mels mel_len"]:
19
+ return dynamic_range_compression_torch(magnitudes)
20
+
21
+
22
+ mel_basis_cache = {}
23
+ hann_window_cache = {}
24
+
25
+
26
+ def get_mel_spectrogram(
27
+ y: Float[torch.Tensor, "n"],
28
+ n_fft: int = 2048,
29
+ num_mels: int = 128,
30
+ sampling_rate: int = 44100,
31
+ hop_size: int = 512,
32
+ win_size: int = 2048,
33
+ fmin: int = 40,
34
+ fmax: int | None = 16000,
35
+ center: bool = False,
36
+ ) -> Float[torch.Tensor, "n_mels mel_len"]:
37
+ """
38
+ Calculate the mel spectrogram of an input signal.
39
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
40
+
41
+ Args:
42
+ y (torch.Tensor): Input signal with shape (n,).
43
+ n_fft (int, optional): FFT size. Defaults to 1024.
44
+ num_mels (int, optional): Number of mel bins. Defaults to 128.
45
+ sampling_rate (int, optional): Sampling rate of the input signal. Defaults to 44100.
46
+ hop_size (int, optional): Hop size for STFT. Defaults to 256.
47
+ win_size (int, optional): Window size for STFT. Defaults to 1024.
48
+ fmin (int, optional): Minimum frequency for mel filterbank. Defaults to 0.
49
+ fmax (int | None, optional): Maximum frequency for mel filterbank. If None, defaults to sr/2.0. Defaults to None.
50
+ center (bool, optional): Whether to pad the input to center the frames. Defaults to False.
51
+
52
+ Returns:
53
+ torch.Tensor: Mel spectrogram with shape (n_mels, mel_len).
54
+ """
55
+ if torch.min(y) < -1.0:
56
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
57
+ if torch.max(y) > 1.0:
58
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
59
+
60
+ device = y.device
61
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
62
+
63
+ if key not in mel_basis_cache:
64
+ mel = librosa_mel_fn(
65
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
66
+ )
67
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
68
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
69
+
70
+ mel_basis = mel_basis_cache[key]
71
+ hann_window = hann_window_cache[key]
72
+
73
+ padding = (n_fft - hop_size) // 2
74
+ y = torch.nn.functional.pad(
75
+ y.unsqueeze(1), (padding, padding), mode="reflect"
76
+ ).squeeze(1)
77
+
78
+ spec = torch.stft(
79
+ y,
80
+ n_fft,
81
+ hop_length=hop_size,
82
+ win_length=win_size,
83
+ window=hann_window,
84
+ center=center,
85
+ pad_mode="reflect",
86
+ normalized=False,
87
+ onesided=True,
88
+ return_complex=True,
89
+ )
90
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
91
+
92
+ mel_spec = torch.matmul(mel_basis, spec)
93
+ mel_spec = spectral_normalize_torch(mel_spec)
94
+
95
+ return mel_spec
96
+
97
+
98
+ class RMSExtractor(nn.Module):
99
+ def __init__(self, hop_length=512, window_length=2048):
100
+ """
101
+ Initializes the RMSExtractor with the specified hop_length.
102
+
103
+ Args:
104
+ hop_length (int): Number of samples between successive frames.
105
+ """
106
+ super(RMSExtractor, self).__init__()
107
+ self.hop_length = hop_length
108
+ self.window_length = window_length
109
+
110
+ def forward(self, inp):
111
+ """
112
+ Extracts RMS energy from the input audio tensor.
113
+
114
+ Args:
115
+ inp (Tensor): Audio tensor of shape (batch, samples).
116
+
117
+ Returns:
118
+ Tensor: RMS energy tensor of shape (batch, frames).
119
+ """
120
+ # Square the audio signal
121
+ audio_squared = inp ** 2
122
+
123
+ # Use the same padding as mel spectrogram
124
+ padding = (self.window_length - self.hop_length) // 2
125
+ audio_padded = torch.nn.functional.pad(
126
+ audio_squared, (padding, padding), mode='reflect'
127
+ )
128
+
129
+ # Unfold to create frames with window_length instead of hop_length
130
+ frames = audio_padded.unfold(1, self.window_length, self.hop_length) # Shape: (batch, frames, window_length)
131
+
132
+ # Compute mean energy per frame
133
+ mean_energy = frames.mean(dim=-1) # Shape: (batch, frames)
134
+
135
+ # Compute RMS by taking square root
136
+ rms = torch.sqrt(mean_energy) # Shape: (batch, frames)
137
+
138
+ return rms
139
+
140
+
141
+ class HubertModelWithFinalProj(HubertModel):
142
+ def __init__(self, config):
143
+ super().__init__(config)
144
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
rift_svc/lightning_module.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+ import wandb
8
+ from functools import partial
9
+ import inspect
10
+
11
+ from pytorch_lightning import LightningModule
12
+
13
+ from rift_svc.metrics import mcd, psnr, si_snr
14
+ from rift_svc.feature_extractors import get_mel_spectrogram
15
+ from rift_svc.nsf_hifigan import NsfHifiGAN
16
+ from rift_svc.utils import draw_mel_specs, l2_grad_norm
17
+
18
+
19
+ class RIFTSVCLightningModule(LightningModule):
20
+ def __init__(
21
+ self,
22
+ model,
23
+ optimizer,
24
+ cfg,
25
+ lr_scheduler=None,
26
+ ):
27
+ super().__init__()
28
+ self.model = model
29
+ self.optimizer = optimizer
30
+ self.lr_scheduler = lr_scheduler
31
+ self.cfg = cfg
32
+ self.eval_sample_steps = cfg['training']['eval_sample_steps']
33
+ self.model.sample = partial(
34
+ self.model.sample,
35
+ steps=self.eval_sample_steps,
36
+ )
37
+ self.log_media_per_steps = cfg['training']['log_media_per_steps']
38
+ self.drop_spk_prob = cfg['training']['drop_spk_prob']
39
+
40
+ self.vocoder = None
41
+ self.save_hyperparameters(ignore=['model', 'optimizer', 'vocoder'])
42
+
43
+ def configure_optimizers(self):
44
+ if self.lr_scheduler is None:
45
+ return self.optimizer
46
+ return {
47
+ "optimizer": self.optimizer,
48
+ "lr_scheduler": {
49
+ "scheduler": self.lr_scheduler,
50
+ "interval": "step",
51
+ }
52
+ }
53
+
54
+ def training_step(self, batch, batch_idx):
55
+ mel = batch['mel']
56
+ spk_id = batch['spk_id']
57
+ f0 = batch['f0']
58
+ rms = batch['rms']
59
+ cvec = batch['cvec']
60
+ frame_len = batch['frame_len']
61
+
62
+ drop_speaker = False
63
+ if self.drop_spk_prob > 0:
64
+ batch_size = spk_id.shape[0]
65
+ num_drop = int(batch_size * self.drop_spk_prob)
66
+ drop_speaker = torch.zeros(batch_size, dtype=torch.bool, device=spk_id.device)
67
+ drop_speaker[:num_drop] = True
68
+ # Randomly shuffle the drop mask
69
+ drop_speaker = drop_speaker[torch.randperm(batch_size)]
70
+
71
+ loss, _ = self.model(
72
+ mel,
73
+ spk_id=spk_id,
74
+ f0=f0,
75
+ rms=rms,
76
+ cvec=cvec,
77
+ drop_speaker=drop_speaker,
78
+ frame_len=frame_len,
79
+ )
80
+
81
+ # Log metrics - compatible with both loggers
82
+ self._log_scalar("train/loss", loss.item(), prog_bar=True)
83
+
84
+ return loss
85
+
86
+ def on_validation_start(self):
87
+ if hasattr(self.optimizer, 'eval'):
88
+ self.optimizer.eval()
89
+ if not self.trainer.is_global_zero:
90
+ return
91
+
92
+ if self.vocoder is None:
93
+ self.vocoder = NsfHifiGAN(
94
+ 'pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt').to(self.device)
95
+ else:
96
+ self.vocoder = self.vocoder.to(self.device)
97
+
98
+ self.mcd = []
99
+ self.si_snr = []
100
+ self.psnr = []
101
+ self.mse = []
102
+
103
+
104
+ def on_validation_end(self, log=True):
105
+ if hasattr(self.optimizer, 'eval'):
106
+ self.optimizer.train()
107
+ if not self.trainer.is_global_zero:
108
+ return
109
+
110
+ if self.vocoder is not None:
111
+ self.vocoder = self.vocoder.cpu()
112
+ gc.collect()
113
+ torch.cuda.empty_cache()
114
+
115
+ metrics = {
116
+ 'val/mcd': np.mean(self.mcd),
117
+ 'val/si_snr': np.mean(self.si_snr),
118
+ 'val/psnr': np.mean(self.psnr),
119
+ 'val/mse': np.mean(self.mse)
120
+ }
121
+
122
+ if log:
123
+ # Log metrics - compatible with both loggers
124
+ for metric_name, metric_value in metrics.items():
125
+ self._log_scalar(metric_name, metric_value)
126
+
127
+
128
+ def validation_step(self, batch, batch_idx, log=True):
129
+ """
130
+ Process validation step and log metrics and media.
131
+
132
+ Args:
133
+ batch: Input batch
134
+ batch_idx: Batch index
135
+ log: Whether to log or not
136
+ """
137
+ # Skip if not the main process or logging is disabled
138
+ if not self.trainer.is_global_zero:
139
+ return
140
+
141
+ # Get step and interval info
142
+ global_step = self.global_step
143
+ log_media_every_n_steps = self.log_media_every_n_steps
144
+
145
+ # Extract input data
146
+ spk_id = batch['spk_id']
147
+ mel_gt = batch['mel']
148
+ rms = batch['rms']
149
+ f0 = batch['f0']
150
+ cvec = batch['cvec']
151
+ frame_len = batch['frame_len']
152
+ cvec_ds = batch.get('cvec_ds', None)
153
+
154
+ # Generate output
155
+ mel_gen, _ = self.model.sample(
156
+ src_mel=mel_gt,
157
+ spk_id=spk_id,
158
+ f0=f0,
159
+ rms=rms,
160
+ cvec=cvec,
161
+ frame_len=frame_len,
162
+ bad_cvec=cvec_ds,
163
+ )
164
+ mel_gen = mel_gen.float()
165
+ mel_gt = mel_gt.float()
166
+
167
+ # Process each sample in the batch
168
+ for i in range(mel_gen.shape[0]):
169
+ sample_idx = batch_idx * mel_gen.shape[0] + i
170
+
171
+ # Generate audio using vocoder
172
+ wav_gen = self.vocoder(mel_gen[i:i+1, :frame_len[i], :].transpose(1, 2), f0[i:i+1, :frame_len[i]])
173
+ wav_gt = self.vocoder(mel_gt[i:i+1, :frame_len[i], :].transpose(1, 2), f0[i:i+1, :frame_len[i]])
174
+ wav_gen = wav_gen.squeeze(0)
175
+ wav_gt = wav_gt.squeeze(0)
176
+
177
+ # Generate mel spectrograms
178
+ mel_gen_i = get_mel_spectrogram(wav_gen).transpose(1, 2)
179
+ mel_gt_i = get_mel_spectrogram(wav_gt).transpose(1, 2)
180
+
181
+ # Clip values to valid range
182
+ mel_min, mel_max = self.model.mel_min, self.model.mel_max
183
+ mel_gen_i = torch.clip(mel_gen_i, min=mel_min, max=mel_max)
184
+ mel_gt_i = torch.clip(mel_gt_i, min=mel_min, max=mel_max)
185
+
186
+ # Calculate metrics
187
+ self.mcd.append(mcd(mel_gen_i, mel_gt_i).cpu().item())
188
+ self.si_snr.append(si_snr(mel_gen_i, mel_gt_i).cpu().item())
189
+ self.psnr.append(psnr(mel_gen_i, mel_gt_i).cpu().item())
190
+ self.mse.append(F.mse_loss(mel_gen_i, mel_gt_i).cpu().item())
191
+
192
+ if log:
193
+ # Create cache directory if it doesn't exist
194
+ os.makedirs('.cache', exist_ok=True)
195
+
196
+ # Log generated audio at specified intervals
197
+ if global_step % log_media_every_n_steps == 0:
198
+ audio_path = f".cache/spk-{spk_id[i].item()}_{sample_idx}_gen.wav"
199
+ torchaudio.save(audio_path, wav_gen.cpu().to(torch.float32), 44100)
200
+ self._log_audio(self.logger, f"val-audio/spk-{spk_id[i].item()}_{sample_idx}-gen", audio_path, global_step)
201
+
202
+ # Log ground truth audio only at the first step
203
+ if global_step == 0:
204
+ gt_audio_path = f".cache/spk-{spk_id[i].item()}_{sample_idx}_gt.wav"
205
+ torchaudio.save(gt_audio_path, wav_gt.cpu().to(torch.float32), 44100)
206
+ self._log_audio(self.logger, f"val-audio/spk-{spk_id[i].item()}_{sample_idx}-gt", gt_audio_path, global_step)
207
+
208
+ # Log mel spectrograms at specified intervals
209
+ if global_step % log_media_every_n_steps == 0:
210
+ # Create mel spectrogram visualization
211
+ data_gt = mel_gt_i.squeeze().T.cpu().numpy()
212
+ data_gen = mel_gen_i.squeeze().T.cpu().numpy()
213
+ data_abs_diff = data_gen - data_gt
214
+ cache_path = f".cache/{sample_idx}_mel.jpg"
215
+ draw_mel_specs(data_gt, data_gen, data_abs_diff, cache_path)
216
+ self._log_image(self.logger, f"val-mel/{sample_idx}_mel", cache_path, global_step)
217
+
218
+ def on_test_start(self):
219
+ self.on_validation_start()
220
+
221
+ def on_test_end(self):
222
+ self.on_validation_end(log=False)
223
+
224
+ def test_step(self, batch, batch_idx):
225
+ self.validation_step(batch, batch_idx, log=False)
226
+
227
+ def on_before_optimizer_step(self, optimizer):
228
+ # Calculate gradient norm
229
+ norm = l2_grad_norm(self.model)
230
+
231
+ # Log gradient norm
232
+ self._log_scalar("train/grad_norm", norm)
233
+
234
+ @property
235
+ def global_step(self):
236
+ return self.trainer.global_step
237
+
238
+ @property
239
+ def log_media_every_n_steps(self):
240
+ if self.log_media_per_steps is not None:
241
+ return self.log_media_per_steps
242
+ if self.save_every_n_steps is None:
243
+ return self.trainer.val_check_interval
244
+ return self.save_every_n_steps
245
+
246
+ @property
247
+ def save_every_n_steps(self):
248
+ for callback in self.trainer.callbacks:
249
+ if hasattr(callback, '_every_n_train_steps'):
250
+ return callback._every_n_train_steps
251
+ return None
252
+
253
+ @property
254
+ def is_using_wandb(self):
255
+ """
256
+ Check if WandB logger is being used.
257
+
258
+ Returns:
259
+ bool: True if WandB logger is being used, False otherwise
260
+ """
261
+ from pytorch_lightning.loggers import WandbLogger
262
+ if isinstance(self.logger, WandbLogger):
263
+ return True
264
+ return False
265
+
266
+ @property
267
+ def is_using_tensorboard(self):
268
+ """
269
+ Check if TensorBoard logger is being used.
270
+
271
+ Returns:
272
+ bool: True if TensorBoard logger is being used, False otherwise
273
+ """
274
+ from pytorch_lightning.loggers import TensorBoardLogger
275
+ if isinstance(self.logger, TensorBoardLogger):
276
+ return True
277
+ return False
278
+
279
+ @property
280
+ def logger_type(self):
281
+ """
282
+ Get a string representation of the logger type.
283
+
284
+ Returns:
285
+ str: 'wandb', 'tensorboard', or 'unknown'
286
+ """
287
+ if self.is_using_wandb:
288
+ return 'wandb'
289
+ elif self.is_using_tensorboard:
290
+ return 'tensorboard'
291
+ else:
292
+ return 'unknown'
293
+
294
+ def state_dict(self, *args, **kwargs):
295
+ # Temporarily store vocoder
296
+ vocoder = self.vocoder
297
+ self.vocoder = None
298
+
299
+ # Get state dict without vocoder
300
+ state = super().state_dict(*args, **kwargs)
301
+
302
+ # Restore vocoder
303
+ self.vocoder = vocoder
304
+ return state
305
+
306
+ # Add helper methods for logging with different logger types
307
+ def _log_scalar(self, name, value, step=None, **kwargs):
308
+ """
309
+ Log a scalar value to the appropriate logger.
310
+
311
+ Args:
312
+ name: Name of the metric
313
+ value: Value of the metric
314
+ step: Step value (defaults to current global step if None)
315
+ **kwargs: Additional arguments to pass to the logger
316
+ """
317
+ if step is None:
318
+ step = self.global_step
319
+
320
+ # Special handling for on_validation_end or on_test_end
321
+ # Get the caller function name to determine if we're in on_validation_end
322
+ caller_frame = inspect.currentframe().f_back
323
+ caller_function = caller_frame.f_code.co_name
324
+
325
+ if caller_function in ['on_validation_end', 'on_test_end']:
326
+ # Use logger.experiment directly as self.log() is not allowed in these hooks
327
+ if self.is_using_wandb:
328
+ self.logger.experiment.log({name: value}, step=step)
329
+ elif self.is_using_tensorboard:
330
+ self.logger.experiment.add_scalar(name, value, step)
331
+ # Add other logger types here if needed
332
+ else:
333
+ # Use PyTorch Lightning's built-in logging system for scalars
334
+ # This handles different logger types automatically
335
+ self.log(name, value, **kwargs)
336
+
337
+ def _log_audio(self, logger, name, file_path, step):
338
+ """
339
+ Log audio to the appropriate logger.
340
+
341
+ Args:
342
+ logger: The logger instance
343
+ name: Name of the audio
344
+ file_path: Path to the audio file
345
+ step: Step value
346
+ """
347
+ try:
348
+ if hasattr(logger, 'experiment') and hasattr(logger.experiment, 'log'):
349
+ # WandbLogger
350
+ import wandb
351
+ logger.experiment.log({
352
+ name: wandb.Audio(file_path, sample_rate=44100)
353
+ }, step=step)
354
+ elif hasattr(logger, 'experiment') and hasattr(logger.experiment, 'add_audio'):
355
+ # TensorBoardLogger
356
+ import soundfile as sf
357
+ audio, sample_rate = sf.read(file_path)
358
+ logger.experiment.add_audio(name, audio, step, sample_rate=44100)
359
+ except Exception as e:
360
+ print(f"Warning: Failed to log audio {name}: {e}")
361
+
362
+ def _log_image(self, logger, name, file_path, step):
363
+ """
364
+ Log an image to the appropriate logger.
365
+
366
+ Args:
367
+ logger: The logger instance
368
+ name: Name of the image
369
+ file_path: Path to the image file
370
+ step: Step value
371
+ """
372
+ try:
373
+ if hasattr(logger, 'experiment') and hasattr(logger.experiment, 'log'):
374
+ # WandbLogger
375
+ import wandb
376
+ logger.experiment.log({
377
+ name: wandb.Image(file_path)
378
+ }, step=step)
379
+ elif hasattr(logger, 'experiment') and hasattr(logger.experiment, 'add_image'):
380
+ # TensorBoardLogger
381
+ import PIL.Image
382
+ import numpy as np
383
+ import torch
384
+ image = PIL.Image.open(file_path)
385
+ image_array = np.array(image)
386
+ image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) # HWC to CHW
387
+ logger.experiment.add_image(name, image_tensor, step)
388
+ except Exception as e:
389
+ print(f"Warning: Failed to log image {name}: {e}")
rift_svc/metrics.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def psnr(estimated, target, max_val=None):
5
+ """Calculate Peak Signal-to-Noise Ratio (PSNR)
6
+ Args:
7
+ estimated (torch.Tensor): Estimated mel spectrogram [B, len, n_mel]
8
+ target (torch.Tensor): Target mel spectrogram [B, len, n_mel]
9
+ max_val (float): Maximum value of the signal. If None, uses max of target
10
+ Returns:
11
+ torch.Tensor: PSNR value in dB [B]
12
+ """
13
+ if max_val is None:
14
+ # Use the maximum absolute value between both tensors
15
+ max_val = max(torch.abs(target).max(), torch.abs(estimated).max())
16
+
17
+ # Ensure max_val is not zero
18
+ max_val = max(max_val, torch.finfo(target.dtype).eps)
19
+
20
+ mse = torch.mean((estimated - target) ** 2, dim=(1, 2))
21
+ # Add eps to avoid log of zero
22
+ eps = torch.finfo(target.dtype).eps
23
+ psnr = 20 * torch.log10(max_val + eps) - 10 * torch.log10(mse + eps)
24
+ return psnr
25
+
26
+ def si_snr(estimated, target, eps=1e-8):
27
+ """Calculate Scale-Invariant Signal-to-Noise Ratio (SI-SNR)
28
+ Args:
29
+ estimated (torch.Tensor): Estimated mel spectrogram [B, len, n_mel]
30
+ target (torch.Tensor): Target mel spectrogram [B, len, n_mel]
31
+ eps (float): Small value to avoid division by zero
32
+ Returns:
33
+ torch.Tensor: SI-SNR value in dB [B]
34
+ """
35
+ # Flatten the mel dimension
36
+ estimated = estimated.reshape(estimated.shape[0], -1)
37
+ target = target.reshape(target.shape[0], -1)
38
+
39
+ # Zero-mean normalization
40
+ estimated = estimated - torch.mean(estimated, dim=1, keepdim=True)
41
+ target = target - torch.mean(target, dim=1, keepdim=True)
42
+
43
+ # SI-SNR
44
+ alpha = torch.sum(estimated * target, dim=1, keepdim=True) / (
45
+ torch.sum(target ** 2, dim=1, keepdim=True) + eps)
46
+ target_scaled = alpha * target
47
+
48
+ si_snr = 10 * torch.log10(
49
+ torch.sum(target_scaled ** 2, dim=1) /
50
+ (torch.sum((estimated - target_scaled) ** 2, dim=1) + eps) + eps
51
+ )
52
+ return si_snr
53
+
54
+ def mcd(estimated, target):
55
+ """Calculate Mel-Cepstral Distortion (MCD)
56
+ Args:
57
+ estimated (torch.Tensor): Estimated mel spectrogram [B, len, n_mel]
58
+ target (torch.Tensor): Target mel spectrogram [B, len, n_mel]
59
+ Returns:
60
+ torch.Tensor: MCD value [B], averaged over time steps
61
+ """
62
+ # Convert to log scale
63
+ estimated = torch.log10(torch.clamp(estimated, min=1e-8))
64
+ target = torch.log10(torch.clamp(target, min=1e-8))
65
+
66
+ # Calculate MCD
67
+ diff = estimated - target
68
+ mcd = torch.sqrt(2 * torch.sum(diff ** 2, dim=2)) # [B, len]
69
+ # Average over time dimension
70
+ mcd = mcd.mean(dim=1) # [B]
71
+ return mcd
rift_svc/modules.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from einops import rearrange
4
+ from jaxtyping import Float, Bool
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from x_transformers.x_transformers import apply_rotary_pos_emb
11
+
12
+
13
+ class LoRALinear(nn.Module):
14
+ def __init__(self, linear, rank, alpha):
15
+ super().__init__()
16
+ self.linear = linear
17
+ self.rank = rank
18
+ self.alpha = alpha
19
+ self.scale = alpha / math.sqrt(rank)
20
+ in_features = linear.in_features
21
+ out_features = linear.out_features
22
+ self.A = nn.Parameter(torch.zeros(in_features, rank))
23
+ self.B = nn.Parameter(torch.zeros(out_features, rank))
24
+ # Initialize LoRA parameters
25
+ nn.init.normal_(self.A, mean=0, std=math.sqrt(self.rank) / self.linear.in_features)
26
+ nn.init.zeros_(self.B)
27
+ # Freeze original linear layer parameters
28
+ self.linear.weight.requires_grad = False
29
+ if self.linear.bias is not None:
30
+ self.linear.bias.requires_grad = False
31
+
32
+ def forward(self, x):
33
+ original_out = self.linear(x)
34
+ lora_out = (x @ self.A) @ self.B.T
35
+ return original_out + lora_out * self.scale
36
+
37
+
38
+ # AdaLayerNormZero
39
+ # return with modulated x for attn input, and params for later mlp modulation
40
+ class AdaLayerNormZero(nn.Module):
41
+ def __init__(self, dim):
42
+ super().__init__()
43
+
44
+ self.silu = nn.SiLU()
45
+ self.proj = nn.Linear(dim, dim * 6)
46
+
47
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
48
+
49
+ def forward(self, x, emb = None):
50
+ emb = self.proj(self.silu(emb))
51
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
52
+
53
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
54
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
55
+
56
+
57
+ # AdaLayerNormZero for final layer
58
+ # return only with modulated x for attn input, cuz no more mlp modulation
59
+ class AdaLayerNormZero_Final(nn.Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+
63
+ self.silu = nn.SiLU()
64
+ self.proj = nn.Linear(dim, dim * 2)
65
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
66
+
67
+ def forward(self, x, emb):
68
+ emb = self.proj(self.silu(emb))
69
+ scale, shift = torch.chunk(emb, 2, dim=1)
70
+
71
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
72
+ return x
73
+
74
+ # ReLU^2
75
+ class ReLU2(nn.Module):
76
+ def forward(self, x):
77
+ return F.relu(x, inplace=True).square()
78
+
79
+ # FeedForward
80
+ class ConvMLP(nn.Module):
81
+ def __init__(self, dim: int, dim_out: int | None = None, mult: float = 4, dropout: float = 0.0, kernel_size: int = 7):
82
+ super().__init__()
83
+ inner_dim = int(dim * mult)
84
+ dim_out = dim_out if dim_out is not None else dim
85
+
86
+ #self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim)
87
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
88
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
89
+ self.activation = ReLU2()
90
+ self.dropout = nn.Dropout(dropout)
91
+ self.mlp_proj = nn.Linear(dim, inner_dim)
92
+ self.mlp_out = nn.Linear(inner_dim, dim_out)
93
+
94
+ def forward(self, x):
95
+ x = x.permute(0, 2, 1)
96
+ x = self.dwconv(x)
97
+ x = x.permute(0, 2, 1)
98
+ x = self.norm(x)
99
+ x = self.mlp_proj(x)
100
+ x = self.activation(x)
101
+ x = self.dropout(x)
102
+ x = self.mlp_out(x)
103
+ return x
104
+
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim: int,
110
+ head_dim: int = 64,
111
+ dropout: float = 0.0,
112
+ ):
113
+ super().__init__()
114
+
115
+ if not hasattr(F, "scaled_dot_product_attention"):
116
+ raise ImportError("Attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
117
+
118
+ self.dim = dim
119
+ assert dim % head_dim == 0
120
+ self.head_dim = head_dim
121
+ self.num_heads = int(dim // head_dim)
122
+ self.inner_dim = dim
123
+ self.dropout = dropout
124
+ self.scale = 1 / dim
125
+
126
+ self.q_proj = nn.Linear(dim, self.inner_dim)
127
+ self.k_proj = nn.Linear(dim, self.inner_dim)
128
+ self.v_proj = nn.Linear(dim, self.inner_dim)
129
+
130
+ self.norm_q = nn.LayerNorm(self.head_dim, elementwise_affine=False, eps=1e-6)
131
+ self.norm_k = nn.LayerNorm(self.head_dim, elementwise_affine=False, eps=1e-6)
132
+
133
+ self.attn_out = nn.Linear(self.inner_dim, dim)
134
+ self.attn_dropout = nn.Dropout(dropout)
135
+
136
+ def forward(
137
+ self,
138
+ x: Float[torch.Tensor, "b n d"],
139
+ mask: Bool[torch.Tensor, "b n"] | None = None,
140
+ rope = None,
141
+ ) -> Float[torch.Tensor, "b n d"]:
142
+ batch_size = x.shape[0]
143
+
144
+ # projections
145
+ query = self.q_proj(x)
146
+ key = self.k_proj(x)
147
+ value = self.v_proj(x)
148
+
149
+ # apply rotary position embedding
150
+ if rope is not None:
151
+ freqs, xpos_scale = rope
152
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
153
+
154
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
155
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
156
+
157
+ # attention
158
+ inner_dim = key.shape[-1]
159
+ head_dim = inner_dim // self.num_heads
160
+ query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
161
+ key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
162
+ value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
163
+
164
+ query = self.norm_q(query)
165
+ key = self.norm_k(key)
166
+
167
+ # mask
168
+ if mask is not None:
169
+ attn_mask = mask
170
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
171
+ attn_mask = attn_mask.expand(batch_size, self.num_heads, query.shape[-2], key.shape[-2])
172
+ else:
173
+ attn_mask = None
174
+
175
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, scale=self.scale)
176
+ x = x.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
177
+ x = x.to(query.dtype)
178
+
179
+ # linear proj and dropout
180
+ x = self.attn_out(x)
181
+ x = self.attn_dropout(x)
182
+
183
+ if mask is not None:
184
+ mask = rearrange(mask, 'b n -> b n 1')
185
+ x = x.masked_fill(~mask, 0.)
186
+
187
+ return x
188
+
189
+
190
+ # DiT Block
191
+ class DiTBlock(nn.Module):
192
+
193
+ def __init__(
194
+ self, dim: int, head_dim: int, ff_mult: float = 4,
195
+ dropout: float = 0.0, kernel_size: int = 31):
196
+ super().__init__()
197
+
198
+ self.attn_norm = AdaLayerNormZero(dim)
199
+ self.attn = Attention(
200
+ dim = dim,
201
+ head_dim = head_dim,
202
+ dropout = dropout,
203
+ )
204
+
205
+ self.mlp_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
206
+ self.mlp = ConvMLP(dim = dim, mult = ff_mult, dropout = dropout, kernel_size=kernel_size)
207
+
208
+ def forward(
209
+ self,
210
+ x: Float[torch.Tensor, "b n d"],
211
+ t: Float[torch.Tensor, "b d"],
212
+ mask: Bool[torch.Tensor, "b n"] | None = None,
213
+ rope: Float[torch.Tensor, "b d"] | None = None,
214
+ ) -> Float[torch.Tensor, "b n d"]:
215
+ # pre-norm & modulation for attention input
216
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
217
+
218
+ # attention
219
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
220
+
221
+ # process attention output for input x
222
+ x = x + gate_msa.unsqueeze(1) * attn_output
223
+
224
+ norm = self.mlp_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
225
+ mlp_output = self.mlp(norm)
226
+ x = x + gate_mlp.unsqueeze(1) * mlp_output
227
+
228
+ return x
229
+
230
+
231
+ # sinusoidal position embedding
232
+ class SinusPositionEmbedding(nn.Module):
233
+ def __init__(self, dim: int):
234
+ super().__init__()
235
+ self.dim = dim
236
+
237
+ def forward(self, x: Float[torch.Tensor, "b"], scale: float = 1000) -> Float[torch.Tensor, "b d"]:
238
+ device = x.device
239
+ half_dim = self.dim // 2
240
+ emb = math.log(10000) / (half_dim - 1)
241
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
242
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
243
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
244
+ return emb
245
+
246
+
247
+ # time step conditioning embedding
248
+ class TimestepEmbedding(nn.Module):
249
+ def __init__(self, dim: int, freq_embed_dim: int = 256):
250
+ super().__init__()
251
+ self.time2emb = SinusPositionEmbedding(freq_embed_dim)
252
+ self.time_emb = nn.Linear(freq_embed_dim, dim)
253
+ self.act = nn.SiLU()
254
+ self.proj = nn.Linear(dim, dim)
255
+
256
+ def forward(self, timestep: Float[torch.Tensor, "b"]) -> Float[torch.Tensor, "b d"]:
257
+ time = self.time2emb(timestep)
258
+ time = self.time_emb(time)
259
+ time = self.act(time)
260
+ time = self.proj(time)
261
+ return time
rift_svc/nsf_hifigan/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .nvSTFT import STFT
2
+ from .vocoder import Vocoder, NsfHifiGAN
rift_svc/nsf_hifigan/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
rift_svc/nsf_hifigan/models.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from .env import AttrDict
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
10
+ from .utils import init_weights, get_padding
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ def load_model(model_path, device='cuda'):
16
+ h = load_config(model_path)
17
+
18
+ generator = Generator(h).to(device)
19
+
20
+ cp_dict = torch.load(model_path, map_location=device, weights_only=True)
21
+ generator.load_state_dict(cp_dict['generator'])
22
+ generator.eval()
23
+ generator.remove_weight_norm()
24
+ del cp_dict
25
+ return generator, h
26
+
27
+ def load_config(model_path):
28
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
29
+ with open(config_file) as f:
30
+ data = f.read()
31
+
32
+ json_config = json.loads(data)
33
+ h = AttrDict(json_config)
34
+ return h
35
+
36
+
37
+ class ResBlock1(torch.nn.Module):
38
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
39
+ super(ResBlock1, self).__init__()
40
+ self.h = h
41
+ self.convs1 = nn.ModuleList([
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
43
+ padding=get_padding(kernel_size, dilation[0]))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
45
+ padding=get_padding(kernel_size, dilation[1]))),
46
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
47
+ padding=get_padding(kernel_size, dilation[2])))
48
+ ])
49
+ self.convs1.apply(init_weights)
50
+
51
+ self.convs2 = nn.ModuleList([
52
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
53
+ padding=get_padding(kernel_size, 1))),
54
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
55
+ padding=get_padding(kernel_size, 1))),
56
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
57
+ padding=get_padding(kernel_size, 1)))
58
+ ])
59
+ self.convs2.apply(init_weights)
60
+
61
+ def forward(self, x):
62
+ for c1, c2 in zip(self.convs1, self.convs2):
63
+ xt = F.leaky_relu(x, LRELU_SLOPE)
64
+ xt = c1(xt)
65
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
66
+ xt = c2(xt)
67
+ x = xt + x
68
+ return x
69
+
70
+ def remove_weight_norm(self):
71
+ for l in self.convs1:
72
+ remove_weight_norm(l)
73
+ for l in self.convs2:
74
+ remove_weight_norm(l)
75
+
76
+
77
+ class ResBlock2(torch.nn.Module):
78
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
79
+ super(ResBlock2, self).__init__()
80
+ self.h = h
81
+ self.convs = nn.ModuleList([
82
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
83
+ padding=get_padding(kernel_size, dilation[0]))),
84
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
85
+ padding=get_padding(kernel_size, dilation[1])))
86
+ ])
87
+ self.convs.apply(init_weights)
88
+
89
+ def forward(self, x):
90
+ for c in self.convs:
91
+ xt = F.leaky_relu(x, LRELU_SLOPE)
92
+ xt = c(xt)
93
+ x = xt + x
94
+ return x
95
+
96
+ def remove_weight_norm(self):
97
+ for l in self.convs:
98
+ remove_weight_norm(l)
99
+
100
+
101
+ class SineGen(torch.nn.Module):
102
+ """ Definition of sine generator
103
+ SineGen(samp_rate, harmonic_num = 0,
104
+ sine_amp = 0.1, noise_std = 0.003,
105
+ voiced_threshold = 0,
106
+ flag_for_pulse=False)
107
+ samp_rate: sampling rate in Hz
108
+ harmonic_num: number of harmonic overtones (default 0)
109
+ sine_amp: amplitude of sine-waveform (default 0.1)
110
+ noise_std: std of Gaussian noise (default 0.003)
111
+ voiced_threshold: F0 threshold for U/V classification (default 0)
112
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
113
+ Note: when flag_for_pulse is True, the first time step of a voiced
114
+ segment is always sin(np.pi) or cos(0)
115
+ """
116
+
117
+ def __init__(self, samp_rate, harmonic_num=0,
118
+ sine_amp=0.1, noise_std=0.003,
119
+ voiced_threshold=0):
120
+ super(SineGen, self).__init__()
121
+ self.sine_amp = sine_amp
122
+ self.noise_std = noise_std
123
+ self.harmonic_num = harmonic_num
124
+ self.dim = self.harmonic_num + 1
125
+ self.sampling_rate = samp_rate
126
+ self.voiced_threshold = voiced_threshold
127
+
128
+ def _f02uv(self, f0):
129
+ # generate uv signal
130
+ uv = torch.ones_like(f0)
131
+ uv = uv * (f0 > self.voiced_threshold)
132
+ return uv
133
+
134
+ def _f02sine(self, f0, upp):
135
+ """ f0: (batchsize, length, dim)
136
+ where dim indicates fundamental tone and overtones
137
+ """
138
+ rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, device=f0.device)
139
+ rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
140
+ rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
141
+ rad += F.pad(rad_acc, (0, 0, 1, -1))
142
+ rad = rad.reshape(f0.shape[0], -1, 1)
143
+ rad = torch.multiply(rad, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1))
144
+ rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
145
+ rand_ini[..., 0] = 0
146
+ rad += rand_ini
147
+ sines = torch.sin(2 * np.pi * rad)
148
+ return sines
149
+
150
+ @torch.no_grad()
151
+ def forward(self, f0, upp):
152
+ """ sine_tensor, uv = forward(f0)
153
+ input F0: tensor(batchsize=1, length, dim=1)
154
+ f0 for unvoiced steps should be 0
155
+ output sine_tensor: tensor(batchsize=1, length, dim)
156
+ output uv: tensor(batchsize=1, length, 1)
157
+ """
158
+ f0 = f0.unsqueeze(-1)
159
+ sine_waves = self._f02sine(f0, upp) * self.sine_amp
160
+ uv = (f0 > self.voiced_threshold).float()
161
+ uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
162
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
163
+ noise = noise_amp * torch.randn_like(sine_waves)
164
+ sine_waves = sine_waves * uv + noise
165
+ return sine_waves
166
+
167
+
168
+ class SourceModuleHnNSF(torch.nn.Module):
169
+ """ SourceModule for hn-nsf
170
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
171
+ add_noise_std=0.003, voiced_threshod=0)
172
+ sampling_rate: sampling_rate in Hz
173
+ harmonic_num: number of harmonic above F0 (default: 0)
174
+ sine_amp: amplitude of sine source signal (default: 0.1)
175
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
176
+ note that amplitude of noise in unvoiced is decided
177
+ by sine_amp
178
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
179
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
180
+ F0_sampled (batchsize, length, 1)
181
+ Sine_source (batchsize, length, 1)
182
+ noise_source (batchsize, length 1)
183
+ uv (batchsize, length, 1)
184
+ """
185
+
186
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
187
+ add_noise_std=0.003, voiced_threshod=0):
188
+ super(SourceModuleHnNSF, self).__init__()
189
+
190
+ self.sine_amp = sine_amp
191
+ self.noise_std = add_noise_std
192
+
193
+ # to produce sine waveforms
194
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
195
+ sine_amp, add_noise_std, voiced_threshod)
196
+
197
+ # to merge source harmonics into a single excitation
198
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
199
+ self.l_tanh = torch.nn.Tanh()
200
+
201
+ def forward(self, x, upp):
202
+ sine_wavs = self.l_sin_gen(x, upp)
203
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
204
+ return sine_merge
205
+
206
+
207
+ class Generator(torch.nn.Module):
208
+ def __init__(self, h):
209
+ super(Generator, self).__init__()
210
+ self.h = h
211
+ self.num_kernels = len(h.resblock_kernel_sizes)
212
+ self.num_upsamples = len(h.upsample_rates)
213
+ self.m_source = SourceModuleHnNSF(
214
+ sampling_rate=h.sampling_rate,
215
+ harmonic_num=8
216
+ )
217
+ self.noise_convs = nn.ModuleList()
218
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
219
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
220
+
221
+ self.ups = nn.ModuleList()
222
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
223
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
224
+ self.ups.append(weight_norm(
225
+ ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
226
+ k, u, padding=(k - u) // 2)))
227
+ if i + 1 < len(h.upsample_rates): #
228
+ stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
229
+ self.noise_convs.append(Conv1d(
230
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
231
+ else:
232
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
233
+ self.resblocks = nn.ModuleList()
234
+ ch = h.upsample_initial_channel
235
+ for i in range(len(self.ups)):
236
+ ch //= 2
237
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
238
+ self.resblocks.append(resblock(h, ch, k, d))
239
+
240
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
241
+ self.ups.apply(init_weights)
242
+ self.conv_post.apply(init_weights)
243
+ self.upp = int(np.prod(h.upsample_rates))
244
+
245
+ def forward(self, x, f0):
246
+ har_source = self.m_source(f0, self.upp).transpose(1, 2)
247
+ x = self.conv_pre(x)
248
+ for i in range(self.num_upsamples):
249
+ x = F.leaky_relu(x, LRELU_SLOPE)
250
+ x = self.ups[i](x)
251
+ x_source = self.noise_convs[i](har_source)
252
+ x = x + x_source
253
+ xs = None
254
+ for j in range(self.num_kernels):
255
+ if xs is None:
256
+ xs = self.resblocks[i * self.num_kernels + j](x)
257
+ else:
258
+ xs += self.resblocks[i * self.num_kernels + j](x)
259
+ x = xs / self.num_kernels
260
+ x = F.leaky_relu(x)
261
+ x = self.conv_post(x)
262
+ x = torch.tanh(x)
263
+
264
+ return x
265
+
266
+ def remove_weight_norm(self):
267
+ print('Removing weight norm...')
268
+ for l in self.ups:
269
+ remove_weight_norm(l)
270
+ for l in self.resblocks:
271
+ l.remove_weight_norm()
272
+ remove_weight_norm(self.conv_pre)
273
+ remove_weight_norm(self.conv_post)
274
+
275
+
276
+ class DiscriminatorP(torch.nn.Module):
277
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
278
+ super(DiscriminatorP, self).__init__()
279
+ self.period = period
280
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
281
+ self.convs = nn.ModuleList([
282
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
283
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
284
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
285
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
286
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
287
+ ])
288
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
289
+
290
+ def forward(self, x):
291
+ fmap = []
292
+
293
+ # 1d to 2d
294
+ b, c, t = x.shape
295
+ if t % self.period != 0: # pad first
296
+ n_pad = self.period - (t % self.period)
297
+ x = F.pad(x, (0, n_pad), "reflect")
298
+ t = t + n_pad
299
+ x = x.view(b, c, t // self.period, self.period)
300
+
301
+ for l in self.convs:
302
+ x = l(x)
303
+ x = F.leaky_relu(x, LRELU_SLOPE)
304
+ fmap.append(x)
305
+ x = self.conv_post(x)
306
+ fmap.append(x)
307
+ x = torch.flatten(x, 1, -1)
308
+
309
+ return x, fmap
310
+
311
+
312
+ class MultiPeriodDiscriminator(torch.nn.Module):
313
+ def __init__(self, periods=None):
314
+ super(MultiPeriodDiscriminator, self).__init__()
315
+ self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
316
+ self.discriminators = nn.ModuleList()
317
+ for period in self.periods:
318
+ self.discriminators.append(DiscriminatorP(period))
319
+
320
+ def forward(self, y, y_hat):
321
+ y_d_rs = []
322
+ y_d_gs = []
323
+ fmap_rs = []
324
+ fmap_gs = []
325
+ for i, d in enumerate(self.discriminators):
326
+ y_d_r, fmap_r = d(y)
327
+ y_d_g, fmap_g = d(y_hat)
328
+ y_d_rs.append(y_d_r)
329
+ fmap_rs.append(fmap_r)
330
+ y_d_gs.append(y_d_g)
331
+ fmap_gs.append(fmap_g)
332
+
333
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
334
+
335
+
336
+ class DiscriminatorS(torch.nn.Module):
337
+ def __init__(self, use_spectral_norm=False):
338
+ super(DiscriminatorS, self).__init__()
339
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
340
+ self.convs = nn.ModuleList([
341
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
342
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
343
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
344
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
345
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
346
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
347
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
348
+ ])
349
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
350
+
351
+ def forward(self, x):
352
+ fmap = []
353
+ for l in self.convs:
354
+ x = l(x)
355
+ x = F.leaky_relu(x, LRELU_SLOPE)
356
+ fmap.append(x)
357
+ x = self.conv_post(x)
358
+ fmap.append(x)
359
+ x = torch.flatten(x, 1, -1)
360
+
361
+ return x, fmap
362
+
363
+
364
+ class MultiScaleDiscriminator(torch.nn.Module):
365
+ def __init__(self):
366
+ super(MultiScaleDiscriminator, self).__init__()
367
+ self.discriminators = nn.ModuleList([
368
+ DiscriminatorS(use_spectral_norm=True),
369
+ DiscriminatorS(),
370
+ DiscriminatorS(),
371
+ ])
372
+ self.meanpools = nn.ModuleList([
373
+ AvgPool1d(4, 2, padding=2),
374
+ AvgPool1d(4, 2, padding=2)
375
+ ])
376
+
377
+ def forward(self, y, y_hat):
378
+ y_d_rs = []
379
+ y_d_gs = []
380
+ fmap_rs = []
381
+ fmap_gs = []
382
+ for i, d in enumerate(self.discriminators):
383
+ if i != 0:
384
+ y = self.meanpools[i - 1](y)
385
+ y_hat = self.meanpools[i - 1](y_hat)
386
+ y_d_r, fmap_r = d(y)
387
+ y_d_g, fmap_g = d(y_hat)
388
+ y_d_rs.append(y_d_r)
389
+ fmap_rs.append(fmap_r)
390
+ y_d_gs.append(y_d_g)
391
+ fmap_gs.append(fmap_g)
392
+
393
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
394
+
395
+
396
+ def feature_loss(fmap_r, fmap_g):
397
+ loss = 0
398
+ for dr, dg in zip(fmap_r, fmap_g):
399
+ for rl, gl in zip(dr, dg):
400
+ loss += torch.mean(torch.abs(rl - gl))
401
+
402
+ return loss * 2
403
+
404
+
405
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
406
+ loss = 0
407
+ r_losses = []
408
+ g_losses = []
409
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
410
+ r_loss = torch.mean((1 - dr) ** 2)
411
+ g_loss = torch.mean(dg ** 2)
412
+ loss += (r_loss + g_loss)
413
+ r_losses.append(r_loss.item())
414
+ g_losses.append(g_loss.item())
415
+
416
+ return loss, r_losses, g_losses
417
+
418
+
419
+ def generator_loss(disc_outputs):
420
+ loss = 0
421
+ gen_losses = []
422
+ for dg in disc_outputs:
423
+ l = torch.mean((1 - dg) ** 2)
424
+ gen_losses.append(l)
425
+ loss += l
426
+
427
+ return loss, gen_losses
rift_svc/nsf_hifigan/nvSTFT.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
4
+ import random
5
+ import torch
6
+ import torch.utils.data
7
+ import numpy as np
8
+ import librosa
9
+ from librosa.util import normalize
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from scipy.io.wavfile import read
12
+ import soundfile as sf
13
+ import torch.nn.functional as F
14
+
15
+ def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
16
+ sampling_rate = None
17
+ try:
18
+ data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
19
+ except Exception as ex:
20
+ print(f"'{full_path}' failed to load.\nException:")
21
+ print(ex)
22
+ if return_empty_on_exception:
23
+ return [], sampling_rate or target_sr or 48000
24
+ else:
25
+ raise Exception(ex)
26
+
27
+ if len(data.shape) > 1:
28
+ data = data[:, 0]
29
+ assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
30
+
31
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
32
+ max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
33
+ else: # if audio data is type fp32
34
+ max_mag = max(np.amax(data), -np.amin(data))
35
+ max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
36
+
37
+ data = torch.FloatTensor(data.astype(np.float32))/max_mag
38
+
39
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
40
+ return [], sampling_rate or target_sr or 48000
41
+ if target_sr is not None and sampling_rate != target_sr:
42
+ data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
43
+ sampling_rate = target_sr
44
+
45
+ return data, sampling_rate
46
+
47
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
48
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
49
+
50
+ def dynamic_range_decompression(x, C=1):
51
+ return np.exp(x) / C
52
+
53
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
54
+ return torch.log(torch.clamp(x, min=clip_val) * C)
55
+
56
+ def dynamic_range_decompression_torch(x, C=1):
57
+ return torch.exp(x) / C
58
+
59
+ class STFT():
60
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
61
+ self.target_sr = sr
62
+
63
+ self.n_mels = n_mels
64
+ self.n_fft = n_fft
65
+ self.win_size = win_size
66
+ self.hop_length = hop_length
67
+ self.fmin = fmin
68
+ self.fmax = fmax
69
+ self.clip_val = clip_val
70
+ self.mel_basis = {}
71
+ self.hann_window = {}
72
+
73
+ def get_mel(self, y, keyshift=0, speed=1, center=False):
74
+ sampling_rate = self.target_sr
75
+ n_mels = self.n_mels
76
+ n_fft = self.n_fft
77
+ win_size = self.win_size
78
+ hop_length = self.hop_length
79
+ fmin = self.fmin
80
+ fmax = self.fmax
81
+ clip_val = self.clip_val
82
+
83
+ factor = 2 ** (keyshift / 12)
84
+ n_fft_new = int(np.round(n_fft * factor))
85
+ win_size_new = int(np.round(win_size * factor))
86
+ hop_length_new = int(np.round(hop_length * speed))
87
+
88
+ mel_basis_key = str(fmax)+'_'+str(y.device)
89
+ if mel_basis_key not in self.mel_basis:
90
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
91
+ self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
92
+
93
+ keyshift_key = str(keyshift)+'_'+str(y.device)
94
+ if keyshift_key not in self.hann_window:
95
+ self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
96
+
97
+ pad_left = (win_size_new - hop_length_new) //2
98
+ pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
99
+ if pad_right < y.size(-1):
100
+ mode = 'reflect'
101
+ else:
102
+ mode = 'constant'
103
+ y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
104
+ y = y.squeeze(1)
105
+
106
+ spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
107
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
108
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
109
+ if keyshift != 0:
110
+ size = n_fft // 2 + 1
111
+ resize = spec.size(1)
112
+ if resize < size:
113
+ spec = F.pad(spec, (0, 0, 0, size-resize))
114
+ spec = spec[:, :size, :] * win_size / win_size_new
115
+ spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
116
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
117
+ return spec
118
+
119
+ def __call__(self, audiopath):
120
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
121
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
122
+ return spect
123
+
124
+ stft = STFT()
rift_svc/nsf_hifigan/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size*dilation - dilation)/2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def del_old_checkpoints(cp_dir, prefix, n_models=2):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern) # get checkpoint paths
55
+ cp_list = sorted(cp_list)# sort by iter
56
+ if len(cp_list) > n_models: # if more than n_models models are found
57
+ for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58
+ open(cp, 'w').close()# empty file contents
59
+ os.unlink(cp)# delete file (move to trash when using Colab)
60
+
61
+
62
+ def scan_checkpoint(cp_dir, prefix):
63
+ pattern = os.path.join(cp_dir, prefix + '????????')
64
+ cp_list = glob.glob(pattern)
65
+ if len(cp_list) == 0:
66
+ return None
67
+ return sorted(cp_list)[-1]
rift_svc/nsf_hifigan/vocoder.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from .nvSTFT import STFT
8
+ from .models import load_model,load_config
9
+ from torchaudio.transforms import Resample
10
+ from jaxtyping import Float
11
+
12
+
13
+ class DotDict(dict):
14
+ def __getattr__(*args):
15
+ val = dict.get(*args)
16
+ return DotDict(val) if type(val) is dict else val
17
+
18
+ __setattr__ = dict.__setitem__
19
+ __delattr__ = dict.__delitem__
20
+
21
+
22
+ def load_model_vocoder(
23
+ model_path,
24
+ device='cpu'):
25
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
26
+ with open(config_file, "r") as config:
27
+ args = yaml.safe_load(config)
28
+ args = DotDict(args)
29
+
30
+ # load vocoder
31
+ vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
32
+
33
+ return vocoder, args
34
+
35
+
36
+ class Vocoder:
37
+ def __init__(self, vocoder_type, vocoder_ckpt, device = None):
38
+ if device is None:
39
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
+ self.device = device
41
+
42
+ if vocoder_type == 'nsf-hifigan':
43
+ self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device)
44
+ elif vocoder_type == 'nsf-hifigan-log10':
45
+ self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device)
46
+ else:
47
+ raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
48
+
49
+ self.resample_kernel = {}
50
+ self.vocoder_sample_rate = self.vocoder.sample_rate()
51
+ self.vocoder_hop_size = self.vocoder.hop_size()
52
+ self.dimension = self.vocoder.dimension()
53
+
54
+ def extract(self, audio, sample_rate=0, keyshift=0):
55
+
56
+ # resample
57
+ if sample_rate == self.vocoder_sample_rate or sample_rate == 0:
58
+ audio_res = audio
59
+ else:
60
+ key_str = str(sample_rate)
61
+ if key_str not in self.resample_kernel:
62
+ self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
63
+ audio_res = self.resample_kernel[key_str](audio)
64
+
65
+ # extract
66
+ mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
67
+ return mel
68
+
69
+ def infer(self, mel, f0):
70
+ f0 = f0[:,:mel.size(1),0] # B, n_frames
71
+ audio = self.vocoder(mel, f0)
72
+ return audio
73
+
74
+
75
+ class NsfHifiGAN(torch.nn.Module):
76
+ def __init__(self, model_path, device=None):
77
+ super().__init__()
78
+ if device is None:
79
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
80
+ self.device = device
81
+ self.model_path = model_path
82
+ self.model = None
83
+ self.h = load_config(model_path)
84
+ self.stft = STFT(
85
+ self.h.sampling_rate,
86
+ self.h.num_mels,
87
+ self.h.n_fft,
88
+ self.h.win_size,
89
+ self.h.hop_size,
90
+ self.h.fmin,
91
+ self.h.fmax)
92
+
93
+ def sample_rate(self):
94
+ return self.h.sampling_rate
95
+
96
+ def hop_size(self):
97
+ return self.h.hop_size
98
+
99
+ def dimension(self):
100
+ return self.h.num_mels
101
+
102
+ def extract(self, audio, keyshift=0):
103
+ mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
104
+ return mel
105
+
106
+ def forward(self, mel: Float[torch.Tensor, "batch bins n_frames"], f0: Float[torch.Tensor, "batch n_frames"]):
107
+ if self.model is None:
108
+ print('| Load HifiGAN: ', self.model_path)
109
+ self.model, self.h = load_model(self.model_path, device=self.device)
110
+ with torch.no_grad():
111
+ audio = self.model(mel, f0)
112
+ return audio
113
+
114
+
115
+ class NsfHifiGANLog10(NsfHifiGAN):
116
+ def forward(self, mel, f0):
117
+ if self.model is None:
118
+ print('| Load HifiGAN: ', self.model_path)
119
+ self.model, self.h = load_model(self.model_path, device=self.device)
120
+ with torch.no_grad():
121
+ c = 0.434294 * mel.transpose(1, 2)
122
+ audio = self.model(c, f0)
123
+ return audio
rift_svc/optim.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from schedulefree import AdamWScheduleFree
2
+ from torch.optim import AdamW
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+ import math
5
+
6
+
7
+ def get_optimizer(
8
+ optimizer_type, model, lr, betas, weight_decay, warmup_steps,
9
+ lora_training=False, **kwargs):
10
+ from collections import defaultdict
11
+ param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
12
+ if not lora_training:
13
+ specp_decay_params = defaultdict(list)
14
+ specp_decay_lr = {}
15
+ decay_params = []
16
+ nodecay_params = []
17
+ for n, p in param_dict.items():
18
+ if p.dim() >= 2:
19
+ if n.endswith('out.weight') or n.endswith('proj.weight'):
20
+ fan_out, fan_in = p.shape[-2:]
21
+ fan_ratio = fan_out / fan_in
22
+ specp_decay_params[f"specp_decay_{fan_ratio:.2f}"].append(p)
23
+ specp_decay_lr[f"specp_decay_{fan_ratio:.2f}"] = lr * fan_ratio
24
+ else:
25
+ decay_params.append(p)
26
+ else:
27
+ nodecay_params.append(p)
28
+
29
+ optim_groups = [
30
+ {'params': decay_params, 'weight_decay': weight_decay, 'lr': lr},
31
+ {'params': nodecay_params, 'weight_decay': 0.0, 'lr': lr}
32
+ ] + [
33
+ {'params': params, 'weight_decay': weight_decay, 'lr': specp_decay_lr[group_name]}
34
+ for group_name, params in specp_decay_params.items()
35
+ ]
36
+ else:
37
+ lora_a_or_spk_embed_params = []
38
+ lora_b_params = []
39
+ for n, p in param_dict.items():
40
+ if n.endswith('.A.weight') or n.endswith('.spk_embed.weight'):
41
+ lora_a_or_spk_embed_params.append(p)
42
+ elif n.endswith('.B.weight'):
43
+ lora_b_params.append(p)
44
+ dim = model.transformer.dim
45
+ rank = model.transformer.transformer_blocks[0].attn.k_proj.rank
46
+ optim_groups = [
47
+ {'params': lora_a_or_spk_embed_params, 'weight_decay': weight_decay, 'lr': lr},
48
+ {'params': lora_b_params, 'weight_decay': weight_decay, 'lr': lr*math.sqrt(dim/rank)}
49
+ ]
50
+ if optimizer_type == 'adamwsf':
51
+ optimizer = AdamWScheduleFree(optim_groups, betas=betas, warmup_steps=warmup_steps)
52
+ return optimizer, None
53
+ elif optimizer_type == 'adamw':
54
+ optimizer = AdamW(optim_groups, betas=betas, weight_decay=weight_decay)
55
+ max_steps = kwargs['max_steps']
56
+ min_lr = kwargs.get('min_lr', 0.0)
57
+ lr_scheduler = LinearWarmupDecayLR(optimizer, warmup_steps, max_steps, min_lr=min_lr)
58
+ return optimizer, lr_scheduler
59
+ else:
60
+ raise ValueError(f"Invalid optimizer type: {optimizer_type}")
61
+
62
+
63
+ class LinearWarmupDecayLR(_LRScheduler):
64
+ """
65
+ Linear learning rate scheduler with warmup and minimum lr.
66
+
67
+ During warmup, the LR increases linearly from 0 to the base LR.
68
+ After warmup, the LR decays linearly from the base LR down to min_lr.
69
+
70
+ Args:
71
+ optimizer (Optimizer): Wrapped optimizer.
72
+ warmup_steps (int): Number of steps to linearly increase LR.
73
+ total_steps (int): Total number of steps for training (warmup + decay).
74
+ min_lr (float): Minimum learning rate after decay.
75
+ last_epoch (int): The index of last epoch. Default: -1.
76
+ """
77
+
78
+ def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0, last_epoch=-1):
79
+ if total_steps <= warmup_steps:
80
+ raise ValueError(
81
+ "Total steps must be larger than warmup_steps for decay to happen."
82
+ )
83
+ self.warmup_steps = warmup_steps
84
+ self.total_steps = total_steps
85
+ self.min_lr = min_lr
86
+ super(LinearWarmupDecayLR, self).__init__(optimizer, last_epoch)
87
+
88
+ def get_lr(self):
89
+ """Compute learning rate using linear warmup and then linear decay."""
90
+ # Note: self.last_epoch is incremented by the base _LRScheduler.step() before calling get_lr().
91
+ if self.last_epoch < self.warmup_steps:
92
+ # Warmup phase: increase linearly from 0 (or a small value) to base_lr.
93
+ return [
94
+ base_lr * float(self.last_epoch + 1) / float(self.warmup_steps)
95
+ for base_lr in self.base_lrs
96
+ ]
97
+ else:
98
+ # Decay phase: decrease linearly from base_lr to min_lr.
99
+ progress = float(self.last_epoch - self.warmup_steps) / float(self.total_steps - self.warmup_steps)
100
+ return [
101
+ max(base_lr * (1.0 - progress) + self.min_lr * progress, self.min_lr)
102
+ for base_lr in self.base_lrs
103
+ ]
rift_svc/rf.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Literal
2
+ from jaxtyping import Bool
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from torchdiffeq import odeint
8
+
9
+ from einops import rearrange
10
+
11
+ from rift_svc.utils import (
12
+ exists,
13
+ lens_to_mask,
14
+ )
15
+
16
+
17
+ def sample_time(time_schedule: Literal['uniform', 'lognorm'], size: int, device: torch.device):
18
+ if time_schedule == 'uniform':
19
+ t = torch.rand((size,), device=device)
20
+ elif time_schedule == 'lognorm':
21
+ # stratified sampling of normals
22
+ # first stratified sample from uniform
23
+ quantiles = torch.linspace(0, 1, size + 1).to(device)
24
+ z = quantiles[:-1] + torch.rand((size,)).to(device) / size
25
+ # now transform to normal
26
+ z = torch.erfinv(2 * z - 1) * math.sqrt(2)
27
+ t = torch.sigmoid(z)
28
+ return t
29
+
30
+
31
+ class RF(nn.Module):
32
+ def __init__(
33
+ self,
34
+ transformer: nn.Module,
35
+ time_schedule: Literal['uniform', 'lognorm'] = 'lognorm',
36
+ odeint_kwargs: dict = dict(
37
+ method='euler'
38
+ ),
39
+ ):
40
+ super().__init__()
41
+
42
+ self.transformer = transformer
43
+ dim = transformer.dim
44
+ self.dim = dim
45
+
46
+ # Sampling related parameters
47
+ self.odeint_kwargs = odeint_kwargs
48
+ self.time_schedule = time_schedule
49
+
50
+ self.mel_min = -12
51
+ self.mel_max = 2
52
+
53
+
54
+ @property
55
+ def device(self):
56
+ return next(self.parameters()).device
57
+
58
+ @torch.no_grad()
59
+ def sample(
60
+ self,
61
+ src_mel: torch.Tensor, # [b n d]
62
+ spk_id: torch.Tensor, # [b]
63
+ f0: torch.Tensor, # [b n]
64
+ rms: torch.Tensor, # [b n]
65
+ cvec: torch.Tensor, # [b n d]
66
+ frame_len: torch.Tensor | None = None, # [b]
67
+ steps: int = 32,
68
+ bad_cvec: torch.Tensor | None = None,
69
+ ds_cfg_strength: float = 0.0,
70
+ spk_cfg_strength: float = 0.0,
71
+ skip_cfg_strength: float = 0.0,
72
+ cfg_skip_layers: Union[int, List[int], None] = None,
73
+ cfg_rescale: float = 0.7,
74
+ ):
75
+ self.eval()
76
+
77
+ batch, mel_seq_len, num_mel_channels = src_mel.shape
78
+ device = src_mel.device
79
+
80
+ if not exists(frame_len):
81
+ frame_len = torch.full((batch,), mel_seq_len, device=device)
82
+
83
+ mask = lens_to_mask(frame_len)
84
+
85
+ # Define the ODE function
86
+ def fn(t, x):
87
+ pred = self.transformer(
88
+ x=x,
89
+ spk=spk_id,
90
+ f0=f0,
91
+ rms=rms,
92
+ cvec=cvec,
93
+ time=t,
94
+ mask=mask
95
+ )
96
+ cfg_flag = (ds_cfg_strength > 1e-5) or (skip_cfg_strength > 1e-5) or (spk_cfg_strength > 1e-5)
97
+ if cfg_rescale > 1e-5 and cfg_flag:
98
+ std_pred = pred.std()
99
+
100
+ if ds_cfg_strength > 1e-5:
101
+ assert exists(bad_cvec), "bad_cvec is required when cfg_strength is greater than 0"
102
+ bad_cvec_pred = self.transformer(
103
+ x=x,
104
+ spk=spk_id,
105
+ f0=f0,
106
+ rms=rms,
107
+ cvec=bad_cvec,
108
+ time=t,
109
+ mask=mask,
110
+ skip_layers=cfg_skip_layers
111
+ )
112
+
113
+ pred = pred + (pred - bad_cvec_pred) * ds_cfg_strength
114
+
115
+ if skip_cfg_strength > 1e-5:
116
+ skip_pred = self.transformer(
117
+ x=x,
118
+ spk=spk_id,
119
+ f0=f0,
120
+ rms=rms,
121
+ cvec=cvec,
122
+ time=t,
123
+ mask=mask,
124
+ skip_layers=cfg_skip_layers
125
+ )
126
+
127
+ pred = pred + (pred - skip_pred) * skip_cfg_strength
128
+
129
+ if spk_cfg_strength > 1e-5:
130
+ null_spk_pred = self.transformer(
131
+ x=x,
132
+ spk=spk_id,
133
+ f0=f0,
134
+ rms=rms,
135
+ cvec=cvec,
136
+ time=t,
137
+ mask=mask,
138
+ drop_speaker=True
139
+ )
140
+
141
+ pred = pred + (pred - null_spk_pred) * spk_cfg_strength
142
+
143
+ if cfg_rescale > 1e-5 and cfg_flag:
144
+ std_cfg = pred.std()
145
+ pred_rescaled = pred * (std_pred / std_cfg)
146
+ pred = cfg_rescale * pred_rescaled + (1 - cfg_rescale) * pred
147
+
148
+ return pred
149
+
150
+ # Noise input
151
+ y0 = torch.randn(batch, mel_seq_len, num_mel_channels, device=self.device)
152
+ # mask out the padded tokens
153
+ y0 = y0.masked_fill(~mask.unsqueeze(-1), 0)
154
+
155
+ t_start = 0
156
+ t = torch.linspace(t_start, 1, steps, device=self.device)
157
+
158
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
159
+
160
+ sampled = trajectory[-1]
161
+ out = self.denorm_mel(sampled)
162
+ out = torch.where(mask.unsqueeze(-1), out, src_mel)
163
+
164
+ return out, trajectory
165
+
166
+ def forward(
167
+ self,
168
+ mel: torch.Tensor, # mel
169
+ spk_id: torch.Tensor, # [b]
170
+ f0: torch.Tensor, # [b n]
171
+ rms: torch.Tensor, # [b n]
172
+ cvec: torch.Tensor, # [b n d]
173
+ frame_len: torch.Tensor | None = None,
174
+ drop_speaker: Union[bool, Bool[torch.Tensor, "b"]] = False,
175
+ ):
176
+ batch, seq_len, dtype, device = *mel.shape[:2], mel.dtype, self.device
177
+
178
+ # Handle lengths and masks
179
+ if not exists(frame_len):
180
+ frame_len = torch.full((batch,), seq_len, device=device)
181
+
182
+ mask = lens_to_mask(frame_len, length=seq_len) # Typically padded to max length in batch
183
+
184
+ x1 = self.norm_mel(mel)
185
+ x0 = torch.randn_like(x1)
186
+
187
+ # uniform time steps sampling
188
+ time = sample_time(self.time_schedule, batch, self.device)
189
+
190
+ t = rearrange(time, 'b -> b 1 1')
191
+ xt = (1 - t) * x0 + t * x1
192
+ flow = x1 - x0
193
+
194
+ pred = self.transformer(
195
+ x=xt,
196
+ spk=spk_id,
197
+ f0=f0,
198
+ rms=rms,
199
+ cvec=cvec,
200
+ time=time,
201
+ drop_speaker=drop_speaker,
202
+ mask=mask
203
+ )
204
+
205
+ # Flow matching loss
206
+ loss = F.mse_loss(pred, flow, reduction='none')
207
+ loss = loss[mask]
208
+
209
+ return loss.mean(), pred
210
+
211
+ def norm_mel(self, mel: torch.Tensor):
212
+ return (mel - self.mel_min) / (self.mel_max - self.mel_min) * 2 - 1
213
+
214
+ def denorm_mel(self, mel: torch.Tensor):
215
+ return (mel + 1) / 2 * (self.mel_max - self.mel_min) + self.mel_min
rift_svc/rmvpe/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .constants import *
2
+ from .model import E2E, E2E0
3
+ from .utils import to_local_average_f0, to_viterbi_f0
4
+ from .inference import RMVPE
5
+ from .spec import MelSpectrogram
rift_svc/rmvpe/constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ SAMPLE_RATE = 16000
2
+
3
+ N_CLASS = 360
4
+
5
+ N_MELS = 128
6
+ MEL_FMIN = 30
7
+ MEL_FMAX = 8000
8
+ WINDOW_LENGTH = 1024
9
+ CONST = 1997.3794084376191
rift_svc/rmvpe/deepunet.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .constants import N_MELS
4
+
5
+
6
+ class ConvBlockRes(nn.Module):
7
+ def __init__(self, in_channels, out_channels, momentum=0.01):
8
+ super(ConvBlockRes, self).__init__()
9
+ self.conv = nn.Sequential(
10
+ nn.Conv2d(in_channels=in_channels,
11
+ out_channels=out_channels,
12
+ kernel_size=(3, 3),
13
+ stride=(1, 1),
14
+ padding=(1, 1),
15
+ bias=False),
16
+ nn.BatchNorm2d(out_channels, momentum=momentum),
17
+ nn.ReLU(),
18
+
19
+ nn.Conv2d(in_channels=out_channels,
20
+ out_channels=out_channels,
21
+ kernel_size=(3, 3),
22
+ stride=(1, 1),
23
+ padding=(1, 1),
24
+ bias=False),
25
+ nn.BatchNorm2d(out_channels, momentum=momentum),
26
+ nn.ReLU(),
27
+ )
28
+ if in_channels != out_channels:
29
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
30
+ self.is_shortcut = True
31
+ else:
32
+ self.is_shortcut = False
33
+
34
+ def forward(self, x):
35
+ if self.is_shortcut:
36
+ return self.conv(x) + self.shortcut(x)
37
+ else:
38
+ return self.conv(x) + x
39
+
40
+
41
+ class ResEncoderBlock(nn.Module):
42
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
43
+ super(ResEncoderBlock, self).__init__()
44
+ self.n_blocks = n_blocks
45
+ self.conv = nn.ModuleList()
46
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
47
+ for i in range(n_blocks - 1):
48
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
49
+ self.kernel_size = kernel_size
50
+ if self.kernel_size is not None:
51
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
52
+
53
+ def forward(self, x):
54
+ for i in range(self.n_blocks):
55
+ x = self.conv[i](x)
56
+ if self.kernel_size is not None:
57
+ return x, self.pool(x)
58
+ else:
59
+ return x
60
+
61
+
62
+ class ResDecoderBlock(nn.Module):
63
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
64
+ super(ResDecoderBlock, self).__init__()
65
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
66
+ self.n_blocks = n_blocks
67
+ self.conv1 = nn.Sequential(
68
+ nn.ConvTranspose2d(in_channels=in_channels,
69
+ out_channels=out_channels,
70
+ kernel_size=(3, 3),
71
+ stride=stride,
72
+ padding=(1, 1),
73
+ output_padding=out_padding,
74
+ bias=False),
75
+ nn.BatchNorm2d(out_channels, momentum=momentum),
76
+ nn.ReLU(),
77
+ )
78
+ self.conv2 = nn.ModuleList()
79
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
80
+ for i in range(n_blocks-1):
81
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
82
+
83
+ def forward(self, x, concat_tensor):
84
+ x = self.conv1(x)
85
+ x = torch.cat((x, concat_tensor), dim=1)
86
+ for i in range(self.n_blocks):
87
+ x = self.conv2[i](x)
88
+ return x
89
+
90
+
91
+ class Encoder(nn.Module):
92
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
93
+ super(Encoder, self).__init__()
94
+ self.n_encoders = n_encoders
95
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
96
+ self.layers = nn.ModuleList()
97
+ self.latent_channels = []
98
+ for i in range(self.n_encoders):
99
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
100
+ self.latent_channels.append([out_channels, in_size])
101
+ in_channels = out_channels
102
+ out_channels *= 2
103
+ in_size //= 2
104
+ self.out_size = in_size
105
+ self.out_channel = out_channels
106
+
107
+ def forward(self, x):
108
+ concat_tensors = []
109
+ x = self.bn(x)
110
+ for i in range(self.n_encoders):
111
+ _, x = self.layers[i](x)
112
+ concat_tensors.append(_)
113
+ return x, concat_tensors
114
+
115
+
116
+ class Intermediate(nn.Module):
117
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
118
+ super(Intermediate, self).__init__()
119
+ self.n_inters = n_inters
120
+ self.layers = nn.ModuleList()
121
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
122
+ for i in range(self.n_inters-1):
123
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
124
+
125
+ def forward(self, x):
126
+ for i in range(self.n_inters):
127
+ x = self.layers[i](x)
128
+ return x
129
+
130
+
131
+ class Decoder(nn.Module):
132
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
133
+ super(Decoder, self).__init__()
134
+ self.layers = nn.ModuleList()
135
+ self.n_decoders = n_decoders
136
+ for i in range(self.n_decoders):
137
+ out_channels = in_channels // 2
138
+ self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
139
+ in_channels = out_channels
140
+
141
+ def forward(self, x, concat_tensors):
142
+ for i in range(self.n_decoders):
143
+ x = self.layers[i](x, concat_tensors[-1-i])
144
+ return x
145
+
146
+
147
+ class TimbreFilter(nn.Module):
148
+ def __init__(self, latent_rep_channels):
149
+ super(TimbreFilter, self).__init__()
150
+ self.layers = nn.ModuleList()
151
+ for latent_rep in latent_rep_channels:
152
+ self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
153
+
154
+ def forward(self, x_tensors):
155
+ out_tensors = []
156
+ for i, layer in enumerate(self.layers):
157
+ out_tensors.append(layer(x_tensors[i]))
158
+ return out_tensors
159
+
160
+
161
+ class DeepUnet(nn.Module):
162
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
163
+ super(DeepUnet, self).__init__()
164
+ self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
165
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
166
+ self.tf = TimbreFilter(self.encoder.latent_channels)
167
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
168
+
169
+ def forward(self, x):
170
+ x, concat_tensors = self.encoder(x)
171
+ x = self.intermediate(x)
172
+ concat_tensors = self.tf(concat_tensors)
173
+ x = self.decoder(x, concat_tensors)
174
+ return x
175
+
176
+
177
+ class DeepUnet0(nn.Module):
178
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
179
+ super(DeepUnet0, self).__init__()
180
+ self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
181
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
182
+ self.tf = TimbreFilter(self.encoder.latent_channels)
183
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
184
+
185
+ def forward(self, x):
186
+ x, concat_tensors = self.encoder(x)
187
+ x = self.intermediate(x)
188
+ x = self.decoder(x, concat_tensors)
189
+ return x
rift_svc/rmvpe/inference.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchaudio.transforms import Resample
5
+ from .constants import *
6
+ from .model import E2E0, E2E
7
+ from .spec import MelSpectrogram
8
+ from .utils import to_local_average_f0, to_viterbi_f0
9
+
10
+ class RMVPE:
11
+ def __init__(self, model_path, hop_length=160, device='cpu'):
12
+ self.resample_kernel = {}
13
+ model = E2E0(4, 1, (2, 2))
14
+ ckpt = torch.load(model_path, weights_only=True)
15
+ model.load_state_dict(ckpt['model'], strict=False)
16
+ model.eval()
17
+ self.model = model
18
+ self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX)
19
+ self.resample_kernel = {}
20
+ self.model = self.model.to(device)
21
+ self.mel_extractor = self.mel_extractor.to(device)
22
+
23
+ def mel2hidden(self, mel):
24
+ with torch.no_grad():
25
+ n_frames = mel.shape[-1]
26
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant')
27
+ hidden = self.model(mel)
28
+ return hidden[:, :n_frames]
29
+
30
+ def decode(self, hidden, thred=0.03, use_viterbi=False):
31
+ if use_viterbi:
32
+ f0 = to_viterbi_f0(hidden, thred=thred)
33
+ else:
34
+ f0 = to_local_average_f0(hidden, thred=thred)
35
+ return f0
36
+
37
+ def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, use_viterbi=False):
38
+ #audio = torch.from_numpy(audio).float().unsqueeze(0).to(device)
39
+ if sample_rate == 16000:
40
+ audio_res = audio
41
+ else:
42
+ key_str = str(sample_rate)
43
+ if key_str not in self.resample_kernel:
44
+ self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
45
+ self.resample_kernel[key_str] = self.resample_kernel[key_str].to(device)
46
+ audio_res = self.resample_kernel[key_str](audio)
47
+
48
+ mel = self.mel_extractor(audio_res, center=True)
49
+ hidden = self.mel2hidden(mel)
50
+ f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi)
51
+ return f0
rift_svc/rmvpe/model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .deepunet import DeepUnet, DeepUnet0
4
+ from .constants import *
5
+ from .spec import MelSpectrogram
6
+ from .seq import BiGRU
7
+
8
+
9
+ class E2E(nn.Module):
10
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
11
+ en_out_channels=16):
12
+ super(E2E, self).__init__()
13
+ self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
14
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
15
+ if n_gru:
16
+ self.fc = nn.Sequential(
17
+ BiGRU(3 * N_MELS, 256, n_gru),
18
+ nn.Linear(512, N_CLASS),
19
+ nn.Dropout(0.25),
20
+ nn.Sigmoid()
21
+ )
22
+ else:
23
+ self.fc = nn.Sequential(
24
+ nn.Linear(3 * N_MELS, N_CLASS),
25
+ nn.Dropout(0.25),
26
+ nn.Sigmoid()
27
+ )
28
+
29
+ def forward(self, mel):
30
+ mel = mel.transpose(-1, -2).unsqueeze(1)
31
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
32
+ x = self.fc(x)
33
+ return x
34
+
35
+
36
+ class E2E0(nn.Module):
37
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
38
+ en_out_channels=16):
39
+ super(E2E0, self).__init__()
40
+ self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
41
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
42
+ if n_gru:
43
+ self.fc = nn.Sequential(
44
+ BiGRU(3 * N_MELS, 256, n_gru),
45
+ nn.Linear(512, N_CLASS),
46
+ nn.Dropout(0.25),
47
+ nn.Sigmoid()
48
+ )
49
+ else:
50
+ self.fc = nn.Sequential(
51
+ nn.Linear(3 * N_MELS, N_CLASS),
52
+ nn.Dropout(0.25),
53
+ nn.Sigmoid()
54
+ )
55
+
56
+ def forward(self, mel):
57
+ mel = mel.transpose(-1, -2).unsqueeze(1)
58
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
59
+ x = self.fc(x)
60
+ return x
rift_svc/rmvpe/seq.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class BiGRU(nn.Module):
5
+ def __init__(self, input_features, hidden_features, num_layers):
6
+ super(BiGRU, self).__init__()
7
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
8
+
9
+ def forward(self, x):
10
+ return self.gru(x)[0]
11
+
12
+
13
+ class BiLSTM(nn.Module):
14
+ def __init__(self, input_features, hidden_features, num_layers):
15
+ super(BiLSTM, self).__init__()
16
+ self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
17
+
18
+ def forward(self, x):
19
+ return self.lstm(x)[0]
20
+
rift_svc/rmvpe/spec.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from librosa.filters import mel
5
+
6
+ class MelSpectrogram(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ n_mel_channels,
10
+ sampling_rate,
11
+ win_length,
12
+ hop_length,
13
+ n_fft=None,
14
+ mel_fmin=0,
15
+ mel_fmax=None,
16
+ clamp = 1e-5
17
+ ):
18
+ super().__init__()
19
+ n_fft = win_length if n_fft is None else n_fft
20
+ self.hann_window = {}
21
+ mel_basis = mel(
22
+ sr=sampling_rate,
23
+ n_fft=n_fft,
24
+ n_mels=n_mel_channels,
25
+ fmin=mel_fmin,
26
+ fmax=mel_fmax,
27
+ htk=True)
28
+ mel_basis = torch.from_numpy(mel_basis).float()
29
+ self.register_buffer("mel_basis", mel_basis)
30
+ self.n_fft = win_length if n_fft is None else n_fft
31
+ self.hop_length = hop_length
32
+ self.win_length = win_length
33
+ self.sampling_rate = sampling_rate
34
+ self.n_mel_channels = n_mel_channels
35
+ self.clamp = clamp
36
+
37
+ def forward(self, audio, keyshift=0, speed=1, center=True):
38
+ factor = 2 ** (keyshift / 12)
39
+ n_fft_new = int(np.round(self.n_fft * factor))
40
+ win_length_new = int(np.round(self.win_length * factor))
41
+ hop_length_new = int(np.round(self.hop_length * speed))
42
+
43
+ keyshift_key = str(keyshift)+'_'+str(audio.device)
44
+ if keyshift_key not in self.hann_window:
45
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
46
+
47
+ fft = torch.stft(
48
+ audio,
49
+ n_fft=n_fft_new,
50
+ hop_length=hop_length_new,
51
+ win_length=win_length_new,
52
+ window=self.hann_window[keyshift_key],
53
+ center=center,
54
+ return_complex=True)
55
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
56
+
57
+ if keyshift != 0:
58
+ size = self.n_fft // 2 + 1
59
+ resize = magnitude.size(1)
60
+ if resize < size:
61
+ magnitude = F.pad(magnitude, (0, 0, 0, size-resize))
62
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
63
+
64
+ mel_output = torch.matmul(self.mel_basis, magnitude)
65
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
66
+ return log_mel_spec
rift_svc/rmvpe/utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import librosa
4
+ import torch
5
+ from functools import reduce
6
+ from .constants import *
7
+ from torch.nn.modules.module import _addindent
8
+
9
+
10
+ def cycle(iterable):
11
+ while True:
12
+ for item in iterable:
13
+ yield item
14
+
15
+
16
+ def summary(model, file=sys.stdout):
17
+ def repr(model):
18
+ # We treat the extra repr like the sub-module, one item per line
19
+ extra_lines = []
20
+ extra_repr = model.extra_repr()
21
+ # empty string will be split into list ['']
22
+ if extra_repr:
23
+ extra_lines = extra_repr.split('\n')
24
+ child_lines = []
25
+ total_params = 0
26
+ for key, module in model._modules.items():
27
+ mod_str, num_params = repr(module)
28
+ mod_str = _addindent(mod_str, 2)
29
+ child_lines.append('(' + key + '): ' + mod_str)
30
+ total_params += num_params
31
+ lines = extra_lines + child_lines
32
+
33
+ for name, p in model._parameters.items():
34
+ if hasattr(p, 'shape'):
35
+ total_params += reduce(lambda x, y: x * y, p.shape)
36
+
37
+ main_str = model._get_name() + '('
38
+ if lines:
39
+ # simple one-liner info, which most builtin Modules will use
40
+ if len(extra_lines) == 1 and not child_lines:
41
+ main_str += extra_lines[0]
42
+ else:
43
+ main_str += '\n ' + '\n '.join(lines) + '\n'
44
+
45
+ main_str += ')'
46
+ if file is sys.stdout:
47
+ main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
48
+ else:
49
+ main_str += ', {:,} params'.format(total_params)
50
+ return main_str, total_params
51
+
52
+ string, count = repr(model)
53
+ if file is not None:
54
+ if isinstance(file, str):
55
+ file = open(file, 'w')
56
+ print(string, file=file)
57
+ file.flush()
58
+
59
+ return count
60
+
61
+
62
+ def to_local_average_cents(salience, center=None, thred=0.03):
63
+ """
64
+ find the weighted average cents near the argmax bin
65
+ """
66
+
67
+ if not hasattr(to_local_average_cents, 'cents_mapping'):
68
+ # the bin number-to-cents mapping
69
+ to_local_average_cents.cents_mapping = (
70
+ 20 * np.arange(N_CLASS) + CONST)
71
+
72
+ if salience.ndim == 1:
73
+ if center is None:
74
+ center = int(np.argmax(salience))
75
+ start = max(0, center - 4)
76
+ end = min(len(salience), center + 5)
77
+ salience = salience[start:end]
78
+ product_sum = np.sum(
79
+ salience * to_local_average_cents.cents_mapping[start:end])
80
+ weight_sum = np.sum(salience)
81
+ return product_sum / weight_sum if np.max(salience) > thred else 0
82
+ if salience.ndim == 2:
83
+ return np.array([to_local_average_cents(salience[i, :], None, thred) for i in
84
+ range(salience.shape[0])])
85
+
86
+ raise Exception("label should be either 1d or 2d ndarray")
87
+
88
+ def to_viterbi_cents(salience, thred=0.03):
89
+ # Create viterbi transition matrix
90
+ if not hasattr(to_viterbi_cents, 'transition'):
91
+ xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
92
+ transition = np.maximum(30 - abs(xx - yy), 0)
93
+ transition = transition / transition.sum(axis=1, keepdims=True)
94
+ to_viterbi_cents.transition = transition
95
+
96
+ # Convert to probability
97
+ prob = salience.T
98
+ prob = prob / prob.sum(axis=0)
99
+
100
+ # Perform viterbi decoding
101
+ path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64)
102
+
103
+ return np.array([to_local_average_cents(salience[i, :], path[i], thred) for i in
104
+ range(len(path))])
105
+
106
+ def to_local_average_f0(hidden, center=None, thred=0.03):
107
+ idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N]
108
+ idx_cents = idx * 20 + CONST # [B=1, N]
109
+ if center is None:
110
+ center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
111
+ start = torch.clip(center - 4, min=0) # [B, T, 1]
112
+ end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1]
113
+ idx_mask = (idx >= start) & (idx < end) # [B, T, N]
114
+ weights = hidden * idx_mask # [B, T, N]
115
+ product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T]
116
+ weight_sum = torch.sum(weights, dim=2) # [B, T]
117
+ cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
118
+ f0 = 10 * 2 ** (cents / 1200)
119
+ uv = hidden.max(dim=2)[0] < thred # [B, T]
120
+ f0 = f0 * ~uv
121
+ return f0.squeeze(0).cpu().numpy()
122
+
123
+ def to_viterbi_f0(hidden, thred=0.03):
124
+ # Create viterbi transition matrix
125
+ if not hasattr(to_viterbi_cents, 'transition'):
126
+ xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
127
+ transition = np.maximum(30 - abs(xx - yy), 0)
128
+ transition = transition / transition.sum(axis=1, keepdims=True)
129
+ to_viterbi_cents.transition = transition
130
+
131
+ # Convert to probability
132
+ prob = hidden.squeeze(0).cpu().numpy()
133
+ prob = prob.T
134
+ prob = prob / prob.sum(axis=0)
135
+
136
+ # Perform viterbi decoding
137
+ path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64)
138
+ center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device)
139
+
140
+ return to_local_average_f0(hidden, center=center, thred=thred)
141
+
142
+
rift_svc/utils.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import random
4
+ import time
5
+ from typing import Any
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from jaxtyping import Bool, Int
12
+ from PIL import Image
13
+ from pytorch_lightning.callbacks import TQDMProgressBar
14
+ import parselmouth as pm
15
+ import librosa
16
+ import pyworld as pw
17
+
18
+
19
+ def seed_everything(seed: int = 0):
20
+ random.seed(seed)
21
+ os.environ['PYTHONHASHSEED'] = str(seed)
22
+ torch.manual_seed(seed)
23
+ torch.cuda.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ torch.backends.cudnn.deterministic = True
26
+ torch.backends.cudnn.benchmark = False
27
+
28
+ # helpers
29
+
30
+ def exists(v: Any) -> bool:
31
+ return v is not None
32
+
33
+ def default(v: Any, d: Any) -> Any:
34
+ return v if exists(v) else d
35
+
36
+
37
+ def draw_mel_specs(gt: np.ndarray, gen: np.ndarray, diff: np.ndarray, cache_path: str):
38
+ vmin = min(gt.min(), gen.min())
39
+ vmax = max(gt.max(), gen.max())
40
+
41
+ # Create figure with space for colorbar
42
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 15), sharex=True, gridspec_kw={'hspace': 0})
43
+
44
+ # Plot all spectrograms with the same scale
45
+ im1 = ax1.imshow(gt, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
46
+ ax1.set_ylabel('GT', fontsize=14)
47
+ ax1.set_xticks([])
48
+
49
+ im2 = ax2.imshow(gen, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
50
+ ax2.set_ylabel('Gen', fontsize=14)
51
+ ax2.set_xticks([])
52
+
53
+ # Find symmetric limits for difference plot
54
+ diff_abs_max = max(abs(diff.min()), abs(diff.max()))
55
+
56
+ im3 = ax3.imshow(diff, origin='lower', aspect='auto',
57
+ cmap='RdBu_r', # Red-White-Blue colormap (reversed)
58
+ vmin=-diff_abs_max, vmax=diff_abs_max)
59
+ ax3.set_ylabel('Diff', fontsize=14)
60
+
61
+ fig.colorbar(im1, ax=[ax1, ax2], location='right', label='Magnitude')
62
+ fig.colorbar(im3, ax=[ax3], location='right', label='Difference')
63
+
64
+ buf = io.BytesIO()
65
+ plt.savefig(buf, format='png', bbox_inches='tight')
66
+ plt.close()
67
+ buf.seek(0)
68
+
69
+ # Open with PIL and save as compressed JPEG
70
+ img = Image.open(buf)
71
+ img = img.convert('RGB')
72
+ img.save(cache_path, 'JPEG', quality=85, optimize=True)
73
+ buf.close()
74
+
75
+
76
+
77
+ # tensor helpers
78
+
79
+ def lens_to_mask(
80
+ t: Int[torch.Tensor, "b"],
81
+ length: int | None = None
82
+ ) -> Bool[torch.Tensor, "b n"]:
83
+
84
+ if not exists(length):
85
+ length = t.amax()
86
+
87
+ seq = torch.arange(length, device = t.device)
88
+ return seq < t[..., None]
89
+
90
+
91
+ def l2_grad_norm(model: torch.nn.Module):
92
+ return torch.cat([p.grad.data.flatten() for p in model.parameters() if p.grad is not None]).norm(2)
93
+
94
+
95
+ def nearest_interpolate_tensor(tensor, new_size):
96
+ # Add two dummy dimensions to make it [1, 1, n, d]
97
+ tensor = tensor.unsqueeze(0).unsqueeze(0)
98
+
99
+ # Interpolate
100
+ interpolated = F.interpolate(tensor, size=(new_size, tensor.shape[-1]), mode='nearest')
101
+
102
+ # Remove the dummy dimensions
103
+ interpolated = interpolated.squeeze(0).squeeze(0)
104
+
105
+ return interpolated
106
+
107
+
108
+ def linear_interpolate_tensor(tensor, new_size):
109
+ # Assumes input tensor shape is [n, d]
110
+ # Rearrange tensor to shape [1, d, n] to prepare for linear interpolation
111
+ tensor = tensor.transpose(0, 1).unsqueeze(0)
112
+
113
+ # Interpolate along the length dimension (last dimension) using linear interpolation.
114
+ # align_corners=True preserves the boundary values; adjust this flag if needed.
115
+ interpolated = F.interpolate(tensor, size=new_size, mode='linear', align_corners=True)
116
+
117
+ # Restore the tensor to shape [new_size, d]
118
+ return interpolated.squeeze(0).transpose(0, 1)
119
+
120
+
121
+ # f0 helpers
122
+
123
+ def post_process_f0(f0, sample_rate, hop_length, n_frames, silence_front=0.0, cut_last=True):
124
+ """
125
+ Post-process the extracted f0 to align with Mel spectrogram frames.
126
+
127
+ Args:
128
+ f0 (numpy.ndarray): Extracted f0 array.
129
+ sample_rate (int): Sample rate of the audio.
130
+ hop_length (int): Hop length used during processing.
131
+ n_frames (int): Total number of frames (for alignment).
132
+ silence_front (float): Seconds of silence to remove from the front.
133
+
134
+ Returns:
135
+ numpy.ndarray: Processed f0 array aligned with Mel spectrogram frames.
136
+ """
137
+ # Calculate number of frames to skip based on silence_front
138
+ start_frame = int(silence_front * sample_rate / hop_length)
139
+ real_silence_front = start_frame * hop_length / sample_rate
140
+ # Assuming silence_front has been handled during RMVPE inference if needed
141
+
142
+ # Handle unvoiced frames by interpolation
143
+ uv = f0 == 0
144
+ if np.any(~uv):
145
+ f0_interp = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
146
+ f0[uv] = f0_interp
147
+ else:
148
+ # If no voiced frames, set all to zero
149
+ f0 = np.zeros_like(f0)
150
+
151
+ # Align with hop_length frames
152
+ origin_time = 0.01 * np.arange(len(f0)) # Placeholder: Adjust based on RMVPE's timing
153
+ target_time = hop_length / sample_rate * np.arange(n_frames - start_frame)
154
+ f0 = np.interp(target_time, origin_time, f0)
155
+ uv = np.interp(target_time, origin_time, uv.astype(float)) > 0.5
156
+ f0[uv] = 0
157
+
158
+ # Pad the silence_front if needed
159
+ f0 = np.pad(f0, (start_frame, 0), mode='constant')
160
+
161
+ if cut_last:
162
+ return f0[:-1]
163
+ else:
164
+ return f0
165
+
166
+ # pyworld
167
+ def get_f0_pw(audio, sr, time_step, f0_min, f0_max):
168
+ pw_pre_f0, times = pw.dio(
169
+ audio.astype(np.double), sr,
170
+ f0_floor=f0_min, f0_ceil=f0_max,
171
+ frame_period=time_step*1000) # raw pitch extractor
172
+ pw_post_f0 = pw.stonemask(audio.astype(np.double), pw_pre_f0, times, sr) # pitch refinement
173
+ pw_post_f0[pw_post_f0==0] = np.nan
174
+ pw_post_f0 = slide_nanmedian(pw_post_f0, 3)
175
+ return pw_post_f0
176
+
177
+ # parselmouth
178
+ def get_f0_pm(audio, sr, time_step, f0_min, f0_max):
179
+ pmac_pitch = pm.Sound(audio, sampling_frequency=sr).to_pitch_ac(
180
+ time_step=time_step, voicing_threshold=0.6,
181
+ pitch_floor=f0_min, pitch_ceiling=f0_max,
182
+ very_accurate=True, octave_jump_cost=0.5)
183
+ pmac_f0 = pmac_pitch.selected_array['frequency']
184
+ pmac_f0[pmac_f0==0] = np.nan
185
+ pmac_f0 = slide_nanmedian(pmac_f0, 3)
186
+ return pmac_f0
187
+
188
+ from numba import njit
189
+ @njit
190
+ def slide_nanmedian(signals=np.array([]), win_length=3):
191
+ """Filters a sequence, ignoring nan values
192
+
193
+ Arguments
194
+ signals (numpy.ndarray (shape=(time)))
195
+ The signals to filter
196
+ win_length
197
+ The size of the analysis window
198
+
199
+ Returns
200
+ filtered (numpy.ndarray (shape=(time)))
201
+ """
202
+ # Output buffer
203
+ filtered = np.empty_like(signals)
204
+
205
+ # Loop over frames
206
+ for i in range(signals.shape[0]):
207
+
208
+ # Get analysis window bounds
209
+ start = max(0, i - win_length // 2)
210
+ end = min(signals.shape[0], i + win_length // 2 + 1)
211
+
212
+ # Apply filter to window
213
+ filtered[i] = np.nanmedian(signals[start:end])
214
+
215
+ return filtered
216
+
217
+
218
+ def f0_ensemble(rmvpe_f0, pw_f0, pmac_f0):
219
+ trunc_len = len(rmvpe_f0)
220
+ pw_f0 = pw_f0[:trunc_len]
221
+ # pad pmac_f0
222
+ pmac_f0 = np.concatenate(
223
+ [pmac_f0, np.full(len(pw_f0)-len(pmac_f0), np.nan, dtype=pmac_f0.dtype)])
224
+
225
+ stack_f0 = np.stack([pw_f0, pmac_f0, rmvpe_f0], axis=0)
226
+
227
+ meadian_f0 = np.nanmedian(stack_f0, axis=0)
228
+ nan_nums = np.sum(np.isnan(stack_f0), axis=0)
229
+ meadian_f0[nan_nums>=2] = np.nan
230
+
231
+ slide_meadian_f0 = slide_nanmedian(meadian_f0, 41)
232
+
233
+ f0_dev = np.abs(meadian_f0-slide_meadian_f0)
234
+ meadian_f0[f0_dev>96] = slide_meadian_f0[f0_dev>96]
235
+
236
+ nan1_f0_min = np.nanmin(stack_f0[:, nan_nums==1], axis=0)
237
+ nan1_f0_max = np.nanmax(stack_f0[:, nan_nums==1], axis=0)
238
+
239
+ nan1_f0 = np.where(
240
+ np.abs(nan1_f0_min-slide_meadian_f0[nan_nums==1])<np.abs(nan1_f0_max-slide_meadian_f0[nan_nums==1]),
241
+ nan1_f0_min, nan1_f0_max)
242
+ meadian_f0[nan_nums==1] = nan1_f0
243
+
244
+ meadian_f0 = slide_nanmedian(meadian_f0, 3)
245
+ meadian_f0[nan_nums>=2] = np.nan
246
+ meadian_f0[np.isnan(meadian_f0)] = 0
247
+
248
+ return meadian_f0
249
+
250
+
251
+ def f0_ensemble_light(rmvpe_f0, pw_f0, pmac_f0, rms=None, rms_threshold=0.05):
252
+ """
253
+ A lighter version of f0 ensemble that preserves RMVPE's expressiveness.
254
+ Only applies corrections when RMVPE shows abnormalities.
255
+
256
+ Args:
257
+ rmvpe_f0 (numpy.ndarray): F0 from RMVPE
258
+ pw_f0 (numpy.ndarray): F0 from WORLD
259
+ pmac_f0 (numpy.ndarray): F0 from Parselmouth
260
+ rms (numpy.ndarray, optional): RMS energy values, used to detect voiced segments
261
+ rms_threshold (float, optional): Threshold for RMS to consider a segment as voiced
262
+
263
+ Returns:
264
+ numpy.ndarray: Corrected F0 values
265
+ """
266
+ trunc_len = len(rmvpe_f0)
267
+ pw_f0 = pw_f0[:trunc_len]
268
+
269
+ # Pad pmac_f0 if needed
270
+ pmac_f0 = np.concatenate(
271
+ [pmac_f0, np.full(max(0, len(pw_f0)-len(pmac_f0)), np.nan, dtype=pmac_f0.dtype)])
272
+
273
+ # Create a copy of rmvpe_f0 to preserve most of its values
274
+ corrected_f0 = rmvpe_f0.copy()
275
+
276
+ # Stack all F0 values
277
+ stack_f0 = np.stack([pw_f0, pmac_f0, rmvpe_f0], axis=0)
278
+
279
+ # Count non-NaN values for each frame
280
+ valid_count = np.sum(~np.isnan(stack_f0), axis=0)
281
+
282
+ # Identify frames where RMVPE shows zero but other methods detect pitch
283
+ zero_rmvpe_mask = (rmvpe_f0 == 0)
284
+
285
+ # For frames where RMVPE is zero but at least one other method has a valid F0
286
+ # and there's voice activity (if RMS is provided)
287
+ other_methods_valid = ((~np.isnan(pw_f0) & (pw_f0 > 0)) |
288
+ (~np.isnan(pmac_f0) & (pmac_f0 > 0)))
289
+
290
+ correction_mask = zero_rmvpe_mask & other_methods_valid
291
+
292
+ # If RMS is provided, only correct frames with voice activity
293
+ if rms is not None:
294
+ voice_activity = rms > rms_threshold
295
+ correction_mask = correction_mask & voice_activity
296
+
297
+ # For frames needing correction, use median of available values
298
+ if np.any(correction_mask):
299
+ # For each frame needing correction, calculate median of non-NaN values
300
+ for i in np.where(correction_mask)[0]:
301
+ valid_values = stack_f0[:, i][~np.isnan(stack_f0[:, i]) & (stack_f0[:, i] > 0)]
302
+ if len(valid_values) > 0:
303
+ corrected_f0[i] = np.median(valid_values)
304
+
305
+ # Handle any remaining NaN values
306
+ corrected_f0[np.isnan(corrected_f0)] = 0
307
+
308
+ return corrected_f0
309
+
310
+
311
+ # progress bar helper
312
+
313
+ class CustomProgressBar(TQDMProgressBar):
314
+ def __init__(self):
315
+ super().__init__()
316
+ self.start_time = None
317
+ self.step_start_time = None
318
+ self.total_steps = None
319
+
320
+ def on_train_start(self, trainer, pl_module):
321
+ super().on_train_start(trainer, pl_module)
322
+ self.start_time = time.time()
323
+ self.total_steps = trainer.max_steps
324
+
325
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
326
+ super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
327
+
328
+ current_step = trainer.global_step
329
+ total_steps = self.total_steps
330
+
331
+ # Calculate elapsed time since training started
332
+ elapsed_time = time.time() - self.start_time
333
+
334
+ # Estimate average step time and remaining time
335
+ average_step_time = elapsed_time / current_step if current_step > 0 else 0
336
+ remaining_steps = total_steps - current_step
337
+ remaining_time = average_step_time * remaining_steps if total_steps > 0 else 0
338
+
339
+ # Format times with no leading zeros for hours
340
+ def format_time(seconds):
341
+ hours = int(seconds // 3600)
342
+ minutes = int((seconds % 3600) // 60)
343
+ seconds = int(seconds % 60)
344
+ return f"{hours}:{minutes:02d}:{seconds:02d}"
345
+
346
+ elapsed_time_str = format_time(elapsed_time)
347
+ remaining_time_str = format_time(remaining_time)
348
+
349
+ # Update the progress bar with loss, elapsed time, remaining time, and remaining steps
350
+ self.train_progress_bar.set_postfix({
351
+ "loss": f"{outputs['loss'].item():.4f}",
352
+ "elapsed_time": elapsed_time_str + "/" + remaining_time_str,
353
+ "remaining_steps": str(remaining_steps) + "/" + str(total_steps)
354
+ })
355
+
356
+
357
+ # state dict helpers
358
+
359
+ def load_state_dict(model, state_dict, strict=False):
360
+ """Load state dict while handling 'model.' prefix"""
361
+ if any(k.startswith('model.') for k in state_dict.keys()):
362
+ # Remove 'model.' prefix
363
+ state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
364
+ return model.load_state_dict(state_dict, strict=strict)
slicer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ import librosa
5
+
6
+ warnings.filterwarnings('ignore')
7
+
8
+ # Configure logging at the top of your slicer.py
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Slicer:
14
+ def __init__(self,
15
+ sr: int,
16
+ threshold: float = -30.,
17
+ min_length: int = 3000,
18
+ min_interval: int = 100,
19
+ hop_size: int = 20,
20
+ max_sil_kept: int = 5000):
21
+ if not min_length >= min_interval >= hop_size:
22
+ raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
23
+ if not max_sil_kept >= hop_size:
24
+ raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
25
+ min_interval = sr * min_interval / 1000
26
+ self.sr = sr
27
+ self.threshold = 10 ** (threshold / 20.)
28
+ self.hop_size = round(sr * hop_size / 1000)
29
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
30
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
31
+ self.min_interval = round(min_interval / self.hop_size)
32
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
33
+
34
+
35
+ def _apply_slice(self, waveform, begin, end):
36
+ if len(waveform.shape) > 1:
37
+ return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
38
+ else:
39
+ return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
40
+
41
+
42
+ def slice(self, waveform):
43
+ if len(waveform.shape) > 1:
44
+ samples = librosa.to_mono(waveform)
45
+ else:
46
+ samples = waveform
47
+ if samples.shape[0] <= self.min_length:
48
+ # Return the entire audio as a single chunk
49
+ return [(0, waveform)]
50
+
51
+ rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
52
+ sil_tags = []
53
+ silence_start = None
54
+ clip_start = 0
55
+ for i, rms in enumerate(rms_list):
56
+ # Keep looping while frame is silent.
57
+ if rms < self.threshold:
58
+ # Record start of silent frames.
59
+ if silence_start is None:
60
+ silence_start = i
61
+ continue
62
+ # Keep looping while frame is not silent and silence start has not been recorded.
63
+ if silence_start is None:
64
+ continue
65
+ # Clear recorded silence start if interval is not enough or clip is too short
66
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
67
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
68
+ if not is_leading_silence and not need_slice_middle:
69
+ silence_start = None
70
+ continue
71
+ # Need slicing. Record the range of silent frames to be removed.
72
+ if i - silence_start <= self.max_sil_kept:
73
+ pos = rms_list[silence_start: i + 1].argmin() + silence_start
74
+ if silence_start == 0:
75
+ sil_tags.append((0, pos))
76
+ else:
77
+ sil_tags.append((pos, pos))
78
+ clip_start = pos
79
+ elif i - silence_start <= self.max_sil_kept * 2:
80
+ pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
81
+ pos += i - self.max_sil_kept
82
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
83
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
84
+ if silence_start == 0:
85
+ sil_tags.append((0, pos_r))
86
+ clip_start = pos_r
87
+ else:
88
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
89
+ clip_start = max(pos_r, pos)
90
+ else:
91
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
92
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
93
+ if silence_start == 0:
94
+ sil_tags.append((0, pos_r))
95
+ else:
96
+ sil_tags.append((pos_l, pos_r))
97
+ clip_start = pos_r
98
+ silence_start = None
99
+
100
+ # Deal with trailing silence.
101
+ total_frames = rms_list.shape[0]
102
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
103
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
104
+ pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
105
+ sil_tags.append((pos, total_frames + 1))
106
+
107
+ # Apply and return slices.
108
+ if len(sil_tags) == 0:
109
+ # Return the entire audio as a single chunk if no silence detected
110
+ return [(0, waveform)]
111
+
112
+ # Extract non-silence chunks
113
+ non_silence_chunks = []
114
+
115
+ # Add first non-silence chunk if it exists
116
+ if sil_tags[0][0] > 0:
117
+ start_pos = 0
118
+ end_frame = sil_tags[0][0]
119
+ chunk = self._apply_slice(waveform, 0, end_frame)
120
+ non_silence_chunks.append((start_pos, chunk))
121
+
122
+ # Add middle non-silence chunks
123
+ for i in range(1, len(sil_tags)):
124
+ start_frame = sil_tags[i-1][1]
125
+ end_frame = sil_tags[i][0]
126
+ if start_frame < end_frame: # Only add if there's actual non-silence content
127
+ start_pos = start_frame * self.hop_size
128
+ chunk = self._apply_slice(waveform, start_frame, end_frame)
129
+ non_silence_chunks.append((start_pos, chunk))
130
+
131
+ # Add last non-silence chunk if it exists
132
+ if sil_tags[-1][1] * self.hop_size < len(waveform):
133
+ start_frame = sil_tags[-1][1]
134
+ start_pos = start_frame * self.hop_size
135
+ chunk = self._apply_slice(waveform, start_frame, total_frames)
136
+ non_silence_chunks.append((start_pos, chunk))
137
+
138
+ for i, (start_pos, chunk) in enumerate(non_silence_chunks):
139
+ # Calculate start and end times in seconds
140
+ start_time_sec = start_pos / self.sr
141
+ end_time_sec = start_pos / self.sr + len(chunk) / self.sr if len(chunk.shape) == 1 else start_pos / self.sr + chunk.shape[1] / self.sr
142
+ duration_sec = end_time_sec - start_time_sec
143
+
144
+ # Format start and end times as mm:ss
145
+ start_min, start_sec = divmod(start_time_sec, 60)
146
+ end_min, end_sec = divmod(end_time_sec, 60)
147
+
148
+ # Log the information
149
+ logger.info(f"Chunk {i}: Start={int(start_min):02d}:{start_sec:05.2f}, End={int(end_min):02d}:{end_sec:05.2f}, Duration={duration_sec:.2f}s")
150
+
151
+ return non_silence_chunks
152
+
153
+
154
+
155
+ def main():
156
+ import os.path
157
+ from argparse import ArgumentParser
158
+ import librosa
159
+ import soundfile
160
+ from pathlib import Path
161
+
162
+ parser = ArgumentParser()
163
+ parser.add_argument('audio', type=str, help='The audio file or directory to be sliced')
164
+ parser.add_argument('--out', type=str, help='Output directory of the sliced audio clips')
165
+ parser.add_argument('--db_thresh', type=float, required=False, default=-30,
166
+ help='The dB threshold for silence detection')
167
+ parser.add_argument('--min_length', type=int, required=False, default=3000,
168
+ help='The minimum milliseconds required for each sliced audio clip')
169
+ parser.add_argument('--min_interval', type=int, required=False, default=100,
170
+ help='The minimum milliseconds for a silence part to be sliced')
171
+ parser.add_argument('--hop_size', type=int, required=False, default=20,
172
+ help='Frame length in milliseconds')
173
+ parser.add_argument('--max_sil_kept', type=int, required=False, default=5000,
174
+ help='The maximum silence length kept around the sliced clip, presented in milliseconds')
175
+ args = parser.parse_args()
176
+
177
+ # Determine if the input is a file or directory
178
+ audio_path = Path(args.audio)
179
+ is_directory = audio_path.is_dir()
180
+
181
+ # Prepare output directory
182
+ out = args.out
183
+ if out is None:
184
+ if is_directory:
185
+ out = os.path.abspath(args.audio)
186
+ else:
187
+ out = os.path.dirname(os.path.abspath(args.audio))
188
+
189
+ if not os.path.exists(out):
190
+ os.makedirs(out)
191
+
192
+ # Audio file extensions to process
193
+ audio_extensions = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
194
+
195
+ # Process a single file or all files in a directory
196
+ if is_directory:
197
+ logger.info(f"Processing all audio files in directory: {args.audio}")
198
+ audio_files = []
199
+ for ext in audio_extensions:
200
+ audio_files.extend(list(audio_path.glob(f'*{ext}')))
201
+
202
+ if not audio_files:
203
+ logger.warning(f"No audio files found in {args.audio}")
204
+ return
205
+
206
+ logger.info(f"Found {len(audio_files)} audio files to process")
207
+ for audio_file in audio_files:
208
+ process_audio_file(audio_file, out, args)
209
+ else:
210
+ # Process a single audio file
211
+ logger.info(f"Processing single audio file: {args.audio}")
212
+ process_audio_file(audio_path, out, args)
213
+
214
+ def process_audio_file(audio_file, out_dir, args):
215
+ """Process a single audio file with the given parameters"""
216
+ import os.path
217
+ import librosa
218
+ import soundfile
219
+
220
+ try:
221
+ logger.info(f"Loading audio file: {audio_file}")
222
+ audio, sr = librosa.load(str(audio_file), sr=None, mono=False)
223
+
224
+ slicer = Slicer(
225
+ sr=sr,
226
+ threshold=args.db_thresh,
227
+ min_length=args.min_length,
228
+ min_interval=args.min_interval,
229
+ hop_size=args.hop_size,
230
+ max_sil_kept=args.max_sil_kept
231
+ )
232
+
233
+ # Get non-silence chunks with their positions
234
+ chunks_with_pos = slicer.slice(audio)
235
+
236
+ file_basename = os.path.basename(str(audio_file)).rsplit('.', maxsplit=1)[0]
237
+ logger.info(f"Saving {len(chunks_with_pos)} non-silence audio chunks from {file_basename}...")
238
+
239
+ for i, (pos, chunk) in enumerate(chunks_with_pos):
240
+ if len(chunk.shape) > 1:
241
+ chunk = chunk.T
242
+
243
+ output_file = os.path.join(out_dir, f'{file_basename}_{i}_pos_{pos}.wav')
244
+ soundfile.write(output_file, chunk, sr)
245
+
246
+ logger.info(f"Finished processing {audio_file}")
247
+ except Exception as e:
248
+ logger.error(f"Error processing {audio_file}: {str(e)}")
249
+
250
+
251
+ if __name__ == '__main__':
252
+ main()