Spaces:
Runtime error
Runtime error
prismleong
commited on
Commit
·
898b100
1
Parent(s):
5c61d84
init
Browse files- README.md +3 -3
- app.py +398 -0
- infer.py +493 -0
- pretrained/content-vec-best/.gitattributes +34 -0
- pretrained/content-vec-best/.gitignore +1 -0
- pretrained/content-vec-best/README.md +33 -0
- pretrained/content-vec-best/config.json +71 -0
- pretrained/content-vec-best/convert.py +150 -0
- pretrained/download.py +12 -0
- pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/NOTICE.txt +87 -0
- pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/NOTICE.zh-CN.txt +85 -0
- pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/config.json +56 -0
- pretrained/rmvpe/.gitkeep +0 -0
- requirements.txt +28 -0
- rift_svc/__init__.py +3 -0
- rift_svc/dataset.py +139 -0
- rift_svc/dit.py +227 -0
- rift_svc/feature_extractors.py +144 -0
- rift_svc/lightning_module.py +389 -0
- rift_svc/metrics.py +71 -0
- rift_svc/modules.py +261 -0
- rift_svc/nsf_hifigan/__init__.py +2 -0
- rift_svc/nsf_hifigan/env.py +15 -0
- rift_svc/nsf_hifigan/models.py +427 -0
- rift_svc/nsf_hifigan/nvSTFT.py +124 -0
- rift_svc/nsf_hifigan/utils.py +67 -0
- rift_svc/nsf_hifigan/vocoder.py +123 -0
- rift_svc/optim.py +103 -0
- rift_svc/rf.py +215 -0
- rift_svc/rmvpe/__init__.py +5 -0
- rift_svc/rmvpe/constants.py +9 -0
- rift_svc/rmvpe/deepunet.py +189 -0
- rift_svc/rmvpe/inference.py +51 -0
- rift_svc/rmvpe/model.py +60 -0
- rift_svc/rmvpe/seq.py +20 -0
- rift_svc/rmvpe/spec.py +66 -0
- rift_svc/rmvpe/utils.py +142 -0
- rift_svc/utils.py +364 -0
- slicer.py +252 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
-
title: RIFT
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
|
@@ -11,4 +11,4 @@ license: cc-by-nc-sa-4.0
|
|
| 11 |
short_description: https://github.com/Pur1zumu/RIFT-SVC
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: RIFT-SVC (七海Nanami demo)
|
| 3 |
+
emoji: 🎵
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
|
|
|
| 11 |
short_description: https://github.com/Pur1zumu/RIFT-SVC
|
| 12 |
---
|
| 13 |
|
| 14 |
+
|
app.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import tempfile
|
| 6 |
+
import gc
|
| 7 |
+
import traceback
|
| 8 |
+
from slicer import Slicer
|
| 9 |
+
|
| 10 |
+
from infer import (
|
| 11 |
+
load_models,
|
| 12 |
+
load_audio,
|
| 13 |
+
apply_fade,
|
| 14 |
+
process_segment
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Global variables for models
|
| 18 |
+
global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg, device
|
| 19 |
+
svc_model = vocoder = rmvpe = hubert = rms_extractor = spk2idx = dataset_cfg = None
|
| 20 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
+
|
| 22 |
+
# Set default model path
|
| 23 |
+
DEFAULT_MODEL_PATH = "pretrained/dit-768-12_nanami.ckpt"
|
| 24 |
+
|
| 25 |
+
# Maximum audio duration in seconds to avoid memory issues
|
| 26 |
+
MAX_AUDIO_DURATION = 300 # 5 minutes
|
| 27 |
+
|
| 28 |
+
def initialize_models(model_path=DEFAULT_MODEL_PATH):
|
| 29 |
+
global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
|
| 30 |
+
|
| 31 |
+
# Clean up memory before loading models
|
| 32 |
+
if svc_model is not None:
|
| 33 |
+
del svc_model
|
| 34 |
+
del vocoder
|
| 35 |
+
del rmvpe
|
| 36 |
+
del hubert
|
| 37 |
+
del rms_extractor
|
| 38 |
+
torch.cuda.empty_cache()
|
| 39 |
+
gc.collect()
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg = load_models(model_path, device)
|
| 43 |
+
available_speakers = list(spk2idx.keys())
|
| 44 |
+
return available_speakers, f"✅ 模型加载成功!可用说话人: {', '.join(available_speakers)}"
|
| 45 |
+
except Exception as e:
|
| 46 |
+
error_trace = traceback.format_exc()
|
| 47 |
+
return [], f"❌ 加载模型出错: {str(e)}\n\n详细信息: {error_trace}"
|
| 48 |
+
|
| 49 |
+
def check_audio_length(audio_path, max_duration=MAX_AUDIO_DURATION):
|
| 50 |
+
"""Check if audio file is too long to process safely"""
|
| 51 |
+
try:
|
| 52 |
+
info = torchaudio.info(audio_path)
|
| 53 |
+
duration = info.num_frames / info.sample_rate
|
| 54 |
+
return duration <= max_duration, duration
|
| 55 |
+
except Exception:
|
| 56 |
+
# If we can't determine the length, we'll try to process it anyway
|
| 57 |
+
return True, 0
|
| 58 |
+
|
| 59 |
+
def process_with_progress(
|
| 60 |
+
progress=gr.Progress(),
|
| 61 |
+
input_audio=None,
|
| 62 |
+
speaker=None,
|
| 63 |
+
key_shift=0,
|
| 64 |
+
infer_steps=32,
|
| 65 |
+
robust_f0=0,
|
| 66 |
+
# Advanced CFG parameters
|
| 67 |
+
ds_cfg_strength=0.05,
|
| 68 |
+
spk_cfg_strength=1.0,
|
| 69 |
+
skip_cfg_strength=0.0,
|
| 70 |
+
cfg_skip_layers=6,
|
| 71 |
+
cfg_rescale=0.7,
|
| 72 |
+
cvec_downsample_rate=2,
|
| 73 |
+
# Slicer parameters
|
| 74 |
+
slicer_threshold=-30.0,
|
| 75 |
+
slicer_min_length=3000,
|
| 76 |
+
slicer_min_interval=100,
|
| 77 |
+
slicer_hop_size=10,
|
| 78 |
+
slicer_max_sil_kept=200
|
| 79 |
+
):
|
| 80 |
+
global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
|
| 81 |
+
|
| 82 |
+
# Fixed target loudness value
|
| 83 |
+
target_loudness = -18.0
|
| 84 |
+
|
| 85 |
+
# Fixed audio parameters
|
| 86 |
+
restore_loudness = True
|
| 87 |
+
fade_duration = 20.0
|
| 88 |
+
sliced_inference = False
|
| 89 |
+
|
| 90 |
+
# Input validation
|
| 91 |
+
if input_audio is None:
|
| 92 |
+
return None, "❌ 错误: 未提供输入音频。"
|
| 93 |
+
|
| 94 |
+
if svc_model is None:
|
| 95 |
+
return None, "❌ 错误: 模型未加载。请重新加载页面或检查模型路径。"
|
| 96 |
+
|
| 97 |
+
if speaker is None or speaker not in spk2idx:
|
| 98 |
+
return None, f"❌ 错误: 无效的说话人选择。可用说话人: {', '.join(spk2idx.keys())}"
|
| 99 |
+
|
| 100 |
+
# Check audio length to avoid memory issues
|
| 101 |
+
is_safe_length, duration = check_audio_length(input_audio)
|
| 102 |
+
if not is_safe_length:
|
| 103 |
+
return None, f"❌ 错误: 音频过长 ({duration:.1f} 秒)。允许的最大时长为 {MAX_AUDIO_DURATION} 秒。"
|
| 104 |
+
|
| 105 |
+
# Process the audio
|
| 106 |
+
try:
|
| 107 |
+
# Update status message
|
| 108 |
+
progress(0, desc="处理中: 加载音频...")
|
| 109 |
+
|
| 110 |
+
# Convert speaker name to ID
|
| 111 |
+
speaker_id = spk2idx[speaker]
|
| 112 |
+
|
| 113 |
+
# Get config from loaded model
|
| 114 |
+
hop_length = 512
|
| 115 |
+
sample_rate = 44100
|
| 116 |
+
|
| 117 |
+
# Load audio
|
| 118 |
+
audio = load_audio(input_audio, sample_rate)
|
| 119 |
+
|
| 120 |
+
# Initialize Slicer
|
| 121 |
+
slicer = Slicer(
|
| 122 |
+
sr=sample_rate,
|
| 123 |
+
threshold=slicer_threshold,
|
| 124 |
+
min_length=slicer_min_length,
|
| 125 |
+
min_interval=slicer_min_interval,
|
| 126 |
+
hop_size=slicer_hop_size,
|
| 127 |
+
max_sil_kept=slicer_max_sil_kept
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
progress(0.1, desc="处理中: 切分音频...")
|
| 131 |
+
# Slice the input audio
|
| 132 |
+
segments_with_pos = slicer.slice(audio)
|
| 133 |
+
|
| 134 |
+
if not segments_with_pos:
|
| 135 |
+
return None, "❌ 错误: 在输入文件中未找到有效的音频片段。"
|
| 136 |
+
|
| 137 |
+
# Calculate fade size in samples
|
| 138 |
+
fade_samples = int(fade_duration * sample_rate / 1000)
|
| 139 |
+
|
| 140 |
+
# Process segments
|
| 141 |
+
result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
|
| 142 |
+
|
| 143 |
+
progress(0.2, desc="处理中: 开始转换...")
|
| 144 |
+
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for i, (start_sample, chunk) in enumerate(segments_with_pos):
|
| 147 |
+
segment_progress = 0.2 + (0.7 * (i / len(segments_with_pos)))
|
| 148 |
+
progress(segment_progress, desc=f"处理中: 片段 {i+1}/{len(segments_with_pos)}")
|
| 149 |
+
|
| 150 |
+
# Process the segment
|
| 151 |
+
audio_out = process_segment(
|
| 152 |
+
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 153 |
+
speaker_id, sample_rate, hop_length, device,
|
| 154 |
+
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 155 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 156 |
+
cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
|
| 157 |
+
robust_f0
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Ensure consistent length
|
| 161 |
+
expected_length = len(chunk)
|
| 162 |
+
if len(audio_out) > expected_length:
|
| 163 |
+
audio_out = audio_out[:expected_length]
|
| 164 |
+
elif len(audio_out) < expected_length:
|
| 165 |
+
audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
|
| 166 |
+
|
| 167 |
+
# Apply fades
|
| 168 |
+
if i > 0: # Not first segment
|
| 169 |
+
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 170 |
+
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 171 |
+
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 172 |
+
|
| 173 |
+
if i < len(segments_with_pos) - 1: # Not last segment
|
| 174 |
+
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 175 |
+
|
| 176 |
+
# Add to result
|
| 177 |
+
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
| 178 |
+
|
| 179 |
+
# Clean up memory after each segment
|
| 180 |
+
torch.cuda.empty_cache()
|
| 181 |
+
|
| 182 |
+
progress(0.9, desc="处理中: 完成音频...")
|
| 183 |
+
# Trim any extra padding
|
| 184 |
+
result_audio = result_audio[:len(audio)]
|
| 185 |
+
|
| 186 |
+
# Create a temporary file to save the result
|
| 187 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
| 188 |
+
output_path = temp_file.name
|
| 189 |
+
|
| 190 |
+
# Save output
|
| 191 |
+
torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate)
|
| 192 |
+
|
| 193 |
+
progress(1.0, desc="处理完成!")
|
| 194 |
+
return (sample_rate, result_audio), f"✅ 转换完成! 已转换为 **{speaker}** 并调整 **{key_shift}** 个半音。"
|
| 195 |
+
|
| 196 |
+
except RuntimeError as e:
|
| 197 |
+
# Handle CUDA out of memory errors
|
| 198 |
+
if "CUDA out of memory" in str(e):
|
| 199 |
+
# Clean up memory
|
| 200 |
+
torch.cuda.empty_cache()
|
| 201 |
+
gc.collect()
|
| 202 |
+
|
| 203 |
+
return None, f"❌ 错误: 内存不足。请尝试更短的音频文件或减少推理步骤。"
|
| 204 |
+
else:
|
| 205 |
+
return None, f"❌ 转换过程中出错: {str(e)}"
|
| 206 |
+
except Exception as e:
|
| 207 |
+
error_trace = traceback.format_exc()
|
| 208 |
+
return None, f"❌ 转换过程中出错: {str(e)}\n\n详细信息: {error_trace}"
|
| 209 |
+
finally:
|
| 210 |
+
# Clean up memory
|
| 211 |
+
torch.cuda.empty_cache()
|
| 212 |
+
gc.collect()
|
| 213 |
+
|
| 214 |
+
def create_ui():
|
| 215 |
+
# CSS for better styling
|
| 216 |
+
css = """
|
| 217 |
+
.gradio-container {
|
| 218 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 219 |
+
}
|
| 220 |
+
.container {
|
| 221 |
+
max-width: 1200px;
|
| 222 |
+
margin: auto;
|
| 223 |
+
}
|
| 224 |
+
.footer {
|
| 225 |
+
margin-top: 20px;
|
| 226 |
+
text-align: center;
|
| 227 |
+
font-size: 0.9em;
|
| 228 |
+
color: #666;
|
| 229 |
+
}
|
| 230 |
+
.title {
|
| 231 |
+
text-align: center;
|
| 232 |
+
margin-bottom: 10px;
|
| 233 |
+
}
|
| 234 |
+
.subtitle {
|
| 235 |
+
text-align: center;
|
| 236 |
+
margin-bottom: 20px;
|
| 237 |
+
color: #666;
|
| 238 |
+
}
|
| 239 |
+
.button-primary {
|
| 240 |
+
background-color: #5460DE !important;
|
| 241 |
+
}
|
| 242 |
+
.output-message {
|
| 243 |
+
margin-top: 10px;
|
| 244 |
+
padding: 10px;
|
| 245 |
+
border-radius: 4px;
|
| 246 |
+
background-color: #f8f9fa;
|
| 247 |
+
border-left: 4px solid #5460DE;
|
| 248 |
+
}
|
| 249 |
+
.error-message {
|
| 250 |
+
color: #d62828;
|
| 251 |
+
font-weight: bold;
|
| 252 |
+
}
|
| 253 |
+
.success-message {
|
| 254 |
+
color: #588157;
|
| 255 |
+
font-weight: bold;
|
| 256 |
+
}
|
| 257 |
+
.info-box {
|
| 258 |
+
background-color: #f8f9fa;
|
| 259 |
+
border-left: 4px solid #5460DE;
|
| 260 |
+
padding: 10px;
|
| 261 |
+
margin: 10px 0;
|
| 262 |
+
border-radius: 4px;
|
| 263 |
+
}
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
# Initialize models
|
| 267 |
+
available_speakers, init_message = initialize_models()
|
| 268 |
+
|
| 269 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="RIFT-SVC 声音转换") as app:
|
| 270 |
+
gr.HTML("""
|
| 271 |
+
<div class="title">
|
| 272 |
+
<h1>🎤 RIFT-SVC 歌声音色转换 (七海Nanami demo)</h1>
|
| 273 |
+
</div>
|
| 274 |
+
<div class="subtitle">
|
| 275 |
+
<h3>使用 RIFT-SVC 模型将歌声或语音转换为七海Nanami的音色</h3>
|
| 276 |
+
</div>
|
| 277 |
+
<div class="info-box">
|
| 278 |
+
<p>📝 <strong>注意:</strong> 为获得最佳效果,请使用背景噪音较少的干净音频。最大音频长度为5分钟。</p>
|
| 279 |
+
</div>
|
| 280 |
+
<div class="info-box">
|
| 281 |
+
<p>🔗 <strong>想要微调自己的说话人?</strong> 请访问 <a href="https://github.com/Pur1zumu/RIFT-SVC" target="_blank">RIFT-SVC GitHub 仓库</a> 获取完整的��练和微调指南。</p>
|
| 282 |
+
</div>
|
| 283 |
+
""")
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
# Left column (input parameters)
|
| 287 |
+
with gr.Column(scale=1):
|
| 288 |
+
with gr.Group():
|
| 289 |
+
gr.Markdown("### 📥 输入")
|
| 290 |
+
model_path = gr.Textbox(label="模型路径", value=DEFAULT_MODEL_PATH, interactive=True)
|
| 291 |
+
input_audio = gr.Audio(label="输入音频文件", type="filepath", elem_id="input_audio")
|
| 292 |
+
reload_btn = gr.Button("🔄 重新加载模型", elem_id="reload_btn")
|
| 293 |
+
|
| 294 |
+
with gr.Accordion("⚙️ 基本参数", open=True):
|
| 295 |
+
speaker = gr.Dropdown(choices=available_speakers, label="目标说话人", interactive=True, elem_id="speaker")
|
| 296 |
+
key_shift = gr.Slider(minimum=-12, maximum=12, step=1, value=0, label="音调调整(半音)", elem_id="key_shift")
|
| 297 |
+
infer_steps = gr.Slider(minimum=8, maximum=64, step=1, value=32, label="推理步数", elem_id="infer_steps",
|
| 298 |
+
info="更低的值 = 更快但质量较低,更高的值 = 更慢但质量更好")
|
| 299 |
+
robust_f0 = gr.Radio(choices=[0, 1, 2], value=0, label="音高滤波",
|
| 300 |
+
info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)",
|
| 301 |
+
elem_id="robust_f0")
|
| 302 |
+
|
| 303 |
+
with gr.Accordion("🔬 高级CFG参数", open=True):
|
| 304 |
+
ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05,
|
| 305 |
+
label="内容向量引导强度",
|
| 306 |
+
info="更高的值可以改善内容保留和咬字清晰度。过高会用力过猛。",
|
| 307 |
+
elem_id="ds_cfg_strength")
|
| 308 |
+
spk_cfg_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0,
|
| 309 |
+
label="说话人引导强度",
|
| 310 |
+
info="更高的值可以增强说话人相似度。过高可能导致音色失真。",
|
| 311 |
+
elem_id="spk_cfg_strength")
|
| 312 |
+
skip_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0,
|
| 313 |
+
label="层引导强度(实验性功能)",
|
| 314 |
+
info="增强指定层的特征渲染。效果取决于目标层的功能。",
|
| 315 |
+
elem_id="skip_cfg_strength")
|
| 316 |
+
cfg_skip_layers = gr.Number(value=6, label="CFG跳过层(实验性功能)", precision=0,
|
| 317 |
+
info="目标增强层下标",
|
| 318 |
+
elem_id="cfg_skip_layers")
|
| 319 |
+
cfg_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7,
|
| 320 |
+
label="CFG重缩放因子",
|
| 321 |
+
info="约束整体引导强度。当引导效果过于强烈时使用调高该值。",
|
| 322 |
+
elem_id="cfg_rescale")
|
| 323 |
+
cvec_downsample_rate = gr.Radio(choices=[1, 2, 4, 8], value=2,
|
| 324 |
+
label="用于反向引导的内容向量下采样率",
|
| 325 |
+
info="更高的值(可能)可以提高内容清晰度。",
|
| 326 |
+
elem_id="cvec_downsample_rate")
|
| 327 |
+
|
| 328 |
+
with gr.Accordion("✂️ 切片参数", open=False):
|
| 329 |
+
slicer_threshold = gr.Slider(minimum=-60.0, maximum=-20.0, step=0.1, value=-30.0,
|
| 330 |
+
label="阈值 (dB)",
|
| 331 |
+
info="静音检测阈值",
|
| 332 |
+
elem_id="slicer_threshold")
|
| 333 |
+
slicer_min_length = gr.Slider(minimum=1000, maximum=10000, step=100, value=3000,
|
| 334 |
+
label="最小长度 (毫秒)",
|
| 335 |
+
info="最小片段长度",
|
| 336 |
+
elem_id="slicer_min_length")
|
| 337 |
+
slicer_min_interval = gr.Slider(minimum=10, maximum=500, step=10, value=100,
|
| 338 |
+
label="最小静音间隔 (毫秒)",
|
| 339 |
+
info="片段之间的最小静音间隔",
|
| 340 |
+
elem_id="slicer_min_interval")
|
| 341 |
+
slicer_hop_size = gr.Slider(minimum=1, maximum=50, step=1, value=10,
|
| 342 |
+
label="跳跃大小 (毫秒)",
|
| 343 |
+
info="分析窗口跳跃大小",
|
| 344 |
+
elem_id="slicer_hop_size")
|
| 345 |
+
slicer_max_sil_kept = gr.Slider(minimum=50, maximum=10000, step=10, value=200,
|
| 346 |
+
label="最大保留静音 (毫秒)",
|
| 347 |
+
info="边界处保留的最大静音",
|
| 348 |
+
elem_id="slicer_max_sil_kept")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Right column (output)
|
| 352 |
+
with gr.Column(scale=1):
|
| 353 |
+
convert_btn = gr.Button("🎵 转换声音", variant="primary", elem_id="convert_btn")
|
| 354 |
+
gr.Markdown("### 📤 输出")
|
| 355 |
+
output_audio = gr.Audio(label="转换后的音频", elem_id="output_audio", autoplay=False, show_share_button=False)
|
| 356 |
+
output_message = gr.Markdown(init_message, elem_id="output_message", elem_classes="output-message")
|
| 357 |
+
|
| 358 |
+
gr.HTML("""
|
| 359 |
+
<div class="info-box">
|
| 360 |
+
<h4>🔍 快速提示</h4>
|
| 361 |
+
<ul>
|
| 362 |
+
<li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li>
|
| 363 |
+
<li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li>
|
| 364 |
+
<li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li>
|
| 365 |
+
<li><strong>CFG参数:</strong> 调整转换质量和音色。</li>
|
| 366 |
+
</ul>
|
| 367 |
+
</div>
|
| 368 |
+
""")
|
| 369 |
+
|
| 370 |
+
# Define button click events
|
| 371 |
+
reload_btn.click(
|
| 372 |
+
fn=initialize_models,
|
| 373 |
+
inputs=[model_path],
|
| 374 |
+
outputs=[speaker, output_message]
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Updated convert button click event
|
| 378 |
+
convert_btn.click(
|
| 379 |
+
fn=lambda: "⏳ 处理中... 请稍候。",
|
| 380 |
+
inputs=None,
|
| 381 |
+
outputs=output_message,
|
| 382 |
+
queue=False
|
| 383 |
+
).then(
|
| 384 |
+
fn=process_with_progress,
|
| 385 |
+
inputs=[
|
| 386 |
+
input_audio, speaker, key_shift, infer_steps, robust_f0,
|
| 387 |
+
ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate,
|
| 388 |
+
slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept
|
| 389 |
+
],
|
| 390 |
+
outputs=[output_audio, output_message],
|
| 391 |
+
show_progress_on=output_audio
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return app
|
| 395 |
+
|
| 396 |
+
if __name__ == "__main__":
|
| 397 |
+
app = create_ui()
|
| 398 |
+
app.launch()
|
infer.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pyloudnorm as pyln
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from rift_svc import DiT, RF
|
| 11 |
+
from rift_svc.feature_extractors import HubertModelWithFinalProj, RMSExtractor, get_mel_spectrogram
|
| 12 |
+
from rift_svc.nsf_hifigan import NsfHifiGAN
|
| 13 |
+
from rift_svc.rmvpe import RMVPE
|
| 14 |
+
from rift_svc.utils import linear_interpolate_tensor, post_process_f0, f0_ensemble, f0_ensemble_light, get_f0_pw, get_f0_pm
|
| 15 |
+
from slicer import Slicer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
torch.set_grad_enabled(False)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def extract_state_dict(ckpt):
|
| 22 |
+
state_dict = ckpt['state_dict']
|
| 23 |
+
new_state_dict = {}
|
| 24 |
+
for k, v in state_dict.items():
|
| 25 |
+
if k.startswith('model.'):
|
| 26 |
+
new_k = k.replace('model.', '')
|
| 27 |
+
new_state_dict[new_k] = v
|
| 28 |
+
spk2idx = ckpt['hyper_parameters']['cfg']['spk2idx']
|
| 29 |
+
model_cfg = ckpt['hyper_parameters']['cfg']['model']
|
| 30 |
+
dataset_cfg = ckpt['hyper_parameters']['cfg']['dataset']
|
| 31 |
+
return new_state_dict, spk2idx, model_cfg, dataset_cfg
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_models(model_path, device):
|
| 35 |
+
"""Load all required models and return them"""
|
| 36 |
+
click.echo("Loading models...")
|
| 37 |
+
|
| 38 |
+
# Load the conversion model
|
| 39 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
| 40 |
+
state_dict, spk2idx, dit_cfg, dataset_cfg = extract_state_dict(ckpt)
|
| 41 |
+
|
| 42 |
+
transformer = DiT(num_speaker=len(spk2idx), **dit_cfg)
|
| 43 |
+
svc_model = RF(transformer=transformer)
|
| 44 |
+
svc_model.load_state_dict(state_dict)
|
| 45 |
+
svc_model = svc_model.to(device)
|
| 46 |
+
svc_model.eval()
|
| 47 |
+
|
| 48 |
+
# Load additional models
|
| 49 |
+
vocoder = NsfHifiGAN('pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt').to(device)
|
| 50 |
+
rmvpe = RMVPE(model_path="pretrained/rmvpe/model.pt", hop_length=160, device=device)
|
| 51 |
+
hubert = HubertModelWithFinalProj.from_pretrained("pretrained/content-vec-best").to(device)
|
| 52 |
+
rms_extractor = RMSExtractor().to(device)
|
| 53 |
+
|
| 54 |
+
return svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_audio(file_path, target_sr):
|
| 58 |
+
"""Load and preprocess audio file"""
|
| 59 |
+
click.echo("Loading audio...")
|
| 60 |
+
audio, sr = torchaudio.load(file_path)
|
| 61 |
+
if sr != target_sr:
|
| 62 |
+
audio = torchaudio.functional.resample(audio, sr, target_sr)
|
| 63 |
+
|
| 64 |
+
if len(audio.shape) > 1:
|
| 65 |
+
audio = audio.mean(dim=0, keepdim=True)
|
| 66 |
+
|
| 67 |
+
return audio.numpy().squeeze()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def apply_fade(audio, fade_samples, fade_in=True):
|
| 71 |
+
"""Apply fade in/out using half of a Hanning window"""
|
| 72 |
+
fade_window = np.hanning(fade_samples * 2)
|
| 73 |
+
if fade_in:
|
| 74 |
+
fade_curve = fade_window[:fade_samples]
|
| 75 |
+
else:
|
| 76 |
+
fade_curve = fade_window[fade_samples:]
|
| 77 |
+
audio[:fade_samples] *= fade_curve
|
| 78 |
+
return audio
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def extract_features(audio_segment, sample_rate, hop_length, rmvpe, hubert, rms_extractor,
|
| 82 |
+
device, key_shift=0, ds_cfg_strength=0.0, cvec_downsample_rate=2, target_loudness=-18.0,
|
| 83 |
+
robust_f0=0):
|
| 84 |
+
"""Extract all required features from an audio segment"""
|
| 85 |
+
# Normalize input segment
|
| 86 |
+
meter = pyln.Meter(sample_rate, block_size=0.1)
|
| 87 |
+
original_loudness = meter.integrated_loudness(audio_segment)
|
| 88 |
+
normalized_audio = pyln.normalize.loudness(audio_segment, original_loudness, target_loudness)
|
| 89 |
+
|
| 90 |
+
# Handle potential clipping
|
| 91 |
+
max_amp = np.max(np.abs(normalized_audio))
|
| 92 |
+
if max_amp > 1.0:
|
| 93 |
+
normalized_audio = normalized_audio * (0.99 / max_amp)
|
| 94 |
+
|
| 95 |
+
audio_tensor = torch.from_numpy(normalized_audio).float().unsqueeze(0).to(device)
|
| 96 |
+
audio_16khz = torch.from_numpy(librosa.resample(normalized_audio, orig_sr=sample_rate, target_sr=16000)).float().unsqueeze(0).to(device)
|
| 97 |
+
|
| 98 |
+
# Extract mel spectrogram
|
| 99 |
+
mel = get_mel_spectrogram(
|
| 100 |
+
audio_tensor,
|
| 101 |
+
sampling_rate=sample_rate,
|
| 102 |
+
n_fft=2048,
|
| 103 |
+
num_mels=128,
|
| 104 |
+
hop_size=512,
|
| 105 |
+
win_size=2048,
|
| 106 |
+
fmin=40,
|
| 107 |
+
fmax=16000
|
| 108 |
+
).transpose(1, 2)
|
| 109 |
+
|
| 110 |
+
# Extract content vector
|
| 111 |
+
cvec = hubert(audio_16khz)["last_hidden_state"].squeeze(0)
|
| 112 |
+
cvec = linear_interpolate_tensor(cvec, mel.shape[1])[None, :]
|
| 113 |
+
|
| 114 |
+
# Create bad_cvec (downsampled) for classifier-free guidance
|
| 115 |
+
if ds_cfg_strength > 0:
|
| 116 |
+
cvec_ds = cvec.clone()
|
| 117 |
+
# Downsample and then interpolate back, similar to dataset.py
|
| 118 |
+
cvec_ds = cvec_ds[0, ::2, :] # Take every other frame
|
| 119 |
+
cvec_ds = linear_interpolate_tensor(cvec_ds, cvec_ds.shape[0]//cvec_downsample_rate)
|
| 120 |
+
cvec_ds = linear_interpolate_tensor(cvec_ds, mel.shape[1])[None, :]
|
| 121 |
+
else:
|
| 122 |
+
cvec_ds = None
|
| 123 |
+
|
| 124 |
+
# Extract f0
|
| 125 |
+
if robust_f0 > 0:
|
| 126 |
+
# Parameters for F0 extraction
|
| 127 |
+
time_step = hop_length / sample_rate
|
| 128 |
+
f0_min = 40
|
| 129 |
+
f0_max = 1100
|
| 130 |
+
|
| 131 |
+
# Extract F0 using multiple methods
|
| 132 |
+
rmvpe_f0 = rmvpe.infer_from_audio(audio_tensor, sample_rate=sample_rate, device=device)
|
| 133 |
+
rmvpe_f0 = post_process_f0(rmvpe_f0, sample_rate, hop_length, mel.shape[1], silence_front=0.0, cut_last=False)
|
| 134 |
+
pw_f0 = get_f0_pw(normalized_audio, sample_rate, time_step, f0_min, f0_max)
|
| 135 |
+
pmac_f0 = get_f0_pm(normalized_audio, sample_rate, time_step, f0_min, f0_max)
|
| 136 |
+
|
| 137 |
+
if robust_f0 == 1:
|
| 138 |
+
# Level 1: Light ensemble that preserves expressiveness
|
| 139 |
+
rms_np = rms_extractor(audio_tensor).squeeze().cpu().numpy()
|
| 140 |
+
f0 = f0_ensemble_light(rmvpe_f0, pw_f0, pmac_f0, rms=rms_np)
|
| 141 |
+
else:
|
| 142 |
+
# Level 2: Strong ensemble with more filtering
|
| 143 |
+
f0 = f0_ensemble(rmvpe_f0, pw_f0, pmac_f0)
|
| 144 |
+
else:
|
| 145 |
+
# Level 0: Use only RMVPE for F0 extraction (original method)
|
| 146 |
+
f0 = rmvpe.infer_from_audio(audio_tensor, sample_rate=sample_rate, device=device)
|
| 147 |
+
f0 = post_process_f0(f0, sample_rate, hop_length, mel.shape[1], silence_front=0.0, cut_last=False)
|
| 148 |
+
|
| 149 |
+
if key_shift != 0:
|
| 150 |
+
f0 = f0 * 2 ** (key_shift / 12)
|
| 151 |
+
f0 = torch.from_numpy(f0).float().to(device)[None, :]
|
| 152 |
+
|
| 153 |
+
# Extract RMS
|
| 154 |
+
rms = rms_extractor(audio_tensor)
|
| 155 |
+
|
| 156 |
+
return mel, cvec, cvec_ds, f0, rms, original_loudness
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def run_inference(
|
| 160 |
+
model, mel, cvec, f0, rms, cvec_ds, spk_id,
|
| 161 |
+
infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 162 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 163 |
+
sliced_inference=False
|
| 164 |
+
):
|
| 165 |
+
"""Run the actual inference through the model"""
|
| 166 |
+
if sliced_inference:
|
| 167 |
+
# Use sliced inference for long segments
|
| 168 |
+
sliced_len = 256
|
| 169 |
+
mel_crossfade_len = 8 # Number of frames to crossfade in mel domain
|
| 170 |
+
|
| 171 |
+
# If the segment is shorter than one slice, just process it directly
|
| 172 |
+
if mel.shape[1] <= sliced_len:
|
| 173 |
+
mel_out, _ = model.sample(
|
| 174 |
+
src_mel=mel,
|
| 175 |
+
spk_id=spk_id,
|
| 176 |
+
f0=f0,
|
| 177 |
+
rms=rms,
|
| 178 |
+
cvec=cvec,
|
| 179 |
+
steps=infer_steps,
|
| 180 |
+
bad_cvec=cvec_ds,
|
| 181 |
+
ds_cfg_strength=ds_cfg_strength,
|
| 182 |
+
spk_cfg_strength=spk_cfg_strength,
|
| 183 |
+
skip_cfg_strength=skip_cfg_strength,
|
| 184 |
+
cfg_skip_layers=cfg_skip_layers,
|
| 185 |
+
cfg_rescale=cfg_rescale,
|
| 186 |
+
)
|
| 187 |
+
return mel_out
|
| 188 |
+
|
| 189 |
+
# Create a tensor to hold the full output with crossfading
|
| 190 |
+
full_mel_out = torch.zeros_like(mel)
|
| 191 |
+
|
| 192 |
+
# Process each slice
|
| 193 |
+
for i in range(0, mel.shape[1], sliced_len - mel_crossfade_len):
|
| 194 |
+
# Determine slice boundaries
|
| 195 |
+
start_idx = i
|
| 196 |
+
end_idx = min(i + sliced_len, mel.shape[1])
|
| 197 |
+
|
| 198 |
+
# Skip if we're at the end
|
| 199 |
+
if start_idx >= mel.shape[1]:
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
# Extract slices for this window
|
| 203 |
+
mel_slice = mel[:, start_idx:end_idx, :]
|
| 204 |
+
cvec_slice = cvec[:, start_idx:end_idx, :]
|
| 205 |
+
f0_slice = f0[:, start_idx:end_idx]
|
| 206 |
+
rms_slice = rms[:, start_idx:end_idx]
|
| 207 |
+
|
| 208 |
+
# Slice the bad_cvec if it exists
|
| 209 |
+
cvec_ds_slice = None
|
| 210 |
+
if cvec_ds is not None:
|
| 211 |
+
cvec_ds_slice = cvec_ds[:, start_idx:end_idx, :]
|
| 212 |
+
|
| 213 |
+
# Process with model
|
| 214 |
+
mel_out_slice, _ = model.sample(
|
| 215 |
+
src_mel=mel_slice,
|
| 216 |
+
spk_id=spk_id,
|
| 217 |
+
f0=f0_slice,
|
| 218 |
+
rms=rms_slice,
|
| 219 |
+
cvec=cvec_slice,
|
| 220 |
+
steps=infer_steps,
|
| 221 |
+
bad_cvec=cvec_ds_slice,
|
| 222 |
+
ds_cfg_strength=ds_cfg_strength,
|
| 223 |
+
spk_cfg_strength=spk_cfg_strength,
|
| 224 |
+
skip_cfg_strength=skip_cfg_strength,
|
| 225 |
+
cfg_skip_layers=cfg_skip_layers,
|
| 226 |
+
cfg_rescale=cfg_rescale,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Create crossfade weights
|
| 230 |
+
slice_len = end_idx - start_idx
|
| 231 |
+
|
| 232 |
+
# Apply different strategies depending on position
|
| 233 |
+
if i == 0: # First slice
|
| 234 |
+
# No crossfade at the beginning
|
| 235 |
+
weights = torch.ones((1, slice_len, 1), device=mel.device)
|
| 236 |
+
if i + sliced_len < mel.shape[1]: # If not the last slice too
|
| 237 |
+
# Fade out at the end - use the minimum of slice_len and mel_crossfade_len
|
| 238 |
+
actual_crossfade_len = min(mel_crossfade_len, slice_len)
|
| 239 |
+
if actual_crossfade_len > 0: # Only apply if we have space
|
| 240 |
+
fade_out = torch.linspace(1, 0, actual_crossfade_len, device=mel.device)
|
| 241 |
+
weights[:, -actual_crossfade_len:, :] = fade_out.view(1, -1, 1)
|
| 242 |
+
elif end_idx >= mel.shape[1]: # Last slice
|
| 243 |
+
# Fade in at the beginning - use the minimum of slice_len and mel_crossfade_len
|
| 244 |
+
weights = torch.ones((1, slice_len, 1), device=mel.device)
|
| 245 |
+
actual_crossfade_len = min(mel_crossfade_len, slice_len)
|
| 246 |
+
if actual_crossfade_len > 0: # Only apply if we have space
|
| 247 |
+
fade_in = torch.linspace(0, 1, actual_crossfade_len, device=mel.device)
|
| 248 |
+
weights[:, :actual_crossfade_len, :] = fade_in.view(1, -1, 1)
|
| 249 |
+
else: # Middle slices
|
| 250 |
+
# Crossfade both sides, handling the case where slice_len < 2*mel_crossfade_len
|
| 251 |
+
weights = torch.ones((1, slice_len, 1), device=mel.device)
|
| 252 |
+
|
| 253 |
+
# Determine the actual crossfade length (might be shorter for small slices)
|
| 254 |
+
actual_crossfade_len = min(mel_crossfade_len, slice_len // 2)
|
| 255 |
+
if actual_crossfade_len > 0:
|
| 256 |
+
fade_in = torch.linspace(0, 1, actual_crossfade_len, device=mel.device)
|
| 257 |
+
fade_out = torch.linspace(1, 0, actual_crossfade_len, device=mel.device)
|
| 258 |
+
weights[:, :actual_crossfade_len, :] = fade_in.view(1, -1, 1)
|
| 259 |
+
weights[:, -actual_crossfade_len:, :] = fade_out.view(1, -1, 1)
|
| 260 |
+
|
| 261 |
+
# Apply weights to current slice output
|
| 262 |
+
mel_out_slice = mel_out_slice * weights
|
| 263 |
+
|
| 264 |
+
# Add to the appropriate region of the output
|
| 265 |
+
full_mel_out[:, start_idx:end_idx, :] += mel_out_slice
|
| 266 |
+
|
| 267 |
+
# Return the full crossfaded output
|
| 268 |
+
mel_out = full_mel_out
|
| 269 |
+
else:
|
| 270 |
+
# Process the entire segment at once
|
| 271 |
+
mel_out, _ = model.sample(
|
| 272 |
+
src_mel=mel,
|
| 273 |
+
spk_id=spk_id,
|
| 274 |
+
f0=f0,
|
| 275 |
+
rms=rms,
|
| 276 |
+
cvec=cvec,
|
| 277 |
+
steps=infer_steps,
|
| 278 |
+
bad_cvec=cvec_ds,
|
| 279 |
+
ds_cfg_strength=ds_cfg_strength,
|
| 280 |
+
spk_cfg_strength=spk_cfg_strength,
|
| 281 |
+
skip_cfg_strength=skip_cfg_strength,
|
| 282 |
+
cfg_skip_layers=cfg_skip_layers,
|
| 283 |
+
cfg_rescale=cfg_rescale,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
return mel_out
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def generate_audio(vocoder, mel_out, f0, original_loudness=None, restore_loudness=True):
|
| 290 |
+
"""Generate audio from mel spectrogram using vocoder"""
|
| 291 |
+
audio_out = vocoder(mel_out.transpose(1, 2), f0)
|
| 292 |
+
audio_out = audio_out.squeeze().cpu().numpy()
|
| 293 |
+
|
| 294 |
+
if restore_loudness and original_loudness is not None:
|
| 295 |
+
# Restore original loudness
|
| 296 |
+
meter = pyln.Meter(44100, block_size=0.1) # Using default sample rate for vocoder
|
| 297 |
+
audio_out_loudness = meter.integrated_loudness(audio_out)
|
| 298 |
+
audio_out = pyln.normalize.loudness(audio_out, audio_out_loudness, original_loudness)
|
| 299 |
+
|
| 300 |
+
# Handle clipping
|
| 301 |
+
max_amp = np.max(np.abs(audio_out))
|
| 302 |
+
if max_amp > 1.0:
|
| 303 |
+
audio_out = audio_out * (0.99 / max_amp)
|
| 304 |
+
|
| 305 |
+
return audio_out
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def process_segment(
|
| 309 |
+
audio_segment,
|
| 310 |
+
svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 311 |
+
speaker_id, sample_rate, hop_length, device,
|
| 312 |
+
key_shift=0,
|
| 313 |
+
infer_steps=32,
|
| 314 |
+
ds_cfg_strength=0.0,
|
| 315 |
+
spk_cfg_strength=0.0,
|
| 316 |
+
skip_cfg_strength=0.0,
|
| 317 |
+
cfg_skip_layers=None,
|
| 318 |
+
cfg_rescale=0.7,
|
| 319 |
+
cvec_downsample_rate=2,
|
| 320 |
+
target_loudness=-18.0,
|
| 321 |
+
restore_loudness=True,
|
| 322 |
+
sliced_inference=False,
|
| 323 |
+
robust_f0=0
|
| 324 |
+
):
|
| 325 |
+
"""Process a single audio segment and return the converted audio"""
|
| 326 |
+
# Extract features
|
| 327 |
+
mel, cvec, cvec_ds, f0, rms, original_loudness = extract_features(
|
| 328 |
+
audio_segment, sample_rate, hop_length, rmvpe, hubert, rms_extractor,
|
| 329 |
+
device, key_shift, ds_cfg_strength, cvec_downsample_rate, target_loudness,
|
| 330 |
+
robust_f0
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Prepare speaker ID
|
| 334 |
+
spk_id = torch.LongTensor([speaker_id]).to(device)
|
| 335 |
+
|
| 336 |
+
# Run inference
|
| 337 |
+
mel_out = run_inference(
|
| 338 |
+
svc_model, mel, cvec, f0, rms, cvec_ds, spk_id,
|
| 339 |
+
infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 340 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 341 |
+
sliced_inference
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Generate audio
|
| 345 |
+
audio_out = generate_audio(
|
| 346 |
+
vocoder, mel_out, f0,
|
| 347 |
+
original_loudness if restore_loudness else None,
|
| 348 |
+
restore_loudness
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return audio_out
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
@click.command()
|
| 355 |
+
@click.option('--model', type=click.Path(exists=True), required=True, help='Path to model checkpoint')
|
| 356 |
+
@click.option('--input', type=click.Path(exists=True), required=True, help='Input audio file')
|
| 357 |
+
@click.option('--output', type=click.Path(), required=True, help='Output audio file')
|
| 358 |
+
@click.option('--speaker', type=str, required=True, help='Target speaker')
|
| 359 |
+
@click.option('--key-shift', type=int, default=0, help='Pitch shift in semitones')
|
| 360 |
+
@click.option('--device', type=str, default=None, help='Device to use (cuda/cpu)')
|
| 361 |
+
@click.option('--infer-steps', type=int, default=32, help='Number of inference steps')
|
| 362 |
+
@click.option('--ds-cfg-strength', type=float, default=0.0, help='Downsampled content vector guidance strength')
|
| 363 |
+
@click.option('--spk-cfg-strength', type=float, default=0.0, help='Speaker guidance strength')
|
| 364 |
+
@click.option('--skip-cfg-strength', type=float, default=0.0, help='Skip layer guidance strength')
|
| 365 |
+
@click.option('--cfg-skip-layers', type=int, default=None, help='Layer to skip for classifier-free guidance')
|
| 366 |
+
@click.option('--cfg-rescale', type=float, default=0.7, help='Classifier-free guidance rescale factor')
|
| 367 |
+
@click.option('--cvec-downsample-rate', type=int, default=2, help='Downsampling rate for bad_cvec creation')
|
| 368 |
+
@click.option('--target-loudness', type=float, default=-18.0, help='Target loudness in LUFS for normalization')
|
| 369 |
+
@click.option('--restore-loudness', default=True, help='Restore loudness to original')
|
| 370 |
+
@click.option('--fade-duration', type=float, default=20.0, help='Fade duration in milliseconds')
|
| 371 |
+
@click.option('--sliced-inference', is_flag=True, default=False, help='Use sliced inference for processing long segments')
|
| 372 |
+
@click.option('--robust-f0', type=int, default=0, help='Level of robust f0 filtering (0=none, 1=light, 2=aggressive)')
|
| 373 |
+
@click.option('--slicer-threshold', type=float, default=-35.0, help='Threshold for audio slicing in dB')
|
| 374 |
+
@click.option('--slicer-min-length', type=int, default=3000, help='Minimum length of audio segments in milliseconds')
|
| 375 |
+
@click.option('--slicer-min-interval', type=int, default=100, help='Minimum interval between audio segments in milliseconds')
|
| 376 |
+
@click.option('--slicer-hop-size', type=int, default=10, help='Hop size for audio slicing in milliseconds')
|
| 377 |
+
@click.option('--slicer-max-sil-kept', type=int, default=300, help='Maximum silence kept in milliseconds')
|
| 378 |
+
def main(
|
| 379 |
+
model,
|
| 380 |
+
input,
|
| 381 |
+
output,
|
| 382 |
+
speaker,
|
| 383 |
+
key_shift,
|
| 384 |
+
device,
|
| 385 |
+
infer_steps,
|
| 386 |
+
ds_cfg_strength,
|
| 387 |
+
spk_cfg_strength,
|
| 388 |
+
skip_cfg_strength,
|
| 389 |
+
cfg_skip_layers,
|
| 390 |
+
cfg_rescale,
|
| 391 |
+
cvec_downsample_rate,
|
| 392 |
+
target_loudness,
|
| 393 |
+
restore_loudness,
|
| 394 |
+
fade_duration,
|
| 395 |
+
sliced_inference,
|
| 396 |
+
robust_f0,
|
| 397 |
+
slicer_threshold,
|
| 398 |
+
slicer_min_length,
|
| 399 |
+
slicer_min_interval,
|
| 400 |
+
slicer_hop_size,
|
| 401 |
+
slicer_max_sil_kept
|
| 402 |
+
):
|
| 403 |
+
"""Convert the voice in an audio file to a target speaker."""
|
| 404 |
+
|
| 405 |
+
# Setup device
|
| 406 |
+
if device is None:
|
| 407 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 408 |
+
device = torch.device(device)
|
| 409 |
+
|
| 410 |
+
# Load models
|
| 411 |
+
svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg = load_models(model, device)
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
speaker_id = spk2idx[speaker]
|
| 415 |
+
except KeyError:
|
| 416 |
+
raise ValueError(f"Speaker {speaker} not found in the model's speaker list, valid speakers are {spk2idx.keys()}")
|
| 417 |
+
|
| 418 |
+
# Get config from loaded model
|
| 419 |
+
hop_length = 512
|
| 420 |
+
sample_rate = 44100
|
| 421 |
+
|
| 422 |
+
# Load audio
|
| 423 |
+
audio = load_audio(input, sample_rate)
|
| 424 |
+
|
| 425 |
+
# Initialize Slicer
|
| 426 |
+
slicer = Slicer(
|
| 427 |
+
sr=sample_rate,
|
| 428 |
+
threshold=slicer_threshold,
|
| 429 |
+
min_length=slicer_min_length,
|
| 430 |
+
min_interval=slicer_min_interval,
|
| 431 |
+
hop_size=slicer_hop_size,
|
| 432 |
+
max_sil_kept=slicer_max_sil_kept
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Step (1): Use slicer to segment the input audio and get positions
|
| 436 |
+
click.echo("Slicing audio...")
|
| 437 |
+
segments_with_pos = slicer.slice(audio) # Now returns list of (start_pos, chunk)
|
| 438 |
+
|
| 439 |
+
if restore_loudness:
|
| 440 |
+
click.echo(f"Will restore loudness to original")
|
| 441 |
+
|
| 442 |
+
# Calculate fade size in samples
|
| 443 |
+
fade_samples = int(fade_duration * sample_rate / 1000)
|
| 444 |
+
|
| 445 |
+
# Process segments
|
| 446 |
+
click.echo(f"Processing {len(segments_with_pos)} segments...")
|
| 447 |
+
result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap
|
| 448 |
+
|
| 449 |
+
with torch.no_grad():
|
| 450 |
+
for idx, (start_sample, chunk) in enumerate(tqdm(segments_with_pos)):
|
| 451 |
+
|
| 452 |
+
# Process the segment
|
| 453 |
+
audio_out = process_segment(
|
| 454 |
+
chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor,
|
| 455 |
+
speaker_id, sample_rate, hop_length, device,
|
| 456 |
+
key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength,
|
| 457 |
+
skip_cfg_strength, cfg_skip_layers, cfg_rescale,
|
| 458 |
+
cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference,
|
| 459 |
+
robust_f0
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Ensure consistent length
|
| 463 |
+
expected_length = len(chunk)
|
| 464 |
+
if len(audio_out) > expected_length:
|
| 465 |
+
audio_out = audio_out[:expected_length]
|
| 466 |
+
elif len(audio_out) < expected_length:
|
| 467 |
+
audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant')
|
| 468 |
+
|
| 469 |
+
# Apply fades
|
| 470 |
+
if idx > 0: # Not first segment
|
| 471 |
+
audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True)
|
| 472 |
+
result_audio[start_sample:start_sample + fade_samples] *= \
|
| 473 |
+
np.linspace(1, 0, fade_samples) # Fade out previous
|
| 474 |
+
|
| 475 |
+
if idx < len(segments_with_pos) - 1: # Not last segment
|
| 476 |
+
audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out
|
| 477 |
+
|
| 478 |
+
# Add to result
|
| 479 |
+
result_audio[start_sample:start_sample + len(audio_out)] += audio_out
|
| 480 |
+
|
| 481 |
+
# Trim any extra padding
|
| 482 |
+
result_audio = result_audio[:len(audio)]
|
| 483 |
+
|
| 484 |
+
# Save output
|
| 485 |
+
click.echo("Saving output...")
|
| 486 |
+
output_path = Path(output)
|
| 487 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 488 |
+
torchaudio.save(output, torch.from_numpy(result_audio).unsqueeze(0), sample_rate)
|
| 489 |
+
click.echo("Done!")
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if __name__ == '__main__':
|
| 493 |
+
main()
|
pretrained/content-vec-best/.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
pretrained/content-vec-best/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
content-vec-best-legacy-500.pt
|
pretrained/content-vec-best/README.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# Content Vec Best
|
| 6 |
+
Official Repo: [ContentVec](https://github.com/auspicious3000/contentvec)
|
| 7 |
+
This repo brings fairseq ContentVec model to HuggingFace Transformers.
|
| 8 |
+
|
| 9 |
+
## How to use
|
| 10 |
+
To use this model, you need to define
|
| 11 |
+
```python
|
| 12 |
+
class HubertModelWithFinalProj(HubertModel):
|
| 13 |
+
def __init__(self, config):
|
| 14 |
+
super().__init__(config)
|
| 15 |
+
|
| 16 |
+
# The final projection layer is only used for backward compatibility.
|
| 17 |
+
# Following https://github.com/auspicious3000/contentvec/issues/6
|
| 18 |
+
# Remove this layer is necessary to achieve the desired outcome.
|
| 19 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
and then load the model with
|
| 23 |
+
```python
|
| 24 |
+
model = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best")
|
| 25 |
+
|
| 26 |
+
x = model(audio)["last_hidden_state"]
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## How to convert
|
| 30 |
+
You need to download the ContentVec_legacy model from the official repo, and then run
|
| 31 |
+
```bash
|
| 32 |
+
python convert.py
|
| 33 |
+
```
|
pretrained/content-vec-best/config.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.1,
|
| 3 |
+
"apply_spec_augment": true,
|
| 4 |
+
"architectures": [
|
| 5 |
+
"HubertModelWithFinalProj"
|
| 6 |
+
],
|
| 7 |
+
"attention_dropout": 0.1,
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"classifier_proj_size": 256,
|
| 10 |
+
"conv_bias": false,
|
| 11 |
+
"conv_dim": [
|
| 12 |
+
512,
|
| 13 |
+
512,
|
| 14 |
+
512,
|
| 15 |
+
512,
|
| 16 |
+
512,
|
| 17 |
+
512,
|
| 18 |
+
512
|
| 19 |
+
],
|
| 20 |
+
"conv_kernel": [
|
| 21 |
+
10,
|
| 22 |
+
3,
|
| 23 |
+
3,
|
| 24 |
+
3,
|
| 25 |
+
3,
|
| 26 |
+
2,
|
| 27 |
+
2
|
| 28 |
+
],
|
| 29 |
+
"conv_stride": [
|
| 30 |
+
5,
|
| 31 |
+
2,
|
| 32 |
+
2,
|
| 33 |
+
2,
|
| 34 |
+
2,
|
| 35 |
+
2,
|
| 36 |
+
2
|
| 37 |
+
],
|
| 38 |
+
"ctc_loss_reduction": "sum",
|
| 39 |
+
"ctc_zero_infinity": false,
|
| 40 |
+
"do_stable_layer_norm": false,
|
| 41 |
+
"eos_token_id": 2,
|
| 42 |
+
"feat_extract_activation": "gelu",
|
| 43 |
+
"feat_extract_norm": "group",
|
| 44 |
+
"feat_proj_dropout": 0.0,
|
| 45 |
+
"feat_proj_layer_norm": true,
|
| 46 |
+
"final_dropout": 0.1,
|
| 47 |
+
"hidden_act": "gelu",
|
| 48 |
+
"hidden_dropout": 0.1,
|
| 49 |
+
"hidden_size": 768,
|
| 50 |
+
"initializer_range": 0.02,
|
| 51 |
+
"intermediate_size": 3072,
|
| 52 |
+
"layer_norm_eps": 1e-05,
|
| 53 |
+
"layerdrop": 0.1,
|
| 54 |
+
"mask_feature_length": 10,
|
| 55 |
+
"mask_feature_min_masks": 0,
|
| 56 |
+
"mask_feature_prob": 0.0,
|
| 57 |
+
"mask_time_length": 10,
|
| 58 |
+
"mask_time_min_masks": 2,
|
| 59 |
+
"mask_time_prob": 0.05,
|
| 60 |
+
"model_type": "hubert",
|
| 61 |
+
"num_attention_heads": 12,
|
| 62 |
+
"num_conv_pos_embedding_groups": 16,
|
| 63 |
+
"num_conv_pos_embeddings": 128,
|
| 64 |
+
"num_feat_extract_layers": 7,
|
| 65 |
+
"num_hidden_layers": 12,
|
| 66 |
+
"pad_token_id": 0,
|
| 67 |
+
"torch_dtype": "float32",
|
| 68 |
+
"transformers_version": "4.27.3",
|
| 69 |
+
"use_weighted_layer_sum": false,
|
| 70 |
+
"vocab_size": 32
|
| 71 |
+
}
|
pretrained/content-vec-best/convert.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from transformers import HubertConfig, HubertModel
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
# Ignore fairseq's logger
|
| 7 |
+
logging.getLogger("fairseq").setLevel(logging.WARNING)
|
| 8 |
+
logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING)
|
| 9 |
+
|
| 10 |
+
from fairseq import checkpoint_utils
|
| 11 |
+
|
| 12 |
+
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
| 13 |
+
["content-vec-best-legacy-500.pt"], suffix=""
|
| 14 |
+
)
|
| 15 |
+
model = models[0]
|
| 16 |
+
model.eval()
|
| 17 |
+
model.eval()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class HubertModelWithFinalProj(HubertModel):
|
| 21 |
+
def __init__(self, config):
|
| 22 |
+
super().__init__(config)
|
| 23 |
+
|
| 24 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Default Config
|
| 28 |
+
hubert = HubertModelWithFinalProj(HubertConfig())
|
| 29 |
+
|
| 30 |
+
# huggingface: fairseq
|
| 31 |
+
mapping = {
|
| 32 |
+
"masked_spec_embed": "mask_emb",
|
| 33 |
+
"encoder.layer_norm.bias": "encoder.layer_norm.bias",
|
| 34 |
+
"encoder.layer_norm.weight": "encoder.layer_norm.weight",
|
| 35 |
+
"encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias",
|
| 36 |
+
"encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g",
|
| 37 |
+
"encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v",
|
| 38 |
+
"feature_projection.layer_norm.bias": "layer_norm.bias",
|
| 39 |
+
"feature_projection.layer_norm.weight": "layer_norm.weight",
|
| 40 |
+
"feature_projection.projection.bias": "post_extract_proj.bias",
|
| 41 |
+
"feature_projection.projection.weight": "post_extract_proj.weight",
|
| 42 |
+
"final_proj.bias": "final_proj.bias",
|
| 43 |
+
"final_proj.weight": "final_proj.weight",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Convert encoder
|
| 47 |
+
for layer in range(12):
|
| 48 |
+
for j in ["q", "k", "v"]:
|
| 49 |
+
mapping[
|
| 50 |
+
f"encoder.layers.{layer}.attention.{j}_proj.weight"
|
| 51 |
+
] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight"
|
| 52 |
+
mapping[
|
| 53 |
+
f"encoder.layers.{layer}.attention.{j}_proj.bias"
|
| 54 |
+
] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias"
|
| 55 |
+
|
| 56 |
+
mapping[
|
| 57 |
+
f"encoder.layers.{layer}.final_layer_norm.bias"
|
| 58 |
+
] = f"encoder.layers.{layer}.final_layer_norm.bias"
|
| 59 |
+
mapping[
|
| 60 |
+
f"encoder.layers.{layer}.final_layer_norm.weight"
|
| 61 |
+
] = f"encoder.layers.{layer}.final_layer_norm.weight"
|
| 62 |
+
|
| 63 |
+
mapping[
|
| 64 |
+
f"encoder.layers.{layer}.layer_norm.bias"
|
| 65 |
+
] = f"encoder.layers.{layer}.self_attn_layer_norm.bias"
|
| 66 |
+
mapping[
|
| 67 |
+
f"encoder.layers.{layer}.layer_norm.weight"
|
| 68 |
+
] = f"encoder.layers.{layer}.self_attn_layer_norm.weight"
|
| 69 |
+
|
| 70 |
+
mapping[
|
| 71 |
+
f"encoder.layers.{layer}.attention.out_proj.bias"
|
| 72 |
+
] = f"encoder.layers.{layer}.self_attn.out_proj.bias"
|
| 73 |
+
mapping[
|
| 74 |
+
f"encoder.layers.{layer}.attention.out_proj.weight"
|
| 75 |
+
] = f"encoder.layers.{layer}.self_attn.out_proj.weight"
|
| 76 |
+
|
| 77 |
+
mapping[
|
| 78 |
+
f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias"
|
| 79 |
+
] = f"encoder.layers.{layer}.fc1.bias"
|
| 80 |
+
mapping[
|
| 81 |
+
f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight"
|
| 82 |
+
] = f"encoder.layers.{layer}.fc1.weight"
|
| 83 |
+
|
| 84 |
+
mapping[
|
| 85 |
+
f"encoder.layers.{layer}.feed_forward.output_dense.bias"
|
| 86 |
+
] = f"encoder.layers.{layer}.fc2.bias"
|
| 87 |
+
mapping[
|
| 88 |
+
f"encoder.layers.{layer}.feed_forward.output_dense.weight"
|
| 89 |
+
] = f"encoder.layers.{layer}.fc2.weight"
|
| 90 |
+
|
| 91 |
+
# Convert Conv Layers
|
| 92 |
+
for layer in range(7):
|
| 93 |
+
mapping[
|
| 94 |
+
f"feature_extractor.conv_layers.{layer}.conv.weight"
|
| 95 |
+
] = f"feature_extractor.conv_layers.{layer}.0.weight"
|
| 96 |
+
|
| 97 |
+
if layer != 0:
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
mapping[
|
| 101 |
+
f"feature_extractor.conv_layers.{layer}.layer_norm.weight"
|
| 102 |
+
] = f"feature_extractor.conv_layers.{layer}.2.weight"
|
| 103 |
+
mapping[
|
| 104 |
+
f"feature_extractor.conv_layers.{layer}.layer_norm.bias"
|
| 105 |
+
] = f"feature_extractor.conv_layers.{layer}.2.bias"
|
| 106 |
+
|
| 107 |
+
hf_keys = set(hubert.state_dict().keys())
|
| 108 |
+
fair_keys = set(model.state_dict().keys())
|
| 109 |
+
|
| 110 |
+
hf_keys -= set(mapping.keys())
|
| 111 |
+
fair_keys -= set(mapping.values())
|
| 112 |
+
|
| 113 |
+
for i, j in zip(sorted(hf_keys), sorted(fair_keys)):
|
| 114 |
+
print(i, j)
|
| 115 |
+
|
| 116 |
+
print(hf_keys, fair_keys)
|
| 117 |
+
print(len(hf_keys), len(fair_keys))
|
| 118 |
+
|
| 119 |
+
# try loading the weights
|
| 120 |
+
new_state_dict = {}
|
| 121 |
+
for k, v in mapping.items():
|
| 122 |
+
new_state_dict[k] = model.state_dict()[v]
|
| 123 |
+
|
| 124 |
+
x = hubert.load_state_dict(new_state_dict, strict=False)
|
| 125 |
+
print(x)
|
| 126 |
+
hubert.eval()
|
| 127 |
+
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
new_input = torch.randn(1, 16384)
|
| 130 |
+
|
| 131 |
+
result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9]
|
| 132 |
+
result1 = hubert.final_proj(result1)
|
| 133 |
+
|
| 134 |
+
result2 = model.extract_features(
|
| 135 |
+
**{
|
| 136 |
+
"source": new_input,
|
| 137 |
+
"padding_mask": torch.zeros(1, 16384, dtype=torch.bool),
|
| 138 |
+
# "features_only": True,
|
| 139 |
+
"output_layer": 9,
|
| 140 |
+
}
|
| 141 |
+
)[0]
|
| 142 |
+
result2 = model.final_proj(result2)
|
| 143 |
+
|
| 144 |
+
assert torch.allclose(result1, result2, atol=1e-3)
|
| 145 |
+
|
| 146 |
+
print("Sanity check passed")
|
| 147 |
+
|
| 148 |
+
# Save huggingface model
|
| 149 |
+
hubert.save_pretrained(".")
|
| 150 |
+
print("Saved model")
|
pretrained/download.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
model_path = snapshot_download(
|
| 6 |
+
repo_id="Pur1zumu/RIFT-SVC-modules",
|
| 7 |
+
local_dir='pretrained',
|
| 8 |
+
local_dir_use_symlinks=False, # Don't use symlinks
|
| 9 |
+
local_files_only=False, # Allow downloading new files
|
| 10 |
+
ignore_patterns=["*.git*"], # Ignore git-related files
|
| 11 |
+
resume_download=True # Resume interrupted downloads
|
| 12 |
+
)
|
pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/NOTICE.txt
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--- DiffSinger Community Vocoder ---
|
| 2 |
+
|
| 3 |
+
ARCHITECTURE: NSF-HiFiGAN
|
| 4 |
+
RELEASE DATE: 2024-02-19
|
| 5 |
+
|
| 6 |
+
HYPER PARAMETERS:
|
| 7 |
+
- 44100 sample rate
|
| 8 |
+
- 128 mel bins
|
| 9 |
+
- 512 hop size
|
| 10 |
+
- 2048 window size
|
| 11 |
+
- fmin at 40Hz
|
| 12 |
+
- fmax at 16000Hz
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
NOTICE:
|
| 16 |
+
|
| 17 |
+
All model weights in the [DiffSinger Community Vocoder Project](https://openvpi.github.io/vocoders/), including
|
| 18 |
+
model weights in this directory, are provided by the [OpenVPI Team](https://github.com/openvpi/), under the
|
| 19 |
+
[Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
ACKNOWLEDGEMENTS:
|
| 23 |
+
|
| 24 |
+
Training data of this vocoder is provided and permitted by the following organizations, societies and individuals:
|
| 25 |
+
|
| 26 |
+
孙飒 https://www.qfssr.cn
|
| 27 |
+
赤松_Akamatsu https://www.zhibin.club
|
| 28 |
+
乐威 https://www.zhibin.club
|
| 29 |
+
伯添 https://space.bilibili.com/24087011
|
| 30 |
+
雲宇光 https://space.bilibili.com/660675050
|
| 31 |
+
橙子言 https://space.bilibili.com/318486464
|
| 32 |
+
人衣大人 https://space.bilibili.com/2270344
|
| 33 |
+
玖蝶 https://space.bilibili.com/676771003
|
| 34 |
+
Yuuko
|
| 35 |
+
白夜零BYL https://space.bilibili.com/1605040503
|
| 36 |
+
嗷天 https://space.bilibili.com/5675252
|
| 37 |
+
洛泠羽 https://space.bilibili.com/347373318
|
| 38 |
+
灰条纹的灰猫君 https://space.bilibili.com/2083633
|
| 39 |
+
幽寂 https://space.bilibili.com/478860
|
| 40 |
+
恶魔王女 https://space.bilibili.com/2475098
|
| 41 |
+
AlexYHX 芮晴
|
| 42 |
+
绮萱 https://y.qq.com/n/ryqq/singer/003HjD6H4aZn1K
|
| 43 |
+
诗芸 https://y.qq.com/n/ryqq/singer/0005NInj142zm0
|
| 44 |
+
汐蕾 https://y.qq.com/n/ryqq/singer/0023cWMH1Bq1PJ
|
| 45 |
+
1262917464
|
| 46 |
+
炜阳
|
| 47 |
+
叶卡yolka
|
| 48 |
+
幸の夏 https://space.bilibili.com/1017297686
|
| 49 |
+
暮色未量 https://space.bilibili.com/272904686
|
| 50 |
+
晓寞sama https://space.bilibili.com/3463394
|
| 51 |
+
没头绪的节操君
|
| 52 |
+
串串BunC https://space.bilibili.com/95817834
|
| 53 |
+
落雨 https://space.bilibili.com/1292427
|
| 54 |
+
长尾巴的翎艾 https://space.bilibili.com/1638666
|
| 55 |
+
声闻计划 https://space.bilibili.com/392812269
|
| 56 |
+
唐家大小姐 http://5sing.kugou.com/palmusic/default.html
|
| 57 |
+
不伊子
|
| 58 |
+
芸青岩 https://space.bilibili.com/35236775
|
| 59 |
+
妖橙 https://space.bilibili.com/161975631
|
| 60 |
+
双桨 https://space.bilibili.com/13245483
|
| 61 |
+
灵滅 https://space.bilibili.com/276988145
|
| 62 |
+
AlexYHX https://space.bilibili.com/13303439
|
| 63 |
+
祁唱 https://space.bilibili.com/11256670
|
| 64 |
+
早稻叽 https://space.bilibili.com/1950658
|
| 65 |
+
|
| 66 |
+
The following public datasets are used:
|
| 67 |
+
|
| 68 |
+
Opencpop https://wenet.org.cn/opencpop/
|
| 69 |
+
CCMUSIC https://ccmusic-database.github.io/index.html
|
| 70 |
+
SingingVoiceDataset http://isophonics.net/SingingVoiceDataset
|
| 71 |
+
|
| 72 |
+
Training machines are provided by:
|
| 73 |
+
|
| 74 |
+
花儿不哭 https://space.bilibili.com/5760446
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
TERMS OF REDISTRIBUTIONS:
|
| 78 |
+
|
| 79 |
+
1. Do not sell this vocoder, or charge any fees from redistributing it, as prohibited by
|
| 80 |
+
the license.
|
| 81 |
+
2. Include a copy of the CC BY-NC-SA 4.0 license, or a link referring to it.
|
| 82 |
+
3. Include a copy of this notice, or any other notices informing that this vocoder is
|
| 83 |
+
provided by the OpenVPI Team, that this vocoder is licensed under CC BY-NC-SA 4.0, and
|
| 84 |
+
with a complete acknowledgement list as shown above.
|
| 85 |
+
4. If you fine-tuned or modified the weights, leave a notice about what has been changed.
|
| 86 |
+
5. (Optional) Leave a link to the official release page of the vocoder, and tell users
|
| 87 |
+
that other versions and future updates of this vocoder can be obtained from the website.
|
pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/NOTICE.zh-CN.txt
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--- DiffSinger 社区声码器 ---
|
| 2 |
+
|
| 3 |
+
架构:NSF-HiFiGAN
|
| 4 |
+
发布日期:2024-02-19
|
| 5 |
+
|
| 6 |
+
超参数:
|
| 7 |
+
- 44100 sample rate
|
| 8 |
+
- 128 mel bins
|
| 9 |
+
- 512 hop size
|
| 10 |
+
- 2048 window size
|
| 11 |
+
- fmin at 40Hz
|
| 12 |
+
- fmax at 16000Hz
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
注意事项:
|
| 16 |
+
|
| 17 |
+
[DiffSinger 社区声码器企划](https://openvpi.github.io/vocoders/) 中的所有模型权重,
|
| 18 |
+
包括此目录下的模型权重,均由 [OpenVPI Team](https://github.com/openvpi/) 提供,并基于
|
| 19 |
+
[Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/)
|
| 20 |
+
进行许可。
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
致谢:
|
| 24 |
+
|
| 25 |
+
此声码器的训练数据由以下组织、社团和个人提供并许可:
|
| 26 |
+
|
| 27 |
+
孙飒 https://www.qfssr.cn
|
| 28 |
+
赤松_Akamatsu https://www.zhibin.club
|
| 29 |
+
乐威 https://www.zhibin.club
|
| 30 |
+
伯添 https://space.bilibili.com/24087011
|
| 31 |
+
雲宇光 https://space.bilibili.com/660675050
|
| 32 |
+
橙子言 https://space.bilibili.com/318486464
|
| 33 |
+
人衣大人 https://space.bilibili.com/2270344
|
| 34 |
+
玖蝶 https://space.bilibili.com/676771003
|
| 35 |
+
Yuuko
|
| 36 |
+
白夜零BYL https://space.bilibili.com/1605040503
|
| 37 |
+
嗷天 https://space.bilibili.com/5675252
|
| 38 |
+
洛泠羽 https://space.bilibili.com/347373318
|
| 39 |
+
灰条纹的灰猫君 https://space.bilibili.com/2083633
|
| 40 |
+
幽寂 https://space.bilibili.com/478860
|
| 41 |
+
恶魔王女 https://space.bilibili.com/2475098
|
| 42 |
+
芮晴
|
| 43 |
+
绮萱 https://y.qq.com/n/ryqq/singer/003HjD6H4aZn1K
|
| 44 |
+
诗芸 https://y.qq.com/n/ryqq/singer/0005NInj142zm0
|
| 45 |
+
汐蕾 https://y.qq.com/n/ryqq/singer/0023cWMH1Bq1PJ
|
| 46 |
+
1262917464
|
| 47 |
+
炜阳
|
| 48 |
+
叶卡yolka
|
| 49 |
+
幸の夏 https://space.bilibili.com/1017297686
|
| 50 |
+
暮色未量 https://space.bilibili.com/272904686
|
| 51 |
+
晓寞sama https://space.bilibili.com/3463394
|
| 52 |
+
没头绪的节操君
|
| 53 |
+
串串BunC https://space.bilibili.com/95817834
|
| 54 |
+
落雨 https://space.bilibili.com/1292427
|
| 55 |
+
长尾巴的翎艾 https://space.bilibili.com/1638666
|
| 56 |
+
声闻计划 https://space.bilibili.com/392812269
|
| 57 |
+
唐家大小姐 http://5sing.kugou.com/palmusic/default.html
|
| 58 |
+
不伊子
|
| 59 |
+
芸青岩 https://space.bilibili.com/35236775
|
| 60 |
+
妖橙 https://space.bilibili.com/161975631
|
| 61 |
+
双桨 https://space.bilibili.com/13245483
|
| 62 |
+
灵滅 https://space.bilibili.com/276988145
|
| 63 |
+
AlexYHX https://space.bilibili.com/13303439
|
| 64 |
+
祁唱 https://space.bilibili.com/11256670
|
| 65 |
+
早稻叽 https://space.bilibili.com/1950658
|
| 66 |
+
|
| 67 |
+
使用了以下公开数据集:
|
| 68 |
+
|
| 69 |
+
Opencpop https://wenet.org.cn/opencpop/
|
| 70 |
+
CCMUSIC https://ccmusic-database.github.io/index.html
|
| 71 |
+
SingingVoiceDataset http://isophonics.net/SingingVoiceDataset
|
| 72 |
+
|
| 73 |
+
训练算力的提供者如下:
|
| 74 |
+
|
| 75 |
+
花儿不哭 https://space.bilibili.com/5760446
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
二次分发条款:
|
| 79 |
+
|
| 80 |
+
1. 请勿售卖此声码器或从其二次分发过程中收取任何费用,因为此类行为受到许可证的禁止。
|
| 81 |
+
2. 请在二次分发文件中包含一份 CC BY-NC-SA 4.0 许可证的副本或指向该许可证的链接。
|
| 82 |
+
3. 请在二次分发文件中包含这份声明,或以其他形式声明此声码器由 OpenVPI Team 提供并基于 CC BY-NC-SA 4.0 许可,
|
| 83 |
+
并附带上述完整的致谢名单。
|
| 84 |
+
4. 如果您微调或修改了权重,请留下一份关于其受到了何种修改的说明。
|
| 85 |
+
5.(可选)留下一份指向此声码器的官方发布页面的链接,并告知使用者可从该网站获取此声码器的其他版本和未来的更新。
|
pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"discriminator_periods": [
|
| 3 |
+
3,
|
| 4 |
+
5,
|
| 5 |
+
7,
|
| 6 |
+
11,
|
| 7 |
+
17,
|
| 8 |
+
23,
|
| 9 |
+
37
|
| 10 |
+
],
|
| 11 |
+
"resblock": "1",
|
| 12 |
+
"resblock_dilation_sizes": [
|
| 13 |
+
[
|
| 14 |
+
1,
|
| 15 |
+
3,
|
| 16 |
+
5
|
| 17 |
+
],
|
| 18 |
+
[
|
| 19 |
+
1,
|
| 20 |
+
3,
|
| 21 |
+
5
|
| 22 |
+
],
|
| 23 |
+
[
|
| 24 |
+
1,
|
| 25 |
+
3,
|
| 26 |
+
5
|
| 27 |
+
]
|
| 28 |
+
],
|
| 29 |
+
"resblock_kernel_sizes": [
|
| 30 |
+
3,
|
| 31 |
+
7,
|
| 32 |
+
11
|
| 33 |
+
],
|
| 34 |
+
"upsample_initial_channel": 512,
|
| 35 |
+
"upsample_kernel_sizes": [
|
| 36 |
+
16,
|
| 37 |
+
16,
|
| 38 |
+
4,
|
| 39 |
+
4,
|
| 40 |
+
4
|
| 41 |
+
],
|
| 42 |
+
"upsample_rates": [
|
| 43 |
+
8,
|
| 44 |
+
8,
|
| 45 |
+
2,
|
| 46 |
+
2,
|
| 47 |
+
2
|
| 48 |
+
],
|
| 49 |
+
"sampling_rate": 44100,
|
| 50 |
+
"num_mels": 128,
|
| 51 |
+
"hop_size": 512,
|
| 52 |
+
"n_fft": 2048,
|
| 53 |
+
"win_size": 2048,
|
| 54 |
+
"fmin": 40,
|
| 55 |
+
"fmax": 16000
|
| 56 |
+
}
|
pretrained/rmvpe/.gitkeep
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
click
|
| 2 |
+
einops
|
| 3 |
+
gradio
|
| 4 |
+
huggingface_hub
|
| 5 |
+
hydra-core
|
| 6 |
+
jaxtyping
|
| 7 |
+
librosa
|
| 8 |
+
matplotlib
|
| 9 |
+
numpy
|
| 10 |
+
omegaconf
|
| 11 |
+
Pillow
|
| 12 |
+
praat-parselmouth
|
| 13 |
+
pyloudnorm
|
| 14 |
+
PyYAML
|
| 15 |
+
pytorch_lightning
|
| 16 |
+
resampy
|
| 17 |
+
schedulefree
|
| 18 |
+
scipy
|
| 19 |
+
soundfile
|
| 20 |
+
tensorboard
|
| 21 |
+
thop
|
| 22 |
+
torch
|
| 23 |
+
torchaudio
|
| 24 |
+
torchdiffeq
|
| 25 |
+
tqdm
|
| 26 |
+
transformers
|
| 27 |
+
wandb
|
| 28 |
+
x_transformers
|
rift_svc/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rift_svc.rf import RF
|
| 2 |
+
from rift_svc.dit import DiT
|
| 3 |
+
from rift_svc.lightning_module import RIFTSVCLightningModule
|
rift_svc/dataset.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Literal
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
from rift_svc.utils import linear_interpolate_tensor, nearest_interpolate_tensor
|
| 11 |
+
|
| 12 |
+
pt_load = partial(torch.load, weights_only=True, map_location='cpu', mmap=True)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SVCDataset(Dataset):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
data_dir: str,
|
| 19 |
+
meta_info_path: str,
|
| 20 |
+
max_frame_len = 256,
|
| 21 |
+
split = "train",
|
| 22 |
+
use_cvec_downsampled: bool = False,
|
| 23 |
+
cvec_downsample_rate: int = 2,
|
| 24 |
+
):
|
| 25 |
+
self.data_dir = data_dir
|
| 26 |
+
self.max_frame_len = max_frame_len
|
| 27 |
+
|
| 28 |
+
with open(meta_info_path, 'r', encoding='utf-8') as f:
|
| 29 |
+
meta = json.load(f)
|
| 30 |
+
|
| 31 |
+
speakers = meta["speakers"]
|
| 32 |
+
self.num_speakers = len(speakers)
|
| 33 |
+
self.spk2idx = {spk: idx for idx, spk in enumerate(speakers)}
|
| 34 |
+
self.split = split
|
| 35 |
+
self.samples = meta[f"{split}_audios"]
|
| 36 |
+
self.use_cvec_downsampled = use_cvec_downsampled
|
| 37 |
+
self.cvec_downsample_rate = cvec_downsample_rate
|
| 38 |
+
|
| 39 |
+
def get_frame_len(self, index):
|
| 40 |
+
return self.samples[index]['frame_len']
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.samples)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, index):
|
| 46 |
+
|
| 47 |
+
sample = self.samples[index]
|
| 48 |
+
spk = sample['speaker']
|
| 49 |
+
path = os.path.join(self.data_dir, spk, sample['file_name'])
|
| 50 |
+
spk_id = torch.LongTensor([self.spk2idx[spk]]) # [1]
|
| 51 |
+
|
| 52 |
+
mel = pt_load(path + ".mel.pt").squeeze(0).T
|
| 53 |
+
rms = pt_load(path + ".rms.pt").squeeze(0)
|
| 54 |
+
f0 = pt_load(path + ".f0.pt").squeeze(0)
|
| 55 |
+
cvec = pt_load(path + ".cvec.pt").squeeze(0)
|
| 56 |
+
|
| 57 |
+
cvec = linear_interpolate_tensor(cvec, mel.shape[0])
|
| 58 |
+
if self.use_cvec_downsampled:
|
| 59 |
+
cvec_ds = cvec[::2, :]
|
| 60 |
+
cvec_ds = linear_interpolate_tensor(cvec_ds, cvec_ds.shape[0]//self.cvec_downsample_rate)
|
| 61 |
+
cvec_ds = linear_interpolate_tensor(cvec_ds, mel.shape[0])
|
| 62 |
+
|
| 63 |
+
frame_len = mel.shape[0]
|
| 64 |
+
|
| 65 |
+
if frame_len > self.max_frame_len:
|
| 66 |
+
if self.split == "train":
|
| 67 |
+
# Keep trying until we find a good segment or hit max attempts
|
| 68 |
+
max_attempts = 10
|
| 69 |
+
attempt = 0
|
| 70 |
+
while attempt < max_attempts:
|
| 71 |
+
start = random.randint(0, frame_len - self.max_frame_len)
|
| 72 |
+
end = start + self.max_frame_len
|
| 73 |
+
f0_segment = f0[start:end]
|
| 74 |
+
# Check if more than 90% of f0 values are 0
|
| 75 |
+
zero_ratio = (f0_segment == 0).float().mean().item()
|
| 76 |
+
if zero_ratio < 0.9: # Found a good segment
|
| 77 |
+
break
|
| 78 |
+
attempt += 1
|
| 79 |
+
else:
|
| 80 |
+
start = 0
|
| 81 |
+
end = start + self.max_frame_len
|
| 82 |
+
mel = mel[start:end]
|
| 83 |
+
rms = rms[start:end]
|
| 84 |
+
f0 = f0[start:end]
|
| 85 |
+
cvec = cvec[start:end]
|
| 86 |
+
if self.use_cvec_downsampled:
|
| 87 |
+
cvec_ds = cvec_ds[start:end]
|
| 88 |
+
frame_len = self.max_frame_len
|
| 89 |
+
|
| 90 |
+
result = dict(
|
| 91 |
+
spk_id = spk_id,
|
| 92 |
+
mel = mel,
|
| 93 |
+
rms = rms,
|
| 94 |
+
f0 = f0,
|
| 95 |
+
cvec = cvec,
|
| 96 |
+
frame_len = frame_len
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if self.use_cvec_downsampled:
|
| 100 |
+
result['cvec_ds'] = cvec_ds
|
| 101 |
+
|
| 102 |
+
return result
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def collate_fn(batch):
|
| 106 |
+
spk_ids = [item['spk_id'] for item in batch]
|
| 107 |
+
mels = [item['mel'] for item in batch]
|
| 108 |
+
rmss = [item['rms'] for item in batch]
|
| 109 |
+
f0s = [item['f0'] for item in batch]
|
| 110 |
+
cvecs = [item['cvec'] for item in batch]
|
| 111 |
+
if 'cvec_ds' in batch[0]:
|
| 112 |
+
cvecs_ds = [item['cvec_ds'] for item in batch]
|
| 113 |
+
|
| 114 |
+
frame_lens = [item['frame_len'] for item in batch]
|
| 115 |
+
|
| 116 |
+
# Pad sequences to max length
|
| 117 |
+
mels_padded = pad_sequence(mels, batch_first=True)
|
| 118 |
+
rmss_padded = pad_sequence(rmss, batch_first=True)
|
| 119 |
+
f0s_padded = pad_sequence(f0s, batch_first=True)
|
| 120 |
+
cvecs_padded = pad_sequence(cvecs, batch_first=True)
|
| 121 |
+
if 'cvec_ds' in batch[0]:
|
| 122 |
+
cvecs_ds_padded = pad_sequence(cvecs_ds, batch_first=True)
|
| 123 |
+
|
| 124 |
+
spk_ids = torch.cat(spk_ids)
|
| 125 |
+
frame_len = torch.tensor(frame_lens)
|
| 126 |
+
|
| 127 |
+
result = {
|
| 128 |
+
'spk_id': spk_ids,
|
| 129 |
+
'mel': mels_padded,
|
| 130 |
+
'rms': rmss_padded,
|
| 131 |
+
'f0': f0s_padded,
|
| 132 |
+
'cvec': cvecs_padded,
|
| 133 |
+
'frame_len': frame_len
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
if 'cvec_ds' in batch[0]:
|
| 137 |
+
result['cvec_ds'] = cvecs_ds_padded
|
| 138 |
+
|
| 139 |
+
return result
|
rift_svc/dit.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Union, List
|
| 3 |
+
|
| 4 |
+
from einops import repeat
|
| 5 |
+
from jaxtyping import Bool, Float, Int
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 10 |
+
|
| 11 |
+
from rift_svc.modules import (
|
| 12 |
+
AdaLayerNormZero_Final,
|
| 13 |
+
DiTBlock,
|
| 14 |
+
TimestepEmbedding,
|
| 15 |
+
LoRALinear,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Conditional embedding for f0, rms, cvec
|
| 19 |
+
class CondEmbedding(nn.Module):
|
| 20 |
+
def __init__(self, cvec_dim: int, cond_dim: int):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.cvec_dim = cvec_dim
|
| 23 |
+
self.cond_dim = cond_dim
|
| 24 |
+
|
| 25 |
+
self.f0_embed = nn.Linear(1, cond_dim)
|
| 26 |
+
self.rms_embed = nn.Linear(1, cond_dim)
|
| 27 |
+
self.cvec_embed = nn.Linear(cvec_dim, cond_dim)
|
| 28 |
+
self.out = nn.Linear(cond_dim, cond_dim)
|
| 29 |
+
|
| 30 |
+
self.ln_cvec = nn.LayerNorm(cond_dim, elementwise_affine=False, eps=1e-6)
|
| 31 |
+
self.ln = nn.LayerNorm(cond_dim, elementwise_affine=True, eps=1e-6)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def forward(
|
| 35 |
+
self,
|
| 36 |
+
f0: Float[torch.Tensor, "b n"],
|
| 37 |
+
rms: Float[torch.Tensor, "b n"],
|
| 38 |
+
cvec: Float[torch.Tensor, "b n d"],
|
| 39 |
+
):
|
| 40 |
+
if f0.ndim == 2:
|
| 41 |
+
f0 = f0.unsqueeze(-1)
|
| 42 |
+
if rms.ndim == 2:
|
| 43 |
+
rms = rms.unsqueeze(-1)
|
| 44 |
+
|
| 45 |
+
f0_embed = self.f0_embed(f0 / 1200)
|
| 46 |
+
rms_embed = self.rms_embed(rms)
|
| 47 |
+
cvec_embed = self.ln_cvec(self.cvec_embed(cvec))
|
| 48 |
+
|
| 49 |
+
cond = f0_embed + rms_embed + cvec_embed
|
| 50 |
+
cond = self.ln(self.out(cond))
|
| 51 |
+
return cond
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# noised input audio and context mixing embedding
|
| 55 |
+
class InputEmbedding(nn.Module):
|
| 56 |
+
def __init__(self, mel_dim: int, out_dim: int):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.mel_embed = nn.Linear(mel_dim, out_dim)
|
| 59 |
+
self.proj = nn.Linear(2 * out_dim, out_dim)
|
| 60 |
+
self.ln = nn.LayerNorm(out_dim, elementwise_affine=False, eps=1e-6)
|
| 61 |
+
|
| 62 |
+
def forward(self, x: Float[torch.Tensor, "b n d1"], cond_embed: Float[torch.Tensor, "b n d2"]):
|
| 63 |
+
x = self.mel_embed(x)
|
| 64 |
+
x = torch.cat((x, cond_embed), dim = -1)
|
| 65 |
+
x = self.proj(x)
|
| 66 |
+
x = self.ln(x)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# backbone using DiT blocks
|
| 71 |
+
class DiT(nn.Module):
|
| 72 |
+
def __init__(self,
|
| 73 |
+
dim: int, depth: int, head_dim: int = 64, dropout: float = 0.0, ff_mult: int = 4,
|
| 74 |
+
n_mel_channels: int = 128, num_speaker: int = 1, cvec_dim: int = 768,
|
| 75 |
+
kernel_size: int = 31, zero_null_spk: bool = False,
|
| 76 |
+
init_std: float = 1):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.num_speaker = num_speaker
|
| 80 |
+
self.spk_embed = nn.Embedding(num_speaker, dim)
|
| 81 |
+
self.null_spk_embed = nn.Embedding(1, dim)
|
| 82 |
+
self.tembed = TimestepEmbedding(dim)
|
| 83 |
+
self.cond_embed = CondEmbedding(cvec_dim, dim)
|
| 84 |
+
self.input_embed = InputEmbedding(n_mel_channels, dim)
|
| 85 |
+
|
| 86 |
+
self.rotary_embed = RotaryEmbedding(head_dim)
|
| 87 |
+
|
| 88 |
+
self.dim = dim
|
| 89 |
+
self.depth = depth
|
| 90 |
+
self.transformer_blocks = nn.ModuleList(
|
| 91 |
+
[
|
| 92 |
+
DiTBlock(
|
| 93 |
+
dim = dim,
|
| 94 |
+
head_dim = head_dim,
|
| 95 |
+
ff_mult = ff_mult,
|
| 96 |
+
dropout = dropout,
|
| 97 |
+
kernel_size = kernel_size,
|
| 98 |
+
)
|
| 99 |
+
for _ in range(depth)
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.norm_out = AdaLayerNormZero_Final(dim)
|
| 104 |
+
self.output = nn.Linear(dim, n_mel_channels)
|
| 105 |
+
|
| 106 |
+
self.init_std = init_std
|
| 107 |
+
self.apply(self._init_weights)
|
| 108 |
+
for block in self.transformer_blocks:
|
| 109 |
+
torch.nn.init.constant_(block.attn_norm.proj.weight, 0)
|
| 110 |
+
torch.nn.init.constant_(block.attn_norm.proj.bias, 0)
|
| 111 |
+
|
| 112 |
+
torch.nn.init.constant_(self.norm_out.proj.weight, 0)
|
| 113 |
+
torch.nn.init.constant_(self.norm_out.proj.bias, 0)
|
| 114 |
+
torch.nn.init.constant_(self.output.weight, 0)
|
| 115 |
+
torch.nn.init.constant_(self.output.bias, 0)
|
| 116 |
+
|
| 117 |
+
if zero_null_spk:
|
| 118 |
+
self.null_spk_embed.weight.data.zero_()
|
| 119 |
+
self.null_spk_embed.requires_grad = False
|
| 120 |
+
|
| 121 |
+
def _init_weights(self, module: nn.Module):
|
| 122 |
+
if isinstance(module, nn.Linear):
|
| 123 |
+
fan_out, fan_in = module.weight.shape
|
| 124 |
+
# Spectral parameterization from the [paper](https://arxiv.org/abs/2310.17813).
|
| 125 |
+
init_std = (self.init_std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
|
| 126 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 127 |
+
if module.bias is not None:
|
| 128 |
+
torch.nn.init.zeros_(module.bias)
|
| 129 |
+
elif isinstance(module, nn.Conv1d):
|
| 130 |
+
# weight shape: (out_channels, in_channels/groups, kernel_size)
|
| 131 |
+
fan_out = module.weight.shape[0] # out_channels
|
| 132 |
+
fan_in = module.weight.shape[1] * module.weight.shape[2] # (in_channels/groups) * kernel_size
|
| 133 |
+
init_std = (self.init_std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
|
| 134 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 135 |
+
if module.bias is not None:
|
| 136 |
+
torch.nn.init.zeros_(module.bias)
|
| 137 |
+
elif isinstance(module, nn.Embedding):
|
| 138 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.init_std/math.sqrt(self.dim))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def forward(
|
| 142 |
+
self,
|
| 143 |
+
x: Float[torch.Tensor, "b n d1`"], # nosied input mel
|
| 144 |
+
spk: Int[torch.Tensor, "b"], # speaker
|
| 145 |
+
f0: Float[torch.Tensor, "b n"],
|
| 146 |
+
rms: Float[torch.Tensor, "b n"],
|
| 147 |
+
cvec: Float[torch.Tensor, "b n d2"],
|
| 148 |
+
time: Float[torch.Tensor, "b"], # time step
|
| 149 |
+
drop_speaker: Union[bool, Bool[torch.Tensor, "b"]] = False,
|
| 150 |
+
mask: Bool[torch.Tensor, "b n"] | None = None,
|
| 151 |
+
skip_layers: Union[int, List[int], None] = None,
|
| 152 |
+
):
|
| 153 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 154 |
+
if time.ndim == 0:
|
| 155 |
+
time = repeat(time, ' -> b', b = batch)
|
| 156 |
+
|
| 157 |
+
if isinstance(drop_speaker, bool):
|
| 158 |
+
drop_speaker = torch.full((batch,), drop_speaker, dtype=torch.bool, device=x.device)
|
| 159 |
+
|
| 160 |
+
spk_embeds = self.spk_embed(spk)
|
| 161 |
+
null_spk_embeds = self.null_spk_embed(torch.zeros_like(spk, dtype=torch.long))
|
| 162 |
+
spk_embeds = torch.where(drop_speaker.unsqueeze(-1), null_spk_embeds, spk_embeds)
|
| 163 |
+
|
| 164 |
+
t = self.tembed(time)
|
| 165 |
+
t = t + spk_embeds
|
| 166 |
+
|
| 167 |
+
cond_embed = self.cond_embed(f0, rms, cvec)
|
| 168 |
+
x = self.input_embed(x, cond_embed)
|
| 169 |
+
|
| 170 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 171 |
+
|
| 172 |
+
if skip_layers is not None:
|
| 173 |
+
if isinstance(skip_layers, int):
|
| 174 |
+
skip_layers = [skip_layers]
|
| 175 |
+
|
| 176 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 177 |
+
if skip_layers is not None and i in skip_layers:
|
| 178 |
+
continue
|
| 179 |
+
x = block(x, t, mask = mask, rope = rope)
|
| 180 |
+
|
| 181 |
+
x = self.norm_out(x, t)
|
| 182 |
+
output = self.output(x)
|
| 183 |
+
|
| 184 |
+
return output
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def apply_lora(self, rank, alpha):
|
| 188 |
+
for n, p in self.named_parameters():
|
| 189 |
+
p.requires_grad = False
|
| 190 |
+
self.spk_embed.weight.requires_grad = True
|
| 191 |
+
# Apply LoRA to k_proj and v_proj in each attention block
|
| 192 |
+
for block in self.transformer_blocks:
|
| 193 |
+
block.attn.k_proj = LoRALinear(block.attn.k_proj, rank, alpha)
|
| 194 |
+
block.attn.v_proj = LoRALinear(block.attn.v_proj, rank, alpha)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def merge_lora(self):
|
| 198 |
+
# Iterate over each transformer block in the DiT backbone
|
| 199 |
+
for block in self.transformer_blocks:
|
| 200 |
+
# Merge for k_proj if it is a LoRALinear instance
|
| 201 |
+
if isinstance(block.attn.k_proj, LoRALinear):
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
# Compute delta update: B @ A^T
|
| 204 |
+
delta = block.attn.k_proj.B @ block.attn.k_proj.A.T
|
| 205 |
+
# The underlying linear layer has weight of shape (out_features, in_features)
|
| 206 |
+
# and its forward computes x * weight.T
|
| 207 |
+
# Note: delta.T equals A @ B^T, so merging works correctly:
|
| 208 |
+
block.attn.k_proj.linear.weight.add_(delta)
|
| 209 |
+
# Replace the LoRALinear module with the merged linear layer
|
| 210 |
+
block.attn.k_proj = block.attn.k_proj.linear
|
| 211 |
+
|
| 212 |
+
# Merge for v_proj in the same way
|
| 213 |
+
if isinstance(block.attn.v_proj, LoRALinear):
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
delta = block.attn.v_proj.B @ block.attn.v_proj.A.T
|
| 216 |
+
block.attn.v_proj.linear.weight.add_(delta)
|
| 217 |
+
block.attn.v_proj = block.attn.v_proj.linear
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def freeze_adaln_and_tembed(self):
|
| 221 |
+
for p in self.tembed.parameters():
|
| 222 |
+
p.requires_grad = False
|
| 223 |
+
for p in self.norm_out.parameters():
|
| 224 |
+
p.requires_grad = False
|
| 225 |
+
for block in self.transformer_blocks:
|
| 226 |
+
for p in block.attn_norm.parameters():
|
| 227 |
+
p.requires_grad = False
|
rift_svc/feature_extractors.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from jaxtyping import Float
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
from transformers import HubertModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def dynamic_range_compression_torch(
|
| 9 |
+
x: Float[torch.Tensor, "n_mels mel_len"],
|
| 10 |
+
C: float = 1,
|
| 11 |
+
clip_val: float = 1e-5
|
| 12 |
+
) -> Float[torch.Tensor, "n_mels mel_len"]:
|
| 13 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def spectral_normalize_torch(
|
| 17 |
+
magnitudes: Float[torch.Tensor, "n_mels mel_len"]
|
| 18 |
+
) -> Float[torch.Tensor, "n_mels mel_len"]:
|
| 19 |
+
return dynamic_range_compression_torch(magnitudes)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
mel_basis_cache = {}
|
| 23 |
+
hann_window_cache = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_mel_spectrogram(
|
| 27 |
+
y: Float[torch.Tensor, "n"],
|
| 28 |
+
n_fft: int = 2048,
|
| 29 |
+
num_mels: int = 128,
|
| 30 |
+
sampling_rate: int = 44100,
|
| 31 |
+
hop_size: int = 512,
|
| 32 |
+
win_size: int = 2048,
|
| 33 |
+
fmin: int = 40,
|
| 34 |
+
fmax: int | None = 16000,
|
| 35 |
+
center: bool = False,
|
| 36 |
+
) -> Float[torch.Tensor, "n_mels mel_len"]:
|
| 37 |
+
"""
|
| 38 |
+
Calculate the mel spectrogram of an input signal.
|
| 39 |
+
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
y (torch.Tensor): Input signal with shape (n,).
|
| 43 |
+
n_fft (int, optional): FFT size. Defaults to 1024.
|
| 44 |
+
num_mels (int, optional): Number of mel bins. Defaults to 128.
|
| 45 |
+
sampling_rate (int, optional): Sampling rate of the input signal. Defaults to 44100.
|
| 46 |
+
hop_size (int, optional): Hop size for STFT. Defaults to 256.
|
| 47 |
+
win_size (int, optional): Window size for STFT. Defaults to 1024.
|
| 48 |
+
fmin (int, optional): Minimum frequency for mel filterbank. Defaults to 0.
|
| 49 |
+
fmax (int | None, optional): Maximum frequency for mel filterbank. If None, defaults to sr/2.0. Defaults to None.
|
| 50 |
+
center (bool, optional): Whether to pad the input to center the frames. Defaults to False.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: Mel spectrogram with shape (n_mels, mel_len).
|
| 54 |
+
"""
|
| 55 |
+
if torch.min(y) < -1.0:
|
| 56 |
+
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
| 57 |
+
if torch.max(y) > 1.0:
|
| 58 |
+
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
| 59 |
+
|
| 60 |
+
device = y.device
|
| 61 |
+
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
| 62 |
+
|
| 63 |
+
if key not in mel_basis_cache:
|
| 64 |
+
mel = librosa_mel_fn(
|
| 65 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 66 |
+
)
|
| 67 |
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
| 68 |
+
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
| 69 |
+
|
| 70 |
+
mel_basis = mel_basis_cache[key]
|
| 71 |
+
hann_window = hann_window_cache[key]
|
| 72 |
+
|
| 73 |
+
padding = (n_fft - hop_size) // 2
|
| 74 |
+
y = torch.nn.functional.pad(
|
| 75 |
+
y.unsqueeze(1), (padding, padding), mode="reflect"
|
| 76 |
+
).squeeze(1)
|
| 77 |
+
|
| 78 |
+
spec = torch.stft(
|
| 79 |
+
y,
|
| 80 |
+
n_fft,
|
| 81 |
+
hop_length=hop_size,
|
| 82 |
+
win_length=win_size,
|
| 83 |
+
window=hann_window,
|
| 84 |
+
center=center,
|
| 85 |
+
pad_mode="reflect",
|
| 86 |
+
normalized=False,
|
| 87 |
+
onesided=True,
|
| 88 |
+
return_complex=True,
|
| 89 |
+
)
|
| 90 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 91 |
+
|
| 92 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
| 93 |
+
mel_spec = spectral_normalize_torch(mel_spec)
|
| 94 |
+
|
| 95 |
+
return mel_spec
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RMSExtractor(nn.Module):
|
| 99 |
+
def __init__(self, hop_length=512, window_length=2048):
|
| 100 |
+
"""
|
| 101 |
+
Initializes the RMSExtractor with the specified hop_length.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
hop_length (int): Number of samples between successive frames.
|
| 105 |
+
"""
|
| 106 |
+
super(RMSExtractor, self).__init__()
|
| 107 |
+
self.hop_length = hop_length
|
| 108 |
+
self.window_length = window_length
|
| 109 |
+
|
| 110 |
+
def forward(self, inp):
|
| 111 |
+
"""
|
| 112 |
+
Extracts RMS energy from the input audio tensor.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
inp (Tensor): Audio tensor of shape (batch, samples).
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Tensor: RMS energy tensor of shape (batch, frames).
|
| 119 |
+
"""
|
| 120 |
+
# Square the audio signal
|
| 121 |
+
audio_squared = inp ** 2
|
| 122 |
+
|
| 123 |
+
# Use the same padding as mel spectrogram
|
| 124 |
+
padding = (self.window_length - self.hop_length) // 2
|
| 125 |
+
audio_padded = torch.nn.functional.pad(
|
| 126 |
+
audio_squared, (padding, padding), mode='reflect'
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Unfold to create frames with window_length instead of hop_length
|
| 130 |
+
frames = audio_padded.unfold(1, self.window_length, self.hop_length) # Shape: (batch, frames, window_length)
|
| 131 |
+
|
| 132 |
+
# Compute mean energy per frame
|
| 133 |
+
mean_energy = frames.mean(dim=-1) # Shape: (batch, frames)
|
| 134 |
+
|
| 135 |
+
# Compute RMS by taking square root
|
| 136 |
+
rms = torch.sqrt(mean_energy) # Shape: (batch, frames)
|
| 137 |
+
|
| 138 |
+
return rms
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class HubertModelWithFinalProj(HubertModel):
|
| 142 |
+
def __init__(self, config):
|
| 143 |
+
super().__init__(config)
|
| 144 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
rift_svc/lightning_module.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torchaudio
|
| 7 |
+
import wandb
|
| 8 |
+
from functools import partial
|
| 9 |
+
import inspect
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import LightningModule
|
| 12 |
+
|
| 13 |
+
from rift_svc.metrics import mcd, psnr, si_snr
|
| 14 |
+
from rift_svc.feature_extractors import get_mel_spectrogram
|
| 15 |
+
from rift_svc.nsf_hifigan import NsfHifiGAN
|
| 16 |
+
from rift_svc.utils import draw_mel_specs, l2_grad_norm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RIFTSVCLightningModule(LightningModule):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model,
|
| 23 |
+
optimizer,
|
| 24 |
+
cfg,
|
| 25 |
+
lr_scheduler=None,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.model = model
|
| 29 |
+
self.optimizer = optimizer
|
| 30 |
+
self.lr_scheduler = lr_scheduler
|
| 31 |
+
self.cfg = cfg
|
| 32 |
+
self.eval_sample_steps = cfg['training']['eval_sample_steps']
|
| 33 |
+
self.model.sample = partial(
|
| 34 |
+
self.model.sample,
|
| 35 |
+
steps=self.eval_sample_steps,
|
| 36 |
+
)
|
| 37 |
+
self.log_media_per_steps = cfg['training']['log_media_per_steps']
|
| 38 |
+
self.drop_spk_prob = cfg['training']['drop_spk_prob']
|
| 39 |
+
|
| 40 |
+
self.vocoder = None
|
| 41 |
+
self.save_hyperparameters(ignore=['model', 'optimizer', 'vocoder'])
|
| 42 |
+
|
| 43 |
+
def configure_optimizers(self):
|
| 44 |
+
if self.lr_scheduler is None:
|
| 45 |
+
return self.optimizer
|
| 46 |
+
return {
|
| 47 |
+
"optimizer": self.optimizer,
|
| 48 |
+
"lr_scheduler": {
|
| 49 |
+
"scheduler": self.lr_scheduler,
|
| 50 |
+
"interval": "step",
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def training_step(self, batch, batch_idx):
|
| 55 |
+
mel = batch['mel']
|
| 56 |
+
spk_id = batch['spk_id']
|
| 57 |
+
f0 = batch['f0']
|
| 58 |
+
rms = batch['rms']
|
| 59 |
+
cvec = batch['cvec']
|
| 60 |
+
frame_len = batch['frame_len']
|
| 61 |
+
|
| 62 |
+
drop_speaker = False
|
| 63 |
+
if self.drop_spk_prob > 0:
|
| 64 |
+
batch_size = spk_id.shape[0]
|
| 65 |
+
num_drop = int(batch_size * self.drop_spk_prob)
|
| 66 |
+
drop_speaker = torch.zeros(batch_size, dtype=torch.bool, device=spk_id.device)
|
| 67 |
+
drop_speaker[:num_drop] = True
|
| 68 |
+
# Randomly shuffle the drop mask
|
| 69 |
+
drop_speaker = drop_speaker[torch.randperm(batch_size)]
|
| 70 |
+
|
| 71 |
+
loss, _ = self.model(
|
| 72 |
+
mel,
|
| 73 |
+
spk_id=spk_id,
|
| 74 |
+
f0=f0,
|
| 75 |
+
rms=rms,
|
| 76 |
+
cvec=cvec,
|
| 77 |
+
drop_speaker=drop_speaker,
|
| 78 |
+
frame_len=frame_len,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Log metrics - compatible with both loggers
|
| 82 |
+
self._log_scalar("train/loss", loss.item(), prog_bar=True)
|
| 83 |
+
|
| 84 |
+
return loss
|
| 85 |
+
|
| 86 |
+
def on_validation_start(self):
|
| 87 |
+
if hasattr(self.optimizer, 'eval'):
|
| 88 |
+
self.optimizer.eval()
|
| 89 |
+
if not self.trainer.is_global_zero:
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
if self.vocoder is None:
|
| 93 |
+
self.vocoder = NsfHifiGAN(
|
| 94 |
+
'pretrained/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt').to(self.device)
|
| 95 |
+
else:
|
| 96 |
+
self.vocoder = self.vocoder.to(self.device)
|
| 97 |
+
|
| 98 |
+
self.mcd = []
|
| 99 |
+
self.si_snr = []
|
| 100 |
+
self.psnr = []
|
| 101 |
+
self.mse = []
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def on_validation_end(self, log=True):
|
| 105 |
+
if hasattr(self.optimizer, 'eval'):
|
| 106 |
+
self.optimizer.train()
|
| 107 |
+
if not self.trainer.is_global_zero:
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
if self.vocoder is not None:
|
| 111 |
+
self.vocoder = self.vocoder.cpu()
|
| 112 |
+
gc.collect()
|
| 113 |
+
torch.cuda.empty_cache()
|
| 114 |
+
|
| 115 |
+
metrics = {
|
| 116 |
+
'val/mcd': np.mean(self.mcd),
|
| 117 |
+
'val/si_snr': np.mean(self.si_snr),
|
| 118 |
+
'val/psnr': np.mean(self.psnr),
|
| 119 |
+
'val/mse': np.mean(self.mse)
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
if log:
|
| 123 |
+
# Log metrics - compatible with both loggers
|
| 124 |
+
for metric_name, metric_value in metrics.items():
|
| 125 |
+
self._log_scalar(metric_name, metric_value)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def validation_step(self, batch, batch_idx, log=True):
|
| 129 |
+
"""
|
| 130 |
+
Process validation step and log metrics and media.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
batch: Input batch
|
| 134 |
+
batch_idx: Batch index
|
| 135 |
+
log: Whether to log or not
|
| 136 |
+
"""
|
| 137 |
+
# Skip if not the main process or logging is disabled
|
| 138 |
+
if not self.trainer.is_global_zero:
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
# Get step and interval info
|
| 142 |
+
global_step = self.global_step
|
| 143 |
+
log_media_every_n_steps = self.log_media_every_n_steps
|
| 144 |
+
|
| 145 |
+
# Extract input data
|
| 146 |
+
spk_id = batch['spk_id']
|
| 147 |
+
mel_gt = batch['mel']
|
| 148 |
+
rms = batch['rms']
|
| 149 |
+
f0 = batch['f0']
|
| 150 |
+
cvec = batch['cvec']
|
| 151 |
+
frame_len = batch['frame_len']
|
| 152 |
+
cvec_ds = batch.get('cvec_ds', None)
|
| 153 |
+
|
| 154 |
+
# Generate output
|
| 155 |
+
mel_gen, _ = self.model.sample(
|
| 156 |
+
src_mel=mel_gt,
|
| 157 |
+
spk_id=spk_id,
|
| 158 |
+
f0=f0,
|
| 159 |
+
rms=rms,
|
| 160 |
+
cvec=cvec,
|
| 161 |
+
frame_len=frame_len,
|
| 162 |
+
bad_cvec=cvec_ds,
|
| 163 |
+
)
|
| 164 |
+
mel_gen = mel_gen.float()
|
| 165 |
+
mel_gt = mel_gt.float()
|
| 166 |
+
|
| 167 |
+
# Process each sample in the batch
|
| 168 |
+
for i in range(mel_gen.shape[0]):
|
| 169 |
+
sample_idx = batch_idx * mel_gen.shape[0] + i
|
| 170 |
+
|
| 171 |
+
# Generate audio using vocoder
|
| 172 |
+
wav_gen = self.vocoder(mel_gen[i:i+1, :frame_len[i], :].transpose(1, 2), f0[i:i+1, :frame_len[i]])
|
| 173 |
+
wav_gt = self.vocoder(mel_gt[i:i+1, :frame_len[i], :].transpose(1, 2), f0[i:i+1, :frame_len[i]])
|
| 174 |
+
wav_gen = wav_gen.squeeze(0)
|
| 175 |
+
wav_gt = wav_gt.squeeze(0)
|
| 176 |
+
|
| 177 |
+
# Generate mel spectrograms
|
| 178 |
+
mel_gen_i = get_mel_spectrogram(wav_gen).transpose(1, 2)
|
| 179 |
+
mel_gt_i = get_mel_spectrogram(wav_gt).transpose(1, 2)
|
| 180 |
+
|
| 181 |
+
# Clip values to valid range
|
| 182 |
+
mel_min, mel_max = self.model.mel_min, self.model.mel_max
|
| 183 |
+
mel_gen_i = torch.clip(mel_gen_i, min=mel_min, max=mel_max)
|
| 184 |
+
mel_gt_i = torch.clip(mel_gt_i, min=mel_min, max=mel_max)
|
| 185 |
+
|
| 186 |
+
# Calculate metrics
|
| 187 |
+
self.mcd.append(mcd(mel_gen_i, mel_gt_i).cpu().item())
|
| 188 |
+
self.si_snr.append(si_snr(mel_gen_i, mel_gt_i).cpu().item())
|
| 189 |
+
self.psnr.append(psnr(mel_gen_i, mel_gt_i).cpu().item())
|
| 190 |
+
self.mse.append(F.mse_loss(mel_gen_i, mel_gt_i).cpu().item())
|
| 191 |
+
|
| 192 |
+
if log:
|
| 193 |
+
# Create cache directory if it doesn't exist
|
| 194 |
+
os.makedirs('.cache', exist_ok=True)
|
| 195 |
+
|
| 196 |
+
# Log generated audio at specified intervals
|
| 197 |
+
if global_step % log_media_every_n_steps == 0:
|
| 198 |
+
audio_path = f".cache/spk-{spk_id[i].item()}_{sample_idx}_gen.wav"
|
| 199 |
+
torchaudio.save(audio_path, wav_gen.cpu().to(torch.float32), 44100)
|
| 200 |
+
self._log_audio(self.logger, f"val-audio/spk-{spk_id[i].item()}_{sample_idx}-gen", audio_path, global_step)
|
| 201 |
+
|
| 202 |
+
# Log ground truth audio only at the first step
|
| 203 |
+
if global_step == 0:
|
| 204 |
+
gt_audio_path = f".cache/spk-{spk_id[i].item()}_{sample_idx}_gt.wav"
|
| 205 |
+
torchaudio.save(gt_audio_path, wav_gt.cpu().to(torch.float32), 44100)
|
| 206 |
+
self._log_audio(self.logger, f"val-audio/spk-{spk_id[i].item()}_{sample_idx}-gt", gt_audio_path, global_step)
|
| 207 |
+
|
| 208 |
+
# Log mel spectrograms at specified intervals
|
| 209 |
+
if global_step % log_media_every_n_steps == 0:
|
| 210 |
+
# Create mel spectrogram visualization
|
| 211 |
+
data_gt = mel_gt_i.squeeze().T.cpu().numpy()
|
| 212 |
+
data_gen = mel_gen_i.squeeze().T.cpu().numpy()
|
| 213 |
+
data_abs_diff = data_gen - data_gt
|
| 214 |
+
cache_path = f".cache/{sample_idx}_mel.jpg"
|
| 215 |
+
draw_mel_specs(data_gt, data_gen, data_abs_diff, cache_path)
|
| 216 |
+
self._log_image(self.logger, f"val-mel/{sample_idx}_mel", cache_path, global_step)
|
| 217 |
+
|
| 218 |
+
def on_test_start(self):
|
| 219 |
+
self.on_validation_start()
|
| 220 |
+
|
| 221 |
+
def on_test_end(self):
|
| 222 |
+
self.on_validation_end(log=False)
|
| 223 |
+
|
| 224 |
+
def test_step(self, batch, batch_idx):
|
| 225 |
+
self.validation_step(batch, batch_idx, log=False)
|
| 226 |
+
|
| 227 |
+
def on_before_optimizer_step(self, optimizer):
|
| 228 |
+
# Calculate gradient norm
|
| 229 |
+
norm = l2_grad_norm(self.model)
|
| 230 |
+
|
| 231 |
+
# Log gradient norm
|
| 232 |
+
self._log_scalar("train/grad_norm", norm)
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def global_step(self):
|
| 236 |
+
return self.trainer.global_step
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def log_media_every_n_steps(self):
|
| 240 |
+
if self.log_media_per_steps is not None:
|
| 241 |
+
return self.log_media_per_steps
|
| 242 |
+
if self.save_every_n_steps is None:
|
| 243 |
+
return self.trainer.val_check_interval
|
| 244 |
+
return self.save_every_n_steps
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def save_every_n_steps(self):
|
| 248 |
+
for callback in self.trainer.callbacks:
|
| 249 |
+
if hasattr(callback, '_every_n_train_steps'):
|
| 250 |
+
return callback._every_n_train_steps
|
| 251 |
+
return None
|
| 252 |
+
|
| 253 |
+
@property
|
| 254 |
+
def is_using_wandb(self):
|
| 255 |
+
"""
|
| 256 |
+
Check if WandB logger is being used.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
bool: True if WandB logger is being used, False otherwise
|
| 260 |
+
"""
|
| 261 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 262 |
+
if isinstance(self.logger, WandbLogger):
|
| 263 |
+
return True
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def is_using_tensorboard(self):
|
| 268 |
+
"""
|
| 269 |
+
Check if TensorBoard logger is being used.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
bool: True if TensorBoard logger is being used, False otherwise
|
| 273 |
+
"""
|
| 274 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 275 |
+
if isinstance(self.logger, TensorBoardLogger):
|
| 276 |
+
return True
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
@property
|
| 280 |
+
def logger_type(self):
|
| 281 |
+
"""
|
| 282 |
+
Get a string representation of the logger type.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
str: 'wandb', 'tensorboard', or 'unknown'
|
| 286 |
+
"""
|
| 287 |
+
if self.is_using_wandb:
|
| 288 |
+
return 'wandb'
|
| 289 |
+
elif self.is_using_tensorboard:
|
| 290 |
+
return 'tensorboard'
|
| 291 |
+
else:
|
| 292 |
+
return 'unknown'
|
| 293 |
+
|
| 294 |
+
def state_dict(self, *args, **kwargs):
|
| 295 |
+
# Temporarily store vocoder
|
| 296 |
+
vocoder = self.vocoder
|
| 297 |
+
self.vocoder = None
|
| 298 |
+
|
| 299 |
+
# Get state dict without vocoder
|
| 300 |
+
state = super().state_dict(*args, **kwargs)
|
| 301 |
+
|
| 302 |
+
# Restore vocoder
|
| 303 |
+
self.vocoder = vocoder
|
| 304 |
+
return state
|
| 305 |
+
|
| 306 |
+
# Add helper methods for logging with different logger types
|
| 307 |
+
def _log_scalar(self, name, value, step=None, **kwargs):
|
| 308 |
+
"""
|
| 309 |
+
Log a scalar value to the appropriate logger.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
name: Name of the metric
|
| 313 |
+
value: Value of the metric
|
| 314 |
+
step: Step value (defaults to current global step if None)
|
| 315 |
+
**kwargs: Additional arguments to pass to the logger
|
| 316 |
+
"""
|
| 317 |
+
if step is None:
|
| 318 |
+
step = self.global_step
|
| 319 |
+
|
| 320 |
+
# Special handling for on_validation_end or on_test_end
|
| 321 |
+
# Get the caller function name to determine if we're in on_validation_end
|
| 322 |
+
caller_frame = inspect.currentframe().f_back
|
| 323 |
+
caller_function = caller_frame.f_code.co_name
|
| 324 |
+
|
| 325 |
+
if caller_function in ['on_validation_end', 'on_test_end']:
|
| 326 |
+
# Use logger.experiment directly as self.log() is not allowed in these hooks
|
| 327 |
+
if self.is_using_wandb:
|
| 328 |
+
self.logger.experiment.log({name: value}, step=step)
|
| 329 |
+
elif self.is_using_tensorboard:
|
| 330 |
+
self.logger.experiment.add_scalar(name, value, step)
|
| 331 |
+
# Add other logger types here if needed
|
| 332 |
+
else:
|
| 333 |
+
# Use PyTorch Lightning's built-in logging system for scalars
|
| 334 |
+
# This handles different logger types automatically
|
| 335 |
+
self.log(name, value, **kwargs)
|
| 336 |
+
|
| 337 |
+
def _log_audio(self, logger, name, file_path, step):
|
| 338 |
+
"""
|
| 339 |
+
Log audio to the appropriate logger.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
logger: The logger instance
|
| 343 |
+
name: Name of the audio
|
| 344 |
+
file_path: Path to the audio file
|
| 345 |
+
step: Step value
|
| 346 |
+
"""
|
| 347 |
+
try:
|
| 348 |
+
if hasattr(logger, 'experiment') and hasattr(logger.experiment, 'log'):
|
| 349 |
+
# WandbLogger
|
| 350 |
+
import wandb
|
| 351 |
+
logger.experiment.log({
|
| 352 |
+
name: wandb.Audio(file_path, sample_rate=44100)
|
| 353 |
+
}, step=step)
|
| 354 |
+
elif hasattr(logger, 'experiment') and hasattr(logger.experiment, 'add_audio'):
|
| 355 |
+
# TensorBoardLogger
|
| 356 |
+
import soundfile as sf
|
| 357 |
+
audio, sample_rate = sf.read(file_path)
|
| 358 |
+
logger.experiment.add_audio(name, audio, step, sample_rate=44100)
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(f"Warning: Failed to log audio {name}: {e}")
|
| 361 |
+
|
| 362 |
+
def _log_image(self, logger, name, file_path, step):
|
| 363 |
+
"""
|
| 364 |
+
Log an image to the appropriate logger.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
logger: The logger instance
|
| 368 |
+
name: Name of the image
|
| 369 |
+
file_path: Path to the image file
|
| 370 |
+
step: Step value
|
| 371 |
+
"""
|
| 372 |
+
try:
|
| 373 |
+
if hasattr(logger, 'experiment') and hasattr(logger.experiment, 'log'):
|
| 374 |
+
# WandbLogger
|
| 375 |
+
import wandb
|
| 376 |
+
logger.experiment.log({
|
| 377 |
+
name: wandb.Image(file_path)
|
| 378 |
+
}, step=step)
|
| 379 |
+
elif hasattr(logger, 'experiment') and hasattr(logger.experiment, 'add_image'):
|
| 380 |
+
# TensorBoardLogger
|
| 381 |
+
import PIL.Image
|
| 382 |
+
import numpy as np
|
| 383 |
+
import torch
|
| 384 |
+
image = PIL.Image.open(file_path)
|
| 385 |
+
image_array = np.array(image)
|
| 386 |
+
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) # HWC to CHW
|
| 387 |
+
logger.experiment.add_image(name, image_tensor, step)
|
| 388 |
+
except Exception as e:
|
| 389 |
+
print(f"Warning: Failed to log image {name}: {e}")
|
rift_svc/metrics.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def psnr(estimated, target, max_val=None):
|
| 5 |
+
"""Calculate Peak Signal-to-Noise Ratio (PSNR)
|
| 6 |
+
Args:
|
| 7 |
+
estimated (torch.Tensor): Estimated mel spectrogram [B, len, n_mel]
|
| 8 |
+
target (torch.Tensor): Target mel spectrogram [B, len, n_mel]
|
| 9 |
+
max_val (float): Maximum value of the signal. If None, uses max of target
|
| 10 |
+
Returns:
|
| 11 |
+
torch.Tensor: PSNR value in dB [B]
|
| 12 |
+
"""
|
| 13 |
+
if max_val is None:
|
| 14 |
+
# Use the maximum absolute value between both tensors
|
| 15 |
+
max_val = max(torch.abs(target).max(), torch.abs(estimated).max())
|
| 16 |
+
|
| 17 |
+
# Ensure max_val is not zero
|
| 18 |
+
max_val = max(max_val, torch.finfo(target.dtype).eps)
|
| 19 |
+
|
| 20 |
+
mse = torch.mean((estimated - target) ** 2, dim=(1, 2))
|
| 21 |
+
# Add eps to avoid log of zero
|
| 22 |
+
eps = torch.finfo(target.dtype).eps
|
| 23 |
+
psnr = 20 * torch.log10(max_val + eps) - 10 * torch.log10(mse + eps)
|
| 24 |
+
return psnr
|
| 25 |
+
|
| 26 |
+
def si_snr(estimated, target, eps=1e-8):
|
| 27 |
+
"""Calculate Scale-Invariant Signal-to-Noise Ratio (SI-SNR)
|
| 28 |
+
Args:
|
| 29 |
+
estimated (torch.Tensor): Estimated mel spectrogram [B, len, n_mel]
|
| 30 |
+
target (torch.Tensor): Target mel spectrogram [B, len, n_mel]
|
| 31 |
+
eps (float): Small value to avoid division by zero
|
| 32 |
+
Returns:
|
| 33 |
+
torch.Tensor: SI-SNR value in dB [B]
|
| 34 |
+
"""
|
| 35 |
+
# Flatten the mel dimension
|
| 36 |
+
estimated = estimated.reshape(estimated.shape[0], -1)
|
| 37 |
+
target = target.reshape(target.shape[0], -1)
|
| 38 |
+
|
| 39 |
+
# Zero-mean normalization
|
| 40 |
+
estimated = estimated - torch.mean(estimated, dim=1, keepdim=True)
|
| 41 |
+
target = target - torch.mean(target, dim=1, keepdim=True)
|
| 42 |
+
|
| 43 |
+
# SI-SNR
|
| 44 |
+
alpha = torch.sum(estimated * target, dim=1, keepdim=True) / (
|
| 45 |
+
torch.sum(target ** 2, dim=1, keepdim=True) + eps)
|
| 46 |
+
target_scaled = alpha * target
|
| 47 |
+
|
| 48 |
+
si_snr = 10 * torch.log10(
|
| 49 |
+
torch.sum(target_scaled ** 2, dim=1) /
|
| 50 |
+
(torch.sum((estimated - target_scaled) ** 2, dim=1) + eps) + eps
|
| 51 |
+
)
|
| 52 |
+
return si_snr
|
| 53 |
+
|
| 54 |
+
def mcd(estimated, target):
|
| 55 |
+
"""Calculate Mel-Cepstral Distortion (MCD)
|
| 56 |
+
Args:
|
| 57 |
+
estimated (torch.Tensor): Estimated mel spectrogram [B, len, n_mel]
|
| 58 |
+
target (torch.Tensor): Target mel spectrogram [B, len, n_mel]
|
| 59 |
+
Returns:
|
| 60 |
+
torch.Tensor: MCD value [B], averaged over time steps
|
| 61 |
+
"""
|
| 62 |
+
# Convert to log scale
|
| 63 |
+
estimated = torch.log10(torch.clamp(estimated, min=1e-8))
|
| 64 |
+
target = torch.log10(torch.clamp(target, min=1e-8))
|
| 65 |
+
|
| 66 |
+
# Calculate MCD
|
| 67 |
+
diff = estimated - target
|
| 68 |
+
mcd = torch.sqrt(2 * torch.sum(diff ** 2, dim=2)) # [B, len]
|
| 69 |
+
# Average over time dimension
|
| 70 |
+
mcd = mcd.mean(dim=1) # [B]
|
| 71 |
+
return mcd
|
rift_svc/modules.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from jaxtyping import Float, Bool
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LoRALinear(nn.Module):
|
| 14 |
+
def __init__(self, linear, rank, alpha):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.linear = linear
|
| 17 |
+
self.rank = rank
|
| 18 |
+
self.alpha = alpha
|
| 19 |
+
self.scale = alpha / math.sqrt(rank)
|
| 20 |
+
in_features = linear.in_features
|
| 21 |
+
out_features = linear.out_features
|
| 22 |
+
self.A = nn.Parameter(torch.zeros(in_features, rank))
|
| 23 |
+
self.B = nn.Parameter(torch.zeros(out_features, rank))
|
| 24 |
+
# Initialize LoRA parameters
|
| 25 |
+
nn.init.normal_(self.A, mean=0, std=math.sqrt(self.rank) / self.linear.in_features)
|
| 26 |
+
nn.init.zeros_(self.B)
|
| 27 |
+
# Freeze original linear layer parameters
|
| 28 |
+
self.linear.weight.requires_grad = False
|
| 29 |
+
if self.linear.bias is not None:
|
| 30 |
+
self.linear.bias.requires_grad = False
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
original_out = self.linear(x)
|
| 34 |
+
lora_out = (x @ self.A) @ self.B.T
|
| 35 |
+
return original_out + lora_out * self.scale
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# AdaLayerNormZero
|
| 39 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
| 40 |
+
class AdaLayerNormZero(nn.Module):
|
| 41 |
+
def __init__(self, dim):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.silu = nn.SiLU()
|
| 45 |
+
self.proj = nn.Linear(dim, dim * 6)
|
| 46 |
+
|
| 47 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 48 |
+
|
| 49 |
+
def forward(self, x, emb = None):
|
| 50 |
+
emb = self.proj(self.silu(emb))
|
| 51 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
| 52 |
+
|
| 53 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 54 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# AdaLayerNormZero for final layer
|
| 58 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 59 |
+
class AdaLayerNormZero_Final(nn.Module):
|
| 60 |
+
def __init__(self, dim):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.silu = nn.SiLU()
|
| 64 |
+
self.proj = nn.Linear(dim, dim * 2)
|
| 65 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 66 |
+
|
| 67 |
+
def forward(self, x, emb):
|
| 68 |
+
emb = self.proj(self.silu(emb))
|
| 69 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 70 |
+
|
| 71 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
# ReLU^2
|
| 75 |
+
class ReLU2(nn.Module):
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
return F.relu(x, inplace=True).square()
|
| 78 |
+
|
| 79 |
+
# FeedForward
|
| 80 |
+
class ConvMLP(nn.Module):
|
| 81 |
+
def __init__(self, dim: int, dim_out: int | None = None, mult: float = 4, dropout: float = 0.0, kernel_size: int = 7):
|
| 82 |
+
super().__init__()
|
| 83 |
+
inner_dim = int(dim * mult)
|
| 84 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 85 |
+
|
| 86 |
+
#self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
| 87 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
|
| 88 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 89 |
+
self.activation = ReLU2()
|
| 90 |
+
self.dropout = nn.Dropout(dropout)
|
| 91 |
+
self.mlp_proj = nn.Linear(dim, inner_dim)
|
| 92 |
+
self.mlp_out = nn.Linear(inner_dim, dim_out)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
x = x.permute(0, 2, 1)
|
| 96 |
+
x = self.dwconv(x)
|
| 97 |
+
x = x.permute(0, 2, 1)
|
| 98 |
+
x = self.norm(x)
|
| 99 |
+
x = self.mlp_proj(x)
|
| 100 |
+
x = self.activation(x)
|
| 101 |
+
x = self.dropout(x)
|
| 102 |
+
x = self.mlp_out(x)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Attention(nn.Module):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
dim: int,
|
| 110 |
+
head_dim: int = 64,
|
| 111 |
+
dropout: float = 0.0,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 116 |
+
raise ImportError("Attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 117 |
+
|
| 118 |
+
self.dim = dim
|
| 119 |
+
assert dim % head_dim == 0
|
| 120 |
+
self.head_dim = head_dim
|
| 121 |
+
self.num_heads = int(dim // head_dim)
|
| 122 |
+
self.inner_dim = dim
|
| 123 |
+
self.dropout = dropout
|
| 124 |
+
self.scale = 1 / dim
|
| 125 |
+
|
| 126 |
+
self.q_proj = nn.Linear(dim, self.inner_dim)
|
| 127 |
+
self.k_proj = nn.Linear(dim, self.inner_dim)
|
| 128 |
+
self.v_proj = nn.Linear(dim, self.inner_dim)
|
| 129 |
+
|
| 130 |
+
self.norm_q = nn.LayerNorm(self.head_dim, elementwise_affine=False, eps=1e-6)
|
| 131 |
+
self.norm_k = nn.LayerNorm(self.head_dim, elementwise_affine=False, eps=1e-6)
|
| 132 |
+
|
| 133 |
+
self.attn_out = nn.Linear(self.inner_dim, dim)
|
| 134 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
x: Float[torch.Tensor, "b n d"],
|
| 139 |
+
mask: Bool[torch.Tensor, "b n"] | None = None,
|
| 140 |
+
rope = None,
|
| 141 |
+
) -> Float[torch.Tensor, "b n d"]:
|
| 142 |
+
batch_size = x.shape[0]
|
| 143 |
+
|
| 144 |
+
# projections
|
| 145 |
+
query = self.q_proj(x)
|
| 146 |
+
key = self.k_proj(x)
|
| 147 |
+
value = self.v_proj(x)
|
| 148 |
+
|
| 149 |
+
# apply rotary position embedding
|
| 150 |
+
if rope is not None:
|
| 151 |
+
freqs, xpos_scale = rope
|
| 152 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
| 153 |
+
|
| 154 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 155 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 156 |
+
|
| 157 |
+
# attention
|
| 158 |
+
inner_dim = key.shape[-1]
|
| 159 |
+
head_dim = inner_dim // self.num_heads
|
| 160 |
+
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
| 161 |
+
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
| 162 |
+
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
| 163 |
+
|
| 164 |
+
query = self.norm_q(query)
|
| 165 |
+
key = self.norm_k(key)
|
| 166 |
+
|
| 167 |
+
# mask
|
| 168 |
+
if mask is not None:
|
| 169 |
+
attn_mask = mask
|
| 170 |
+
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
|
| 171 |
+
attn_mask = attn_mask.expand(batch_size, self.num_heads, query.shape[-2], key.shape[-2])
|
| 172 |
+
else:
|
| 173 |
+
attn_mask = None
|
| 174 |
+
|
| 175 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, scale=self.scale)
|
| 176 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
|
| 177 |
+
x = x.to(query.dtype)
|
| 178 |
+
|
| 179 |
+
# linear proj and dropout
|
| 180 |
+
x = self.attn_out(x)
|
| 181 |
+
x = self.attn_dropout(x)
|
| 182 |
+
|
| 183 |
+
if mask is not None:
|
| 184 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 185 |
+
x = x.masked_fill(~mask, 0.)
|
| 186 |
+
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# DiT Block
|
| 191 |
+
class DiTBlock(nn.Module):
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self, dim: int, head_dim: int, ff_mult: float = 4,
|
| 195 |
+
dropout: float = 0.0, kernel_size: int = 31):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
| 199 |
+
self.attn = Attention(
|
| 200 |
+
dim = dim,
|
| 201 |
+
head_dim = head_dim,
|
| 202 |
+
dropout = dropout,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self.mlp_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 206 |
+
self.mlp = ConvMLP(dim = dim, mult = ff_mult, dropout = dropout, kernel_size=kernel_size)
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self,
|
| 210 |
+
x: Float[torch.Tensor, "b n d"],
|
| 211 |
+
t: Float[torch.Tensor, "b d"],
|
| 212 |
+
mask: Bool[torch.Tensor, "b n"] | None = None,
|
| 213 |
+
rope: Float[torch.Tensor, "b d"] | None = None,
|
| 214 |
+
) -> Float[torch.Tensor, "b n d"]:
|
| 215 |
+
# pre-norm & modulation for attention input
|
| 216 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
| 217 |
+
|
| 218 |
+
# attention
|
| 219 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
| 220 |
+
|
| 221 |
+
# process attention output for input x
|
| 222 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 223 |
+
|
| 224 |
+
norm = self.mlp_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 225 |
+
mlp_output = self.mlp(norm)
|
| 226 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_output
|
| 227 |
+
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# sinusoidal position embedding
|
| 232 |
+
class SinusPositionEmbedding(nn.Module):
|
| 233 |
+
def __init__(self, dim: int):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.dim = dim
|
| 236 |
+
|
| 237 |
+
def forward(self, x: Float[torch.Tensor, "b"], scale: float = 1000) -> Float[torch.Tensor, "b d"]:
|
| 238 |
+
device = x.device
|
| 239 |
+
half_dim = self.dim // 2
|
| 240 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 241 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 242 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 243 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 244 |
+
return emb
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# time step conditioning embedding
|
| 248 |
+
class TimestepEmbedding(nn.Module):
|
| 249 |
+
def __init__(self, dim: int, freq_embed_dim: int = 256):
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.time2emb = SinusPositionEmbedding(freq_embed_dim)
|
| 252 |
+
self.time_emb = nn.Linear(freq_embed_dim, dim)
|
| 253 |
+
self.act = nn.SiLU()
|
| 254 |
+
self.proj = nn.Linear(dim, dim)
|
| 255 |
+
|
| 256 |
+
def forward(self, timestep: Float[torch.Tensor, "b"]) -> Float[torch.Tensor, "b d"]:
|
| 257 |
+
time = self.time2emb(timestep)
|
| 258 |
+
time = self.time_emb(time)
|
| 259 |
+
time = self.act(time)
|
| 260 |
+
time = self.proj(time)
|
| 261 |
+
return time
|
rift_svc/nsf_hifigan/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .nvSTFT import STFT
|
| 2 |
+
from .vocoder import Vocoder, NsfHifiGAN
|
rift_svc/nsf_hifigan/env.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AttrDict(dict):
|
| 6 |
+
def __init__(self, *args, **kwargs):
|
| 7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 8 |
+
self.__dict__ = self
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_env(config, config_name, path):
|
| 12 |
+
t_path = os.path.join(path, config_name)
|
| 13 |
+
if config != t_path:
|
| 14 |
+
os.makedirs(path, exist_ok=True)
|
| 15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
rift_svc/nsf_hifigan/models.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from .env import AttrDict
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 9 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 10 |
+
from .utils import init_weights, get_padding
|
| 11 |
+
|
| 12 |
+
LRELU_SLOPE = 0.1
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model(model_path, device='cuda'):
|
| 16 |
+
h = load_config(model_path)
|
| 17 |
+
|
| 18 |
+
generator = Generator(h).to(device)
|
| 19 |
+
|
| 20 |
+
cp_dict = torch.load(model_path, map_location=device, weights_only=True)
|
| 21 |
+
generator.load_state_dict(cp_dict['generator'])
|
| 22 |
+
generator.eval()
|
| 23 |
+
generator.remove_weight_norm()
|
| 24 |
+
del cp_dict
|
| 25 |
+
return generator, h
|
| 26 |
+
|
| 27 |
+
def load_config(model_path):
|
| 28 |
+
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
|
| 29 |
+
with open(config_file) as f:
|
| 30 |
+
data = f.read()
|
| 31 |
+
|
| 32 |
+
json_config = json.loads(data)
|
| 33 |
+
h = AttrDict(json_config)
|
| 34 |
+
return h
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ResBlock1(torch.nn.Module):
|
| 38 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 39 |
+
super(ResBlock1, self).__init__()
|
| 40 |
+
self.h = h
|
| 41 |
+
self.convs1 = nn.ModuleList([
|
| 42 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 43 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 44 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 45 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
| 46 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
| 47 |
+
padding=get_padding(kernel_size, dilation[2])))
|
| 48 |
+
])
|
| 49 |
+
self.convs1.apply(init_weights)
|
| 50 |
+
|
| 51 |
+
self.convs2 = nn.ModuleList([
|
| 52 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 53 |
+
padding=get_padding(kernel_size, 1))),
|
| 54 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 55 |
+
padding=get_padding(kernel_size, 1))),
|
| 56 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 57 |
+
padding=get_padding(kernel_size, 1)))
|
| 58 |
+
])
|
| 59 |
+
self.convs2.apply(init_weights)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 63 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 64 |
+
xt = c1(xt)
|
| 65 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 66 |
+
xt = c2(xt)
|
| 67 |
+
x = xt + x
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
def remove_weight_norm(self):
|
| 71 |
+
for l in self.convs1:
|
| 72 |
+
remove_weight_norm(l)
|
| 73 |
+
for l in self.convs2:
|
| 74 |
+
remove_weight_norm(l)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ResBlock2(torch.nn.Module):
|
| 78 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 79 |
+
super(ResBlock2, self).__init__()
|
| 80 |
+
self.h = h
|
| 81 |
+
self.convs = nn.ModuleList([
|
| 82 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 83 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 84 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 85 |
+
padding=get_padding(kernel_size, dilation[1])))
|
| 86 |
+
])
|
| 87 |
+
self.convs.apply(init_weights)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
for c in self.convs:
|
| 91 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 92 |
+
xt = c(xt)
|
| 93 |
+
x = xt + x
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
def remove_weight_norm(self):
|
| 97 |
+
for l in self.convs:
|
| 98 |
+
remove_weight_norm(l)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SineGen(torch.nn.Module):
|
| 102 |
+
""" Definition of sine generator
|
| 103 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 104 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 105 |
+
voiced_threshold = 0,
|
| 106 |
+
flag_for_pulse=False)
|
| 107 |
+
samp_rate: sampling rate in Hz
|
| 108 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 109 |
+
sine_amp: amplitude of sine-waveform (default 0.1)
|
| 110 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 111 |
+
voiced_threshold: F0 threshold for U/V classification (default 0)
|
| 112 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 113 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 114 |
+
segment is always sin(np.pi) or cos(0)
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
| 118 |
+
sine_amp=0.1, noise_std=0.003,
|
| 119 |
+
voiced_threshold=0):
|
| 120 |
+
super(SineGen, self).__init__()
|
| 121 |
+
self.sine_amp = sine_amp
|
| 122 |
+
self.noise_std = noise_std
|
| 123 |
+
self.harmonic_num = harmonic_num
|
| 124 |
+
self.dim = self.harmonic_num + 1
|
| 125 |
+
self.sampling_rate = samp_rate
|
| 126 |
+
self.voiced_threshold = voiced_threshold
|
| 127 |
+
|
| 128 |
+
def _f02uv(self, f0):
|
| 129 |
+
# generate uv signal
|
| 130 |
+
uv = torch.ones_like(f0)
|
| 131 |
+
uv = uv * (f0 > self.voiced_threshold)
|
| 132 |
+
return uv
|
| 133 |
+
|
| 134 |
+
def _f02sine(self, f0, upp):
|
| 135 |
+
""" f0: (batchsize, length, dim)
|
| 136 |
+
where dim indicates fundamental tone and overtones
|
| 137 |
+
"""
|
| 138 |
+
rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, device=f0.device)
|
| 139 |
+
rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
|
| 140 |
+
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
|
| 141 |
+
rad += F.pad(rad_acc, (0, 0, 1, -1))
|
| 142 |
+
rad = rad.reshape(f0.shape[0], -1, 1)
|
| 143 |
+
rad = torch.multiply(rad, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1))
|
| 144 |
+
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
|
| 145 |
+
rand_ini[..., 0] = 0
|
| 146 |
+
rad += rand_ini
|
| 147 |
+
sines = torch.sin(2 * np.pi * rad)
|
| 148 |
+
return sines
|
| 149 |
+
|
| 150 |
+
@torch.no_grad()
|
| 151 |
+
def forward(self, f0, upp):
|
| 152 |
+
""" sine_tensor, uv = forward(f0)
|
| 153 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
| 154 |
+
f0 for unvoiced steps should be 0
|
| 155 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
| 156 |
+
output uv: tensor(batchsize=1, length, 1)
|
| 157 |
+
"""
|
| 158 |
+
f0 = f0.unsqueeze(-1)
|
| 159 |
+
sine_waves = self._f02sine(f0, upp) * self.sine_amp
|
| 160 |
+
uv = (f0 > self.voiced_threshold).float()
|
| 161 |
+
uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
|
| 162 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 163 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 164 |
+
sine_waves = sine_waves * uv + noise
|
| 165 |
+
return sine_waves
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 169 |
+
""" SourceModule for hn-nsf
|
| 170 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 171 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 172 |
+
sampling_rate: sampling_rate in Hz
|
| 173 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 174 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 175 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 176 |
+
note that amplitude of noise in unvoiced is decided
|
| 177 |
+
by sine_amp
|
| 178 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 179 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 180 |
+
F0_sampled (batchsize, length, 1)
|
| 181 |
+
Sine_source (batchsize, length, 1)
|
| 182 |
+
noise_source (batchsize, length 1)
|
| 183 |
+
uv (batchsize, length, 1)
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 187 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 188 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.sine_amp = sine_amp
|
| 191 |
+
self.noise_std = add_noise_std
|
| 192 |
+
|
| 193 |
+
# to produce sine waveforms
|
| 194 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
| 195 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 196 |
+
|
| 197 |
+
# to merge source harmonics into a single excitation
|
| 198 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 199 |
+
self.l_tanh = torch.nn.Tanh()
|
| 200 |
+
|
| 201 |
+
def forward(self, x, upp):
|
| 202 |
+
sine_wavs = self.l_sin_gen(x, upp)
|
| 203 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 204 |
+
return sine_merge
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Generator(torch.nn.Module):
|
| 208 |
+
def __init__(self, h):
|
| 209 |
+
super(Generator, self).__init__()
|
| 210 |
+
self.h = h
|
| 211 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 212 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 213 |
+
self.m_source = SourceModuleHnNSF(
|
| 214 |
+
sampling_rate=h.sampling_rate,
|
| 215 |
+
harmonic_num=8
|
| 216 |
+
)
|
| 217 |
+
self.noise_convs = nn.ModuleList()
|
| 218 |
+
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
| 219 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
| 220 |
+
|
| 221 |
+
self.ups = nn.ModuleList()
|
| 222 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 223 |
+
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
|
| 224 |
+
self.ups.append(weight_norm(
|
| 225 |
+
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
|
| 226 |
+
k, u, padding=(k - u) // 2)))
|
| 227 |
+
if i + 1 < len(h.upsample_rates): #
|
| 228 |
+
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
|
| 229 |
+
self.noise_convs.append(Conv1d(
|
| 230 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
| 231 |
+
else:
|
| 232 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
| 233 |
+
self.resblocks = nn.ModuleList()
|
| 234 |
+
ch = h.upsample_initial_channel
|
| 235 |
+
for i in range(len(self.ups)):
|
| 236 |
+
ch //= 2
|
| 237 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
| 238 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 239 |
+
|
| 240 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 241 |
+
self.ups.apply(init_weights)
|
| 242 |
+
self.conv_post.apply(init_weights)
|
| 243 |
+
self.upp = int(np.prod(h.upsample_rates))
|
| 244 |
+
|
| 245 |
+
def forward(self, x, f0):
|
| 246 |
+
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
| 247 |
+
x = self.conv_pre(x)
|
| 248 |
+
for i in range(self.num_upsamples):
|
| 249 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 250 |
+
x = self.ups[i](x)
|
| 251 |
+
x_source = self.noise_convs[i](har_source)
|
| 252 |
+
x = x + x_source
|
| 253 |
+
xs = None
|
| 254 |
+
for j in range(self.num_kernels):
|
| 255 |
+
if xs is None:
|
| 256 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 257 |
+
else:
|
| 258 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 259 |
+
x = xs / self.num_kernels
|
| 260 |
+
x = F.leaky_relu(x)
|
| 261 |
+
x = self.conv_post(x)
|
| 262 |
+
x = torch.tanh(x)
|
| 263 |
+
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
def remove_weight_norm(self):
|
| 267 |
+
print('Removing weight norm...')
|
| 268 |
+
for l in self.ups:
|
| 269 |
+
remove_weight_norm(l)
|
| 270 |
+
for l in self.resblocks:
|
| 271 |
+
l.remove_weight_norm()
|
| 272 |
+
remove_weight_norm(self.conv_pre)
|
| 273 |
+
remove_weight_norm(self.conv_post)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class DiscriminatorP(torch.nn.Module):
|
| 277 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 278 |
+
super(DiscriminatorP, self).__init__()
|
| 279 |
+
self.period = period
|
| 280 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 281 |
+
self.convs = nn.ModuleList([
|
| 282 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 283 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 284 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 285 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 286 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 287 |
+
])
|
| 288 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 289 |
+
|
| 290 |
+
def forward(self, x):
|
| 291 |
+
fmap = []
|
| 292 |
+
|
| 293 |
+
# 1d to 2d
|
| 294 |
+
b, c, t = x.shape
|
| 295 |
+
if t % self.period != 0: # pad first
|
| 296 |
+
n_pad = self.period - (t % self.period)
|
| 297 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 298 |
+
t = t + n_pad
|
| 299 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 300 |
+
|
| 301 |
+
for l in self.convs:
|
| 302 |
+
x = l(x)
|
| 303 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 304 |
+
fmap.append(x)
|
| 305 |
+
x = self.conv_post(x)
|
| 306 |
+
fmap.append(x)
|
| 307 |
+
x = torch.flatten(x, 1, -1)
|
| 308 |
+
|
| 309 |
+
return x, fmap
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 313 |
+
def __init__(self, periods=None):
|
| 314 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 315 |
+
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
|
| 316 |
+
self.discriminators = nn.ModuleList()
|
| 317 |
+
for period in self.periods:
|
| 318 |
+
self.discriminators.append(DiscriminatorP(period))
|
| 319 |
+
|
| 320 |
+
def forward(self, y, y_hat):
|
| 321 |
+
y_d_rs = []
|
| 322 |
+
y_d_gs = []
|
| 323 |
+
fmap_rs = []
|
| 324 |
+
fmap_gs = []
|
| 325 |
+
for i, d in enumerate(self.discriminators):
|
| 326 |
+
y_d_r, fmap_r = d(y)
|
| 327 |
+
y_d_g, fmap_g = d(y_hat)
|
| 328 |
+
y_d_rs.append(y_d_r)
|
| 329 |
+
fmap_rs.append(fmap_r)
|
| 330 |
+
y_d_gs.append(y_d_g)
|
| 331 |
+
fmap_gs.append(fmap_g)
|
| 332 |
+
|
| 333 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class DiscriminatorS(torch.nn.Module):
|
| 337 |
+
def __init__(self, use_spectral_norm=False):
|
| 338 |
+
super(DiscriminatorS, self).__init__()
|
| 339 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 340 |
+
self.convs = nn.ModuleList([
|
| 341 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 342 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 343 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 344 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 345 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 346 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 347 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 348 |
+
])
|
| 349 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 350 |
+
|
| 351 |
+
def forward(self, x):
|
| 352 |
+
fmap = []
|
| 353 |
+
for l in self.convs:
|
| 354 |
+
x = l(x)
|
| 355 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 356 |
+
fmap.append(x)
|
| 357 |
+
x = self.conv_post(x)
|
| 358 |
+
fmap.append(x)
|
| 359 |
+
x = torch.flatten(x, 1, -1)
|
| 360 |
+
|
| 361 |
+
return x, fmap
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 365 |
+
def __init__(self):
|
| 366 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 367 |
+
self.discriminators = nn.ModuleList([
|
| 368 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 369 |
+
DiscriminatorS(),
|
| 370 |
+
DiscriminatorS(),
|
| 371 |
+
])
|
| 372 |
+
self.meanpools = nn.ModuleList([
|
| 373 |
+
AvgPool1d(4, 2, padding=2),
|
| 374 |
+
AvgPool1d(4, 2, padding=2)
|
| 375 |
+
])
|
| 376 |
+
|
| 377 |
+
def forward(self, y, y_hat):
|
| 378 |
+
y_d_rs = []
|
| 379 |
+
y_d_gs = []
|
| 380 |
+
fmap_rs = []
|
| 381 |
+
fmap_gs = []
|
| 382 |
+
for i, d in enumerate(self.discriminators):
|
| 383 |
+
if i != 0:
|
| 384 |
+
y = self.meanpools[i - 1](y)
|
| 385 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
| 386 |
+
y_d_r, fmap_r = d(y)
|
| 387 |
+
y_d_g, fmap_g = d(y_hat)
|
| 388 |
+
y_d_rs.append(y_d_r)
|
| 389 |
+
fmap_rs.append(fmap_r)
|
| 390 |
+
y_d_gs.append(y_d_g)
|
| 391 |
+
fmap_gs.append(fmap_g)
|
| 392 |
+
|
| 393 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def feature_loss(fmap_r, fmap_g):
|
| 397 |
+
loss = 0
|
| 398 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 399 |
+
for rl, gl in zip(dr, dg):
|
| 400 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 401 |
+
|
| 402 |
+
return loss * 2
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 406 |
+
loss = 0
|
| 407 |
+
r_losses = []
|
| 408 |
+
g_losses = []
|
| 409 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 410 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 411 |
+
g_loss = torch.mean(dg ** 2)
|
| 412 |
+
loss += (r_loss + g_loss)
|
| 413 |
+
r_losses.append(r_loss.item())
|
| 414 |
+
g_losses.append(g_loss.item())
|
| 415 |
+
|
| 416 |
+
return loss, r_losses, g_losses
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def generator_loss(disc_outputs):
|
| 420 |
+
loss = 0
|
| 421 |
+
gen_losses = []
|
| 422 |
+
for dg in disc_outputs:
|
| 423 |
+
l = torch.mean((1 - dg) ** 2)
|
| 424 |
+
gen_losses.append(l)
|
| 425 |
+
loss += l
|
| 426 |
+
|
| 427 |
+
return loss, gen_losses
|
rift_svc/nsf_hifigan/nvSTFT.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data
|
| 7 |
+
import numpy as np
|
| 8 |
+
import librosa
|
| 9 |
+
from librosa.util import normalize
|
| 10 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 11 |
+
from scipy.io.wavfile import read
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
| 16 |
+
sampling_rate = None
|
| 17 |
+
try:
|
| 18 |
+
data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
|
| 19 |
+
except Exception as ex:
|
| 20 |
+
print(f"'{full_path}' failed to load.\nException:")
|
| 21 |
+
print(ex)
|
| 22 |
+
if return_empty_on_exception:
|
| 23 |
+
return [], sampling_rate or target_sr or 48000
|
| 24 |
+
else:
|
| 25 |
+
raise Exception(ex)
|
| 26 |
+
|
| 27 |
+
if len(data.shape) > 1:
|
| 28 |
+
data = data[:, 0]
|
| 29 |
+
assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
|
| 30 |
+
|
| 31 |
+
if np.issubdtype(data.dtype, np.integer): # if audio data is type int
|
| 32 |
+
max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
|
| 33 |
+
else: # if audio data is type fp32
|
| 34 |
+
max_mag = max(np.amax(data), -np.amin(data))
|
| 35 |
+
max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
|
| 36 |
+
|
| 37 |
+
data = torch.FloatTensor(data.astype(np.float32))/max_mag
|
| 38 |
+
|
| 39 |
+
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
|
| 40 |
+
return [], sampling_rate or target_sr or 48000
|
| 41 |
+
if target_sr is not None and sampling_rate != target_sr:
|
| 42 |
+
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
|
| 43 |
+
sampling_rate = target_sr
|
| 44 |
+
|
| 45 |
+
return data, sampling_rate
|
| 46 |
+
|
| 47 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 48 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 49 |
+
|
| 50 |
+
def dynamic_range_decompression(x, C=1):
|
| 51 |
+
return np.exp(x) / C
|
| 52 |
+
|
| 53 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 54 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 55 |
+
|
| 56 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 57 |
+
return torch.exp(x) / C
|
| 58 |
+
|
| 59 |
+
class STFT():
|
| 60 |
+
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
| 61 |
+
self.target_sr = sr
|
| 62 |
+
|
| 63 |
+
self.n_mels = n_mels
|
| 64 |
+
self.n_fft = n_fft
|
| 65 |
+
self.win_size = win_size
|
| 66 |
+
self.hop_length = hop_length
|
| 67 |
+
self.fmin = fmin
|
| 68 |
+
self.fmax = fmax
|
| 69 |
+
self.clip_val = clip_val
|
| 70 |
+
self.mel_basis = {}
|
| 71 |
+
self.hann_window = {}
|
| 72 |
+
|
| 73 |
+
def get_mel(self, y, keyshift=0, speed=1, center=False):
|
| 74 |
+
sampling_rate = self.target_sr
|
| 75 |
+
n_mels = self.n_mels
|
| 76 |
+
n_fft = self.n_fft
|
| 77 |
+
win_size = self.win_size
|
| 78 |
+
hop_length = self.hop_length
|
| 79 |
+
fmin = self.fmin
|
| 80 |
+
fmax = self.fmax
|
| 81 |
+
clip_val = self.clip_val
|
| 82 |
+
|
| 83 |
+
factor = 2 ** (keyshift / 12)
|
| 84 |
+
n_fft_new = int(np.round(n_fft * factor))
|
| 85 |
+
win_size_new = int(np.round(win_size * factor))
|
| 86 |
+
hop_length_new = int(np.round(hop_length * speed))
|
| 87 |
+
|
| 88 |
+
mel_basis_key = str(fmax)+'_'+str(y.device)
|
| 89 |
+
if mel_basis_key not in self.mel_basis:
|
| 90 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
| 91 |
+
self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
|
| 92 |
+
|
| 93 |
+
keyshift_key = str(keyshift)+'_'+str(y.device)
|
| 94 |
+
if keyshift_key not in self.hann_window:
|
| 95 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
| 96 |
+
|
| 97 |
+
pad_left = (win_size_new - hop_length_new) //2
|
| 98 |
+
pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
|
| 99 |
+
if pad_right < y.size(-1):
|
| 100 |
+
mode = 'reflect'
|
| 101 |
+
else:
|
| 102 |
+
mode = 'constant'
|
| 103 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
|
| 104 |
+
y = y.squeeze(1)
|
| 105 |
+
|
| 106 |
+
spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
|
| 107 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 108 |
+
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
| 109 |
+
if keyshift != 0:
|
| 110 |
+
size = n_fft // 2 + 1
|
| 111 |
+
resize = spec.size(1)
|
| 112 |
+
if resize < size:
|
| 113 |
+
spec = F.pad(spec, (0, 0, 0, size-resize))
|
| 114 |
+
spec = spec[:, :size, :] * win_size / win_size_new
|
| 115 |
+
spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
|
| 116 |
+
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
|
| 117 |
+
return spec
|
| 118 |
+
|
| 119 |
+
def __call__(self, audiopath):
|
| 120 |
+
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
| 121 |
+
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
| 122 |
+
return spect
|
| 123 |
+
|
| 124 |
+
stft = STFT()
|
rift_svc/nsf_hifigan/utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import matplotlib
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.utils import weight_norm
|
| 6 |
+
matplotlib.use("Agg")
|
| 7 |
+
import matplotlib.pylab as plt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def plot_spectrogram(spectrogram):
|
| 11 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 12 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
| 13 |
+
interpolation='none')
|
| 14 |
+
plt.colorbar(im, ax=ax)
|
| 15 |
+
|
| 16 |
+
fig.canvas.draw()
|
| 17 |
+
plt.close()
|
| 18 |
+
|
| 19 |
+
return fig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 23 |
+
classname = m.__class__.__name__
|
| 24 |
+
if classname.find("Conv") != -1:
|
| 25 |
+
m.weight.data.normal_(mean, std)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_weight_norm(m):
|
| 29 |
+
classname = m.__class__.__name__
|
| 30 |
+
if classname.find("Conv") != -1:
|
| 31 |
+
weight_norm(m)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_padding(kernel_size, dilation=1):
|
| 35 |
+
return int((kernel_size*dilation - dilation)/2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_checkpoint(filepath, device):
|
| 39 |
+
assert os.path.isfile(filepath)
|
| 40 |
+
print("Loading '{}'".format(filepath))
|
| 41 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
| 42 |
+
print("Complete.")
|
| 43 |
+
return checkpoint_dict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def save_checkpoint(filepath, obj):
|
| 47 |
+
print("Saving checkpoint to {}".format(filepath))
|
| 48 |
+
torch.save(obj, filepath)
|
| 49 |
+
print("Complete.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def del_old_checkpoints(cp_dir, prefix, n_models=2):
|
| 53 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
| 54 |
+
cp_list = glob.glob(pattern) # get checkpoint paths
|
| 55 |
+
cp_list = sorted(cp_list)# sort by iter
|
| 56 |
+
if len(cp_list) > n_models: # if more than n_models models are found
|
| 57 |
+
for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
|
| 58 |
+
open(cp, 'w').close()# empty file contents
|
| 59 |
+
os.unlink(cp)# delete file (move to trash when using Colab)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def scan_checkpoint(cp_dir, prefix):
|
| 63 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
| 64 |
+
cp_list = glob.glob(pattern)
|
| 65 |
+
if len(cp_list) == 0:
|
| 66 |
+
return None
|
| 67 |
+
return sorted(cp_list)[-1]
|
rift_svc/nsf_hifigan/vocoder.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from .nvSTFT import STFT
|
| 8 |
+
from .models import load_model,load_config
|
| 9 |
+
from torchaudio.transforms import Resample
|
| 10 |
+
from jaxtyping import Float
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DotDict(dict):
|
| 14 |
+
def __getattr__(*args):
|
| 15 |
+
val = dict.get(*args)
|
| 16 |
+
return DotDict(val) if type(val) is dict else val
|
| 17 |
+
|
| 18 |
+
__setattr__ = dict.__setitem__
|
| 19 |
+
__delattr__ = dict.__delitem__
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_model_vocoder(
|
| 23 |
+
model_path,
|
| 24 |
+
device='cpu'):
|
| 25 |
+
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
|
| 26 |
+
with open(config_file, "r") as config:
|
| 27 |
+
args = yaml.safe_load(config)
|
| 28 |
+
args = DotDict(args)
|
| 29 |
+
|
| 30 |
+
# load vocoder
|
| 31 |
+
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
|
| 32 |
+
|
| 33 |
+
return vocoder, args
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Vocoder:
|
| 37 |
+
def __init__(self, vocoder_type, vocoder_ckpt, device = None):
|
| 38 |
+
if device is None:
|
| 39 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 40 |
+
self.device = device
|
| 41 |
+
|
| 42 |
+
if vocoder_type == 'nsf-hifigan':
|
| 43 |
+
self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device)
|
| 44 |
+
elif vocoder_type == 'nsf-hifigan-log10':
|
| 45 |
+
self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device)
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
|
| 48 |
+
|
| 49 |
+
self.resample_kernel = {}
|
| 50 |
+
self.vocoder_sample_rate = self.vocoder.sample_rate()
|
| 51 |
+
self.vocoder_hop_size = self.vocoder.hop_size()
|
| 52 |
+
self.dimension = self.vocoder.dimension()
|
| 53 |
+
|
| 54 |
+
def extract(self, audio, sample_rate=0, keyshift=0):
|
| 55 |
+
|
| 56 |
+
# resample
|
| 57 |
+
if sample_rate == self.vocoder_sample_rate or sample_rate == 0:
|
| 58 |
+
audio_res = audio
|
| 59 |
+
else:
|
| 60 |
+
key_str = str(sample_rate)
|
| 61 |
+
if key_str not in self.resample_kernel:
|
| 62 |
+
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
|
| 63 |
+
audio_res = self.resample_kernel[key_str](audio)
|
| 64 |
+
|
| 65 |
+
# extract
|
| 66 |
+
mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
|
| 67 |
+
return mel
|
| 68 |
+
|
| 69 |
+
def infer(self, mel, f0):
|
| 70 |
+
f0 = f0[:,:mel.size(1),0] # B, n_frames
|
| 71 |
+
audio = self.vocoder(mel, f0)
|
| 72 |
+
return audio
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class NsfHifiGAN(torch.nn.Module):
|
| 76 |
+
def __init__(self, model_path, device=None):
|
| 77 |
+
super().__init__()
|
| 78 |
+
if device is None:
|
| 79 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 80 |
+
self.device = device
|
| 81 |
+
self.model_path = model_path
|
| 82 |
+
self.model = None
|
| 83 |
+
self.h = load_config(model_path)
|
| 84 |
+
self.stft = STFT(
|
| 85 |
+
self.h.sampling_rate,
|
| 86 |
+
self.h.num_mels,
|
| 87 |
+
self.h.n_fft,
|
| 88 |
+
self.h.win_size,
|
| 89 |
+
self.h.hop_size,
|
| 90 |
+
self.h.fmin,
|
| 91 |
+
self.h.fmax)
|
| 92 |
+
|
| 93 |
+
def sample_rate(self):
|
| 94 |
+
return self.h.sampling_rate
|
| 95 |
+
|
| 96 |
+
def hop_size(self):
|
| 97 |
+
return self.h.hop_size
|
| 98 |
+
|
| 99 |
+
def dimension(self):
|
| 100 |
+
return self.h.num_mels
|
| 101 |
+
|
| 102 |
+
def extract(self, audio, keyshift=0):
|
| 103 |
+
mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
|
| 104 |
+
return mel
|
| 105 |
+
|
| 106 |
+
def forward(self, mel: Float[torch.Tensor, "batch bins n_frames"], f0: Float[torch.Tensor, "batch n_frames"]):
|
| 107 |
+
if self.model is None:
|
| 108 |
+
print('| Load HifiGAN: ', self.model_path)
|
| 109 |
+
self.model, self.h = load_model(self.model_path, device=self.device)
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
audio = self.model(mel, f0)
|
| 112 |
+
return audio
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class NsfHifiGANLog10(NsfHifiGAN):
|
| 116 |
+
def forward(self, mel, f0):
|
| 117 |
+
if self.model is None:
|
| 118 |
+
print('| Load HifiGAN: ', self.model_path)
|
| 119 |
+
self.model, self.h = load_model(self.model_path, device=self.device)
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
c = 0.434294 * mel.transpose(1, 2)
|
| 122 |
+
audio = self.model(c, f0)
|
| 123 |
+
return audio
|
rift_svc/optim.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from schedulefree import AdamWScheduleFree
|
| 2 |
+
from torch.optim import AdamW
|
| 3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_optimizer(
|
| 8 |
+
optimizer_type, model, lr, betas, weight_decay, warmup_steps,
|
| 9 |
+
lora_training=False, **kwargs):
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
|
| 12 |
+
if not lora_training:
|
| 13 |
+
specp_decay_params = defaultdict(list)
|
| 14 |
+
specp_decay_lr = {}
|
| 15 |
+
decay_params = []
|
| 16 |
+
nodecay_params = []
|
| 17 |
+
for n, p in param_dict.items():
|
| 18 |
+
if p.dim() >= 2:
|
| 19 |
+
if n.endswith('out.weight') or n.endswith('proj.weight'):
|
| 20 |
+
fan_out, fan_in = p.shape[-2:]
|
| 21 |
+
fan_ratio = fan_out / fan_in
|
| 22 |
+
specp_decay_params[f"specp_decay_{fan_ratio:.2f}"].append(p)
|
| 23 |
+
specp_decay_lr[f"specp_decay_{fan_ratio:.2f}"] = lr * fan_ratio
|
| 24 |
+
else:
|
| 25 |
+
decay_params.append(p)
|
| 26 |
+
else:
|
| 27 |
+
nodecay_params.append(p)
|
| 28 |
+
|
| 29 |
+
optim_groups = [
|
| 30 |
+
{'params': decay_params, 'weight_decay': weight_decay, 'lr': lr},
|
| 31 |
+
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': lr}
|
| 32 |
+
] + [
|
| 33 |
+
{'params': params, 'weight_decay': weight_decay, 'lr': specp_decay_lr[group_name]}
|
| 34 |
+
for group_name, params in specp_decay_params.items()
|
| 35 |
+
]
|
| 36 |
+
else:
|
| 37 |
+
lora_a_or_spk_embed_params = []
|
| 38 |
+
lora_b_params = []
|
| 39 |
+
for n, p in param_dict.items():
|
| 40 |
+
if n.endswith('.A.weight') or n.endswith('.spk_embed.weight'):
|
| 41 |
+
lora_a_or_spk_embed_params.append(p)
|
| 42 |
+
elif n.endswith('.B.weight'):
|
| 43 |
+
lora_b_params.append(p)
|
| 44 |
+
dim = model.transformer.dim
|
| 45 |
+
rank = model.transformer.transformer_blocks[0].attn.k_proj.rank
|
| 46 |
+
optim_groups = [
|
| 47 |
+
{'params': lora_a_or_spk_embed_params, 'weight_decay': weight_decay, 'lr': lr},
|
| 48 |
+
{'params': lora_b_params, 'weight_decay': weight_decay, 'lr': lr*math.sqrt(dim/rank)}
|
| 49 |
+
]
|
| 50 |
+
if optimizer_type == 'adamwsf':
|
| 51 |
+
optimizer = AdamWScheduleFree(optim_groups, betas=betas, warmup_steps=warmup_steps)
|
| 52 |
+
return optimizer, None
|
| 53 |
+
elif optimizer_type == 'adamw':
|
| 54 |
+
optimizer = AdamW(optim_groups, betas=betas, weight_decay=weight_decay)
|
| 55 |
+
max_steps = kwargs['max_steps']
|
| 56 |
+
min_lr = kwargs.get('min_lr', 0.0)
|
| 57 |
+
lr_scheduler = LinearWarmupDecayLR(optimizer, warmup_steps, max_steps, min_lr=min_lr)
|
| 58 |
+
return optimizer, lr_scheduler
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Invalid optimizer type: {optimizer_type}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LinearWarmupDecayLR(_LRScheduler):
|
| 64 |
+
"""
|
| 65 |
+
Linear learning rate scheduler with warmup and minimum lr.
|
| 66 |
+
|
| 67 |
+
During warmup, the LR increases linearly from 0 to the base LR.
|
| 68 |
+
After warmup, the LR decays linearly from the base LR down to min_lr.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 72 |
+
warmup_steps (int): Number of steps to linearly increase LR.
|
| 73 |
+
total_steps (int): Total number of steps for training (warmup + decay).
|
| 74 |
+
min_lr (float): Minimum learning rate after decay.
|
| 75 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0, last_epoch=-1):
|
| 79 |
+
if total_steps <= warmup_steps:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
"Total steps must be larger than warmup_steps for decay to happen."
|
| 82 |
+
)
|
| 83 |
+
self.warmup_steps = warmup_steps
|
| 84 |
+
self.total_steps = total_steps
|
| 85 |
+
self.min_lr = min_lr
|
| 86 |
+
super(LinearWarmupDecayLR, self).__init__(optimizer, last_epoch)
|
| 87 |
+
|
| 88 |
+
def get_lr(self):
|
| 89 |
+
"""Compute learning rate using linear warmup and then linear decay."""
|
| 90 |
+
# Note: self.last_epoch is incremented by the base _LRScheduler.step() before calling get_lr().
|
| 91 |
+
if self.last_epoch < self.warmup_steps:
|
| 92 |
+
# Warmup phase: increase linearly from 0 (or a small value) to base_lr.
|
| 93 |
+
return [
|
| 94 |
+
base_lr * float(self.last_epoch + 1) / float(self.warmup_steps)
|
| 95 |
+
for base_lr in self.base_lrs
|
| 96 |
+
]
|
| 97 |
+
else:
|
| 98 |
+
# Decay phase: decrease linearly from base_lr to min_lr.
|
| 99 |
+
progress = float(self.last_epoch - self.warmup_steps) / float(self.total_steps - self.warmup_steps)
|
| 100 |
+
return [
|
| 101 |
+
max(base_lr * (1.0 - progress) + self.min_lr * progress, self.min_lr)
|
| 102 |
+
for base_lr in self.base_lrs
|
| 103 |
+
]
|
rift_svc/rf.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List, Literal
|
| 2 |
+
from jaxtyping import Bool
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import math
|
| 7 |
+
from torchdiffeq import odeint
|
| 8 |
+
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
from rift_svc.utils import (
|
| 12 |
+
exists,
|
| 13 |
+
lens_to_mask,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def sample_time(time_schedule: Literal['uniform', 'lognorm'], size: int, device: torch.device):
|
| 18 |
+
if time_schedule == 'uniform':
|
| 19 |
+
t = torch.rand((size,), device=device)
|
| 20 |
+
elif time_schedule == 'lognorm':
|
| 21 |
+
# stratified sampling of normals
|
| 22 |
+
# first stratified sample from uniform
|
| 23 |
+
quantiles = torch.linspace(0, 1, size + 1).to(device)
|
| 24 |
+
z = quantiles[:-1] + torch.rand((size,)).to(device) / size
|
| 25 |
+
# now transform to normal
|
| 26 |
+
z = torch.erfinv(2 * z - 1) * math.sqrt(2)
|
| 27 |
+
t = torch.sigmoid(z)
|
| 28 |
+
return t
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RF(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
transformer: nn.Module,
|
| 35 |
+
time_schedule: Literal['uniform', 'lognorm'] = 'lognorm',
|
| 36 |
+
odeint_kwargs: dict = dict(
|
| 37 |
+
method='euler'
|
| 38 |
+
),
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.transformer = transformer
|
| 43 |
+
dim = transformer.dim
|
| 44 |
+
self.dim = dim
|
| 45 |
+
|
| 46 |
+
# Sampling related parameters
|
| 47 |
+
self.odeint_kwargs = odeint_kwargs
|
| 48 |
+
self.time_schedule = time_schedule
|
| 49 |
+
|
| 50 |
+
self.mel_min = -12
|
| 51 |
+
self.mel_max = 2
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def device(self):
|
| 56 |
+
return next(self.parameters()).device
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def sample(
|
| 60 |
+
self,
|
| 61 |
+
src_mel: torch.Tensor, # [b n d]
|
| 62 |
+
spk_id: torch.Tensor, # [b]
|
| 63 |
+
f0: torch.Tensor, # [b n]
|
| 64 |
+
rms: torch.Tensor, # [b n]
|
| 65 |
+
cvec: torch.Tensor, # [b n d]
|
| 66 |
+
frame_len: torch.Tensor | None = None, # [b]
|
| 67 |
+
steps: int = 32,
|
| 68 |
+
bad_cvec: torch.Tensor | None = None,
|
| 69 |
+
ds_cfg_strength: float = 0.0,
|
| 70 |
+
spk_cfg_strength: float = 0.0,
|
| 71 |
+
skip_cfg_strength: float = 0.0,
|
| 72 |
+
cfg_skip_layers: Union[int, List[int], None] = None,
|
| 73 |
+
cfg_rescale: float = 0.7,
|
| 74 |
+
):
|
| 75 |
+
self.eval()
|
| 76 |
+
|
| 77 |
+
batch, mel_seq_len, num_mel_channels = src_mel.shape
|
| 78 |
+
device = src_mel.device
|
| 79 |
+
|
| 80 |
+
if not exists(frame_len):
|
| 81 |
+
frame_len = torch.full((batch,), mel_seq_len, device=device)
|
| 82 |
+
|
| 83 |
+
mask = lens_to_mask(frame_len)
|
| 84 |
+
|
| 85 |
+
# Define the ODE function
|
| 86 |
+
def fn(t, x):
|
| 87 |
+
pred = self.transformer(
|
| 88 |
+
x=x,
|
| 89 |
+
spk=spk_id,
|
| 90 |
+
f0=f0,
|
| 91 |
+
rms=rms,
|
| 92 |
+
cvec=cvec,
|
| 93 |
+
time=t,
|
| 94 |
+
mask=mask
|
| 95 |
+
)
|
| 96 |
+
cfg_flag = (ds_cfg_strength > 1e-5) or (skip_cfg_strength > 1e-5) or (spk_cfg_strength > 1e-5)
|
| 97 |
+
if cfg_rescale > 1e-5 and cfg_flag:
|
| 98 |
+
std_pred = pred.std()
|
| 99 |
+
|
| 100 |
+
if ds_cfg_strength > 1e-5:
|
| 101 |
+
assert exists(bad_cvec), "bad_cvec is required when cfg_strength is greater than 0"
|
| 102 |
+
bad_cvec_pred = self.transformer(
|
| 103 |
+
x=x,
|
| 104 |
+
spk=spk_id,
|
| 105 |
+
f0=f0,
|
| 106 |
+
rms=rms,
|
| 107 |
+
cvec=bad_cvec,
|
| 108 |
+
time=t,
|
| 109 |
+
mask=mask,
|
| 110 |
+
skip_layers=cfg_skip_layers
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
pred = pred + (pred - bad_cvec_pred) * ds_cfg_strength
|
| 114 |
+
|
| 115 |
+
if skip_cfg_strength > 1e-5:
|
| 116 |
+
skip_pred = self.transformer(
|
| 117 |
+
x=x,
|
| 118 |
+
spk=spk_id,
|
| 119 |
+
f0=f0,
|
| 120 |
+
rms=rms,
|
| 121 |
+
cvec=cvec,
|
| 122 |
+
time=t,
|
| 123 |
+
mask=mask,
|
| 124 |
+
skip_layers=cfg_skip_layers
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
pred = pred + (pred - skip_pred) * skip_cfg_strength
|
| 128 |
+
|
| 129 |
+
if spk_cfg_strength > 1e-5:
|
| 130 |
+
null_spk_pred = self.transformer(
|
| 131 |
+
x=x,
|
| 132 |
+
spk=spk_id,
|
| 133 |
+
f0=f0,
|
| 134 |
+
rms=rms,
|
| 135 |
+
cvec=cvec,
|
| 136 |
+
time=t,
|
| 137 |
+
mask=mask,
|
| 138 |
+
drop_speaker=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
pred = pred + (pred - null_spk_pred) * spk_cfg_strength
|
| 142 |
+
|
| 143 |
+
if cfg_rescale > 1e-5 and cfg_flag:
|
| 144 |
+
std_cfg = pred.std()
|
| 145 |
+
pred_rescaled = pred * (std_pred / std_cfg)
|
| 146 |
+
pred = cfg_rescale * pred_rescaled + (1 - cfg_rescale) * pred
|
| 147 |
+
|
| 148 |
+
return pred
|
| 149 |
+
|
| 150 |
+
# Noise input
|
| 151 |
+
y0 = torch.randn(batch, mel_seq_len, num_mel_channels, device=self.device)
|
| 152 |
+
# mask out the padded tokens
|
| 153 |
+
y0 = y0.masked_fill(~mask.unsqueeze(-1), 0)
|
| 154 |
+
|
| 155 |
+
t_start = 0
|
| 156 |
+
t = torch.linspace(t_start, 1, steps, device=self.device)
|
| 157 |
+
|
| 158 |
+
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
| 159 |
+
|
| 160 |
+
sampled = trajectory[-1]
|
| 161 |
+
out = self.denorm_mel(sampled)
|
| 162 |
+
out = torch.where(mask.unsqueeze(-1), out, src_mel)
|
| 163 |
+
|
| 164 |
+
return out, trajectory
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
mel: torch.Tensor, # mel
|
| 169 |
+
spk_id: torch.Tensor, # [b]
|
| 170 |
+
f0: torch.Tensor, # [b n]
|
| 171 |
+
rms: torch.Tensor, # [b n]
|
| 172 |
+
cvec: torch.Tensor, # [b n d]
|
| 173 |
+
frame_len: torch.Tensor | None = None,
|
| 174 |
+
drop_speaker: Union[bool, Bool[torch.Tensor, "b"]] = False,
|
| 175 |
+
):
|
| 176 |
+
batch, seq_len, dtype, device = *mel.shape[:2], mel.dtype, self.device
|
| 177 |
+
|
| 178 |
+
# Handle lengths and masks
|
| 179 |
+
if not exists(frame_len):
|
| 180 |
+
frame_len = torch.full((batch,), seq_len, device=device)
|
| 181 |
+
|
| 182 |
+
mask = lens_to_mask(frame_len, length=seq_len) # Typically padded to max length in batch
|
| 183 |
+
|
| 184 |
+
x1 = self.norm_mel(mel)
|
| 185 |
+
x0 = torch.randn_like(x1)
|
| 186 |
+
|
| 187 |
+
# uniform time steps sampling
|
| 188 |
+
time = sample_time(self.time_schedule, batch, self.device)
|
| 189 |
+
|
| 190 |
+
t = rearrange(time, 'b -> b 1 1')
|
| 191 |
+
xt = (1 - t) * x0 + t * x1
|
| 192 |
+
flow = x1 - x0
|
| 193 |
+
|
| 194 |
+
pred = self.transformer(
|
| 195 |
+
x=xt,
|
| 196 |
+
spk=spk_id,
|
| 197 |
+
f0=f0,
|
| 198 |
+
rms=rms,
|
| 199 |
+
cvec=cvec,
|
| 200 |
+
time=time,
|
| 201 |
+
drop_speaker=drop_speaker,
|
| 202 |
+
mask=mask
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Flow matching loss
|
| 206 |
+
loss = F.mse_loss(pred, flow, reduction='none')
|
| 207 |
+
loss = loss[mask]
|
| 208 |
+
|
| 209 |
+
return loss.mean(), pred
|
| 210 |
+
|
| 211 |
+
def norm_mel(self, mel: torch.Tensor):
|
| 212 |
+
return (mel - self.mel_min) / (self.mel_max - self.mel_min) * 2 - 1
|
| 213 |
+
|
| 214 |
+
def denorm_mel(self, mel: torch.Tensor):
|
| 215 |
+
return (mel + 1) / 2 * (self.mel_max - self.mel_min) + self.mel_min
|
rift_svc/rmvpe/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .constants import *
|
| 2 |
+
from .model import E2E, E2E0
|
| 3 |
+
from .utils import to_local_average_f0, to_viterbi_f0
|
| 4 |
+
from .inference import RMVPE
|
| 5 |
+
from .spec import MelSpectrogram
|
rift_svc/rmvpe/constants.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SAMPLE_RATE = 16000
|
| 2 |
+
|
| 3 |
+
N_CLASS = 360
|
| 4 |
+
|
| 5 |
+
N_MELS = 128
|
| 6 |
+
MEL_FMIN = 30
|
| 7 |
+
MEL_FMAX = 8000
|
| 8 |
+
WINDOW_LENGTH = 1024
|
| 9 |
+
CONST = 1997.3794084376191
|
rift_svc/rmvpe/deepunet.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .constants import N_MELS
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConvBlockRes(nn.Module):
|
| 7 |
+
def __init__(self, in_channels, out_channels, momentum=0.01):
|
| 8 |
+
super(ConvBlockRes, self).__init__()
|
| 9 |
+
self.conv = nn.Sequential(
|
| 10 |
+
nn.Conv2d(in_channels=in_channels,
|
| 11 |
+
out_channels=out_channels,
|
| 12 |
+
kernel_size=(3, 3),
|
| 13 |
+
stride=(1, 1),
|
| 14 |
+
padding=(1, 1),
|
| 15 |
+
bias=False),
|
| 16 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 17 |
+
nn.ReLU(),
|
| 18 |
+
|
| 19 |
+
nn.Conv2d(in_channels=out_channels,
|
| 20 |
+
out_channels=out_channels,
|
| 21 |
+
kernel_size=(3, 3),
|
| 22 |
+
stride=(1, 1),
|
| 23 |
+
padding=(1, 1),
|
| 24 |
+
bias=False),
|
| 25 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
)
|
| 28 |
+
if in_channels != out_channels:
|
| 29 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
| 30 |
+
self.is_shortcut = True
|
| 31 |
+
else:
|
| 32 |
+
self.is_shortcut = False
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
if self.is_shortcut:
|
| 36 |
+
return self.conv(x) + self.shortcut(x)
|
| 37 |
+
else:
|
| 38 |
+
return self.conv(x) + x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ResEncoderBlock(nn.Module):
|
| 42 |
+
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
| 43 |
+
super(ResEncoderBlock, self).__init__()
|
| 44 |
+
self.n_blocks = n_blocks
|
| 45 |
+
self.conv = nn.ModuleList()
|
| 46 |
+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
| 47 |
+
for i in range(n_blocks - 1):
|
| 48 |
+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 49 |
+
self.kernel_size = kernel_size
|
| 50 |
+
if self.kernel_size is not None:
|
| 51 |
+
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
for i in range(self.n_blocks):
|
| 55 |
+
x = self.conv[i](x)
|
| 56 |
+
if self.kernel_size is not None:
|
| 57 |
+
return x, self.pool(x)
|
| 58 |
+
else:
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ResDecoderBlock(nn.Module):
|
| 63 |
+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
| 64 |
+
super(ResDecoderBlock, self).__init__()
|
| 65 |
+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
| 66 |
+
self.n_blocks = n_blocks
|
| 67 |
+
self.conv1 = nn.Sequential(
|
| 68 |
+
nn.ConvTranspose2d(in_channels=in_channels,
|
| 69 |
+
out_channels=out_channels,
|
| 70 |
+
kernel_size=(3, 3),
|
| 71 |
+
stride=stride,
|
| 72 |
+
padding=(1, 1),
|
| 73 |
+
output_padding=out_padding,
|
| 74 |
+
bias=False),
|
| 75 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 76 |
+
nn.ReLU(),
|
| 77 |
+
)
|
| 78 |
+
self.conv2 = nn.ModuleList()
|
| 79 |
+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
| 80 |
+
for i in range(n_blocks-1):
|
| 81 |
+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 82 |
+
|
| 83 |
+
def forward(self, x, concat_tensor):
|
| 84 |
+
x = self.conv1(x)
|
| 85 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 86 |
+
for i in range(self.n_blocks):
|
| 87 |
+
x = self.conv2[i](x)
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Encoder(nn.Module):
|
| 92 |
+
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
| 93 |
+
super(Encoder, self).__init__()
|
| 94 |
+
self.n_encoders = n_encoders
|
| 95 |
+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| 96 |
+
self.layers = nn.ModuleList()
|
| 97 |
+
self.latent_channels = []
|
| 98 |
+
for i in range(self.n_encoders):
|
| 99 |
+
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
| 100 |
+
self.latent_channels.append([out_channels, in_size])
|
| 101 |
+
in_channels = out_channels
|
| 102 |
+
out_channels *= 2
|
| 103 |
+
in_size //= 2
|
| 104 |
+
self.out_size = in_size
|
| 105 |
+
self.out_channel = out_channels
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
concat_tensors = []
|
| 109 |
+
x = self.bn(x)
|
| 110 |
+
for i in range(self.n_encoders):
|
| 111 |
+
_, x = self.layers[i](x)
|
| 112 |
+
concat_tensors.append(_)
|
| 113 |
+
return x, concat_tensors
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Intermediate(nn.Module):
|
| 117 |
+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
| 118 |
+
super(Intermediate, self).__init__()
|
| 119 |
+
self.n_inters = n_inters
|
| 120 |
+
self.layers = nn.ModuleList()
|
| 121 |
+
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
| 122 |
+
for i in range(self.n_inters-1):
|
| 123 |
+
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
for i in range(self.n_inters):
|
| 127 |
+
x = self.layers[i](x)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Decoder(nn.Module):
|
| 132 |
+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
| 133 |
+
super(Decoder, self).__init__()
|
| 134 |
+
self.layers = nn.ModuleList()
|
| 135 |
+
self.n_decoders = n_decoders
|
| 136 |
+
for i in range(self.n_decoders):
|
| 137 |
+
out_channels = in_channels // 2
|
| 138 |
+
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
| 139 |
+
in_channels = out_channels
|
| 140 |
+
|
| 141 |
+
def forward(self, x, concat_tensors):
|
| 142 |
+
for i in range(self.n_decoders):
|
| 143 |
+
x = self.layers[i](x, concat_tensors[-1-i])
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TimbreFilter(nn.Module):
|
| 148 |
+
def __init__(self, latent_rep_channels):
|
| 149 |
+
super(TimbreFilter, self).__init__()
|
| 150 |
+
self.layers = nn.ModuleList()
|
| 151 |
+
for latent_rep in latent_rep_channels:
|
| 152 |
+
self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
|
| 153 |
+
|
| 154 |
+
def forward(self, x_tensors):
|
| 155 |
+
out_tensors = []
|
| 156 |
+
for i, layer in enumerate(self.layers):
|
| 157 |
+
out_tensors.append(layer(x_tensors[i]))
|
| 158 |
+
return out_tensors
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class DeepUnet(nn.Module):
|
| 162 |
+
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 163 |
+
super(DeepUnet, self).__init__()
|
| 164 |
+
self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
| 165 |
+
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
| 166 |
+
self.tf = TimbreFilter(self.encoder.latent_channels)
|
| 167 |
+
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
x, concat_tensors = self.encoder(x)
|
| 171 |
+
x = self.intermediate(x)
|
| 172 |
+
concat_tensors = self.tf(concat_tensors)
|
| 173 |
+
x = self.decoder(x, concat_tensors)
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class DeepUnet0(nn.Module):
|
| 178 |
+
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 179 |
+
super(DeepUnet0, self).__init__()
|
| 180 |
+
self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
| 181 |
+
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
| 182 |
+
self.tf = TimbreFilter(self.encoder.latent_channels)
|
| 183 |
+
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
x, concat_tensors = self.encoder(x)
|
| 187 |
+
x = self.intermediate(x)
|
| 188 |
+
x = self.decoder(x, concat_tensors)
|
| 189 |
+
return x
|
rift_svc/rmvpe/inference.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchaudio.transforms import Resample
|
| 5 |
+
from .constants import *
|
| 6 |
+
from .model import E2E0, E2E
|
| 7 |
+
from .spec import MelSpectrogram
|
| 8 |
+
from .utils import to_local_average_f0, to_viterbi_f0
|
| 9 |
+
|
| 10 |
+
class RMVPE:
|
| 11 |
+
def __init__(self, model_path, hop_length=160, device='cpu'):
|
| 12 |
+
self.resample_kernel = {}
|
| 13 |
+
model = E2E0(4, 1, (2, 2))
|
| 14 |
+
ckpt = torch.load(model_path, weights_only=True)
|
| 15 |
+
model.load_state_dict(ckpt['model'], strict=False)
|
| 16 |
+
model.eval()
|
| 17 |
+
self.model = model
|
| 18 |
+
self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX)
|
| 19 |
+
self.resample_kernel = {}
|
| 20 |
+
self.model = self.model.to(device)
|
| 21 |
+
self.mel_extractor = self.mel_extractor.to(device)
|
| 22 |
+
|
| 23 |
+
def mel2hidden(self, mel):
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
n_frames = mel.shape[-1]
|
| 26 |
+
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant')
|
| 27 |
+
hidden = self.model(mel)
|
| 28 |
+
return hidden[:, :n_frames]
|
| 29 |
+
|
| 30 |
+
def decode(self, hidden, thred=0.03, use_viterbi=False):
|
| 31 |
+
if use_viterbi:
|
| 32 |
+
f0 = to_viterbi_f0(hidden, thred=thred)
|
| 33 |
+
else:
|
| 34 |
+
f0 = to_local_average_f0(hidden, thred=thred)
|
| 35 |
+
return f0
|
| 36 |
+
|
| 37 |
+
def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, use_viterbi=False):
|
| 38 |
+
#audio = torch.from_numpy(audio).float().unsqueeze(0).to(device)
|
| 39 |
+
if sample_rate == 16000:
|
| 40 |
+
audio_res = audio
|
| 41 |
+
else:
|
| 42 |
+
key_str = str(sample_rate)
|
| 43 |
+
if key_str not in self.resample_kernel:
|
| 44 |
+
self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
|
| 45 |
+
self.resample_kernel[key_str] = self.resample_kernel[key_str].to(device)
|
| 46 |
+
audio_res = self.resample_kernel[key_str](audio)
|
| 47 |
+
|
| 48 |
+
mel = self.mel_extractor(audio_res, center=True)
|
| 49 |
+
hidden = self.mel2hidden(mel)
|
| 50 |
+
f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi)
|
| 51 |
+
return f0
|
rift_svc/rmvpe/model.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .deepunet import DeepUnet, DeepUnet0
|
| 4 |
+
from .constants import *
|
| 5 |
+
from .spec import MelSpectrogram
|
| 6 |
+
from .seq import BiGRU
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class E2E(nn.Module):
|
| 10 |
+
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
|
| 11 |
+
en_out_channels=16):
|
| 12 |
+
super(E2E, self).__init__()
|
| 13 |
+
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
| 14 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
| 15 |
+
if n_gru:
|
| 16 |
+
self.fc = nn.Sequential(
|
| 17 |
+
BiGRU(3 * N_MELS, 256, n_gru),
|
| 18 |
+
nn.Linear(512, N_CLASS),
|
| 19 |
+
nn.Dropout(0.25),
|
| 20 |
+
nn.Sigmoid()
|
| 21 |
+
)
|
| 22 |
+
else:
|
| 23 |
+
self.fc = nn.Sequential(
|
| 24 |
+
nn.Linear(3 * N_MELS, N_CLASS),
|
| 25 |
+
nn.Dropout(0.25),
|
| 26 |
+
nn.Sigmoid()
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, mel):
|
| 30 |
+
mel = mel.transpose(-1, -2).unsqueeze(1)
|
| 31 |
+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
| 32 |
+
x = self.fc(x)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class E2E0(nn.Module):
|
| 37 |
+
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
|
| 38 |
+
en_out_channels=16):
|
| 39 |
+
super(E2E0, self).__init__()
|
| 40 |
+
self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
| 41 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
| 42 |
+
if n_gru:
|
| 43 |
+
self.fc = nn.Sequential(
|
| 44 |
+
BiGRU(3 * N_MELS, 256, n_gru),
|
| 45 |
+
nn.Linear(512, N_CLASS),
|
| 46 |
+
nn.Dropout(0.25),
|
| 47 |
+
nn.Sigmoid()
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
self.fc = nn.Sequential(
|
| 51 |
+
nn.Linear(3 * N_MELS, N_CLASS),
|
| 52 |
+
nn.Dropout(0.25),
|
| 53 |
+
nn.Sigmoid()
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, mel):
|
| 57 |
+
mel = mel.transpose(-1, -2).unsqueeze(1)
|
| 58 |
+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
| 59 |
+
x = self.fc(x)
|
| 60 |
+
return x
|
rift_svc/rmvpe/seq.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BiGRU(nn.Module):
|
| 5 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
| 6 |
+
super(BiGRU, self).__init__()
|
| 7 |
+
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
| 8 |
+
|
| 9 |
+
def forward(self, x):
|
| 10 |
+
return self.gru(x)[0]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BiLSTM(nn.Module):
|
| 14 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
| 15 |
+
super(BiLSTM, self).__init__()
|
| 16 |
+
self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.lstm(x)[0]
|
| 20 |
+
|
rift_svc/rmvpe/spec.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from librosa.filters import mel
|
| 5 |
+
|
| 6 |
+
class MelSpectrogram(torch.nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
n_mel_channels,
|
| 10 |
+
sampling_rate,
|
| 11 |
+
win_length,
|
| 12 |
+
hop_length,
|
| 13 |
+
n_fft=None,
|
| 14 |
+
mel_fmin=0,
|
| 15 |
+
mel_fmax=None,
|
| 16 |
+
clamp = 1e-5
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
n_fft = win_length if n_fft is None else n_fft
|
| 20 |
+
self.hann_window = {}
|
| 21 |
+
mel_basis = mel(
|
| 22 |
+
sr=sampling_rate,
|
| 23 |
+
n_fft=n_fft,
|
| 24 |
+
n_mels=n_mel_channels,
|
| 25 |
+
fmin=mel_fmin,
|
| 26 |
+
fmax=mel_fmax,
|
| 27 |
+
htk=True)
|
| 28 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 29 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 30 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
| 31 |
+
self.hop_length = hop_length
|
| 32 |
+
self.win_length = win_length
|
| 33 |
+
self.sampling_rate = sampling_rate
|
| 34 |
+
self.n_mel_channels = n_mel_channels
|
| 35 |
+
self.clamp = clamp
|
| 36 |
+
|
| 37 |
+
def forward(self, audio, keyshift=0, speed=1, center=True):
|
| 38 |
+
factor = 2 ** (keyshift / 12)
|
| 39 |
+
n_fft_new = int(np.round(self.n_fft * factor))
|
| 40 |
+
win_length_new = int(np.round(self.win_length * factor))
|
| 41 |
+
hop_length_new = int(np.round(self.hop_length * speed))
|
| 42 |
+
|
| 43 |
+
keyshift_key = str(keyshift)+'_'+str(audio.device)
|
| 44 |
+
if keyshift_key not in self.hann_window:
|
| 45 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
| 46 |
+
|
| 47 |
+
fft = torch.stft(
|
| 48 |
+
audio,
|
| 49 |
+
n_fft=n_fft_new,
|
| 50 |
+
hop_length=hop_length_new,
|
| 51 |
+
win_length=win_length_new,
|
| 52 |
+
window=self.hann_window[keyshift_key],
|
| 53 |
+
center=center,
|
| 54 |
+
return_complex=True)
|
| 55 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
| 56 |
+
|
| 57 |
+
if keyshift != 0:
|
| 58 |
+
size = self.n_fft // 2 + 1
|
| 59 |
+
resize = magnitude.size(1)
|
| 60 |
+
if resize < size:
|
| 61 |
+
magnitude = F.pad(magnitude, (0, 0, 0, size-resize))
|
| 62 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
| 63 |
+
|
| 64 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
| 65 |
+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
| 66 |
+
return log_mel_spec
|
rift_svc/rmvpe/utils.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
import torch
|
| 5 |
+
from functools import reduce
|
| 6 |
+
from .constants import *
|
| 7 |
+
from torch.nn.modules.module import _addindent
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def cycle(iterable):
|
| 11 |
+
while True:
|
| 12 |
+
for item in iterable:
|
| 13 |
+
yield item
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def summary(model, file=sys.stdout):
|
| 17 |
+
def repr(model):
|
| 18 |
+
# We treat the extra repr like the sub-module, one item per line
|
| 19 |
+
extra_lines = []
|
| 20 |
+
extra_repr = model.extra_repr()
|
| 21 |
+
# empty string will be split into list ['']
|
| 22 |
+
if extra_repr:
|
| 23 |
+
extra_lines = extra_repr.split('\n')
|
| 24 |
+
child_lines = []
|
| 25 |
+
total_params = 0
|
| 26 |
+
for key, module in model._modules.items():
|
| 27 |
+
mod_str, num_params = repr(module)
|
| 28 |
+
mod_str = _addindent(mod_str, 2)
|
| 29 |
+
child_lines.append('(' + key + '): ' + mod_str)
|
| 30 |
+
total_params += num_params
|
| 31 |
+
lines = extra_lines + child_lines
|
| 32 |
+
|
| 33 |
+
for name, p in model._parameters.items():
|
| 34 |
+
if hasattr(p, 'shape'):
|
| 35 |
+
total_params += reduce(lambda x, y: x * y, p.shape)
|
| 36 |
+
|
| 37 |
+
main_str = model._get_name() + '('
|
| 38 |
+
if lines:
|
| 39 |
+
# simple one-liner info, which most builtin Modules will use
|
| 40 |
+
if len(extra_lines) == 1 and not child_lines:
|
| 41 |
+
main_str += extra_lines[0]
|
| 42 |
+
else:
|
| 43 |
+
main_str += '\n ' + '\n '.join(lines) + '\n'
|
| 44 |
+
|
| 45 |
+
main_str += ')'
|
| 46 |
+
if file is sys.stdout:
|
| 47 |
+
main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
|
| 48 |
+
else:
|
| 49 |
+
main_str += ', {:,} params'.format(total_params)
|
| 50 |
+
return main_str, total_params
|
| 51 |
+
|
| 52 |
+
string, count = repr(model)
|
| 53 |
+
if file is not None:
|
| 54 |
+
if isinstance(file, str):
|
| 55 |
+
file = open(file, 'w')
|
| 56 |
+
print(string, file=file)
|
| 57 |
+
file.flush()
|
| 58 |
+
|
| 59 |
+
return count
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def to_local_average_cents(salience, center=None, thred=0.03):
|
| 63 |
+
"""
|
| 64 |
+
find the weighted average cents near the argmax bin
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
if not hasattr(to_local_average_cents, 'cents_mapping'):
|
| 68 |
+
# the bin number-to-cents mapping
|
| 69 |
+
to_local_average_cents.cents_mapping = (
|
| 70 |
+
20 * np.arange(N_CLASS) + CONST)
|
| 71 |
+
|
| 72 |
+
if salience.ndim == 1:
|
| 73 |
+
if center is None:
|
| 74 |
+
center = int(np.argmax(salience))
|
| 75 |
+
start = max(0, center - 4)
|
| 76 |
+
end = min(len(salience), center + 5)
|
| 77 |
+
salience = salience[start:end]
|
| 78 |
+
product_sum = np.sum(
|
| 79 |
+
salience * to_local_average_cents.cents_mapping[start:end])
|
| 80 |
+
weight_sum = np.sum(salience)
|
| 81 |
+
return product_sum / weight_sum if np.max(salience) > thred else 0
|
| 82 |
+
if salience.ndim == 2:
|
| 83 |
+
return np.array([to_local_average_cents(salience[i, :], None, thred) for i in
|
| 84 |
+
range(salience.shape[0])])
|
| 85 |
+
|
| 86 |
+
raise Exception("label should be either 1d or 2d ndarray")
|
| 87 |
+
|
| 88 |
+
def to_viterbi_cents(salience, thred=0.03):
|
| 89 |
+
# Create viterbi transition matrix
|
| 90 |
+
if not hasattr(to_viterbi_cents, 'transition'):
|
| 91 |
+
xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
|
| 92 |
+
transition = np.maximum(30 - abs(xx - yy), 0)
|
| 93 |
+
transition = transition / transition.sum(axis=1, keepdims=True)
|
| 94 |
+
to_viterbi_cents.transition = transition
|
| 95 |
+
|
| 96 |
+
# Convert to probability
|
| 97 |
+
prob = salience.T
|
| 98 |
+
prob = prob / prob.sum(axis=0)
|
| 99 |
+
|
| 100 |
+
# Perform viterbi decoding
|
| 101 |
+
path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64)
|
| 102 |
+
|
| 103 |
+
return np.array([to_local_average_cents(salience[i, :], path[i], thred) for i in
|
| 104 |
+
range(len(path))])
|
| 105 |
+
|
| 106 |
+
def to_local_average_f0(hidden, center=None, thred=0.03):
|
| 107 |
+
idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N]
|
| 108 |
+
idx_cents = idx * 20 + CONST # [B=1, N]
|
| 109 |
+
if center is None:
|
| 110 |
+
center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
|
| 111 |
+
start = torch.clip(center - 4, min=0) # [B, T, 1]
|
| 112 |
+
end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1]
|
| 113 |
+
idx_mask = (idx >= start) & (idx < end) # [B, T, N]
|
| 114 |
+
weights = hidden * idx_mask # [B, T, N]
|
| 115 |
+
product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T]
|
| 116 |
+
weight_sum = torch.sum(weights, dim=2) # [B, T]
|
| 117 |
+
cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
|
| 118 |
+
f0 = 10 * 2 ** (cents / 1200)
|
| 119 |
+
uv = hidden.max(dim=2)[0] < thred # [B, T]
|
| 120 |
+
f0 = f0 * ~uv
|
| 121 |
+
return f0.squeeze(0).cpu().numpy()
|
| 122 |
+
|
| 123 |
+
def to_viterbi_f0(hidden, thred=0.03):
|
| 124 |
+
# Create viterbi transition matrix
|
| 125 |
+
if not hasattr(to_viterbi_cents, 'transition'):
|
| 126 |
+
xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
|
| 127 |
+
transition = np.maximum(30 - abs(xx - yy), 0)
|
| 128 |
+
transition = transition / transition.sum(axis=1, keepdims=True)
|
| 129 |
+
to_viterbi_cents.transition = transition
|
| 130 |
+
|
| 131 |
+
# Convert to probability
|
| 132 |
+
prob = hidden.squeeze(0).cpu().numpy()
|
| 133 |
+
prob = prob.T
|
| 134 |
+
prob = prob / prob.sum(axis=0)
|
| 135 |
+
|
| 136 |
+
# Perform viterbi decoding
|
| 137 |
+
path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64)
|
| 138 |
+
center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device)
|
| 139 |
+
|
| 140 |
+
return to_local_average_f0(hidden, center=center, thred=thred)
|
| 141 |
+
|
| 142 |
+
|
rift_svc/utils.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from jaxtyping import Bool, Int
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
| 14 |
+
import parselmouth as pm
|
| 15 |
+
import librosa
|
| 16 |
+
import pyworld as pw
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def seed_everything(seed: int = 0):
|
| 20 |
+
random.seed(seed)
|
| 21 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 22 |
+
torch.manual_seed(seed)
|
| 23 |
+
torch.cuda.manual_seed(seed)
|
| 24 |
+
torch.cuda.manual_seed_all(seed)
|
| 25 |
+
torch.backends.cudnn.deterministic = True
|
| 26 |
+
torch.backends.cudnn.benchmark = False
|
| 27 |
+
|
| 28 |
+
# helpers
|
| 29 |
+
|
| 30 |
+
def exists(v: Any) -> bool:
|
| 31 |
+
return v is not None
|
| 32 |
+
|
| 33 |
+
def default(v: Any, d: Any) -> Any:
|
| 34 |
+
return v if exists(v) else d
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def draw_mel_specs(gt: np.ndarray, gen: np.ndarray, diff: np.ndarray, cache_path: str):
|
| 38 |
+
vmin = min(gt.min(), gen.min())
|
| 39 |
+
vmax = max(gt.max(), gen.max())
|
| 40 |
+
|
| 41 |
+
# Create figure with space for colorbar
|
| 42 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 15), sharex=True, gridspec_kw={'hspace': 0})
|
| 43 |
+
|
| 44 |
+
# Plot all spectrograms with the same scale
|
| 45 |
+
im1 = ax1.imshow(gt, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
|
| 46 |
+
ax1.set_ylabel('GT', fontsize=14)
|
| 47 |
+
ax1.set_xticks([])
|
| 48 |
+
|
| 49 |
+
im2 = ax2.imshow(gen, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
|
| 50 |
+
ax2.set_ylabel('Gen', fontsize=14)
|
| 51 |
+
ax2.set_xticks([])
|
| 52 |
+
|
| 53 |
+
# Find symmetric limits for difference plot
|
| 54 |
+
diff_abs_max = max(abs(diff.min()), abs(diff.max()))
|
| 55 |
+
|
| 56 |
+
im3 = ax3.imshow(diff, origin='lower', aspect='auto',
|
| 57 |
+
cmap='RdBu_r', # Red-White-Blue colormap (reversed)
|
| 58 |
+
vmin=-diff_abs_max, vmax=diff_abs_max)
|
| 59 |
+
ax3.set_ylabel('Diff', fontsize=14)
|
| 60 |
+
|
| 61 |
+
fig.colorbar(im1, ax=[ax1, ax2], location='right', label='Magnitude')
|
| 62 |
+
fig.colorbar(im3, ax=[ax3], location='right', label='Difference')
|
| 63 |
+
|
| 64 |
+
buf = io.BytesIO()
|
| 65 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 66 |
+
plt.close()
|
| 67 |
+
buf.seek(0)
|
| 68 |
+
|
| 69 |
+
# Open with PIL and save as compressed JPEG
|
| 70 |
+
img = Image.open(buf)
|
| 71 |
+
img = img.convert('RGB')
|
| 72 |
+
img.save(cache_path, 'JPEG', quality=85, optimize=True)
|
| 73 |
+
buf.close()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# tensor helpers
|
| 78 |
+
|
| 79 |
+
def lens_to_mask(
|
| 80 |
+
t: Int[torch.Tensor, "b"],
|
| 81 |
+
length: int | None = None
|
| 82 |
+
) -> Bool[torch.Tensor, "b n"]:
|
| 83 |
+
|
| 84 |
+
if not exists(length):
|
| 85 |
+
length = t.amax()
|
| 86 |
+
|
| 87 |
+
seq = torch.arange(length, device = t.device)
|
| 88 |
+
return seq < t[..., None]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def l2_grad_norm(model: torch.nn.Module):
|
| 92 |
+
return torch.cat([p.grad.data.flatten() for p in model.parameters() if p.grad is not None]).norm(2)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def nearest_interpolate_tensor(tensor, new_size):
|
| 96 |
+
# Add two dummy dimensions to make it [1, 1, n, d]
|
| 97 |
+
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
| 98 |
+
|
| 99 |
+
# Interpolate
|
| 100 |
+
interpolated = F.interpolate(tensor, size=(new_size, tensor.shape[-1]), mode='nearest')
|
| 101 |
+
|
| 102 |
+
# Remove the dummy dimensions
|
| 103 |
+
interpolated = interpolated.squeeze(0).squeeze(0)
|
| 104 |
+
|
| 105 |
+
return interpolated
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def linear_interpolate_tensor(tensor, new_size):
|
| 109 |
+
# Assumes input tensor shape is [n, d]
|
| 110 |
+
# Rearrange tensor to shape [1, d, n] to prepare for linear interpolation
|
| 111 |
+
tensor = tensor.transpose(0, 1).unsqueeze(0)
|
| 112 |
+
|
| 113 |
+
# Interpolate along the length dimension (last dimension) using linear interpolation.
|
| 114 |
+
# align_corners=True preserves the boundary values; adjust this flag if needed.
|
| 115 |
+
interpolated = F.interpolate(tensor, size=new_size, mode='linear', align_corners=True)
|
| 116 |
+
|
| 117 |
+
# Restore the tensor to shape [new_size, d]
|
| 118 |
+
return interpolated.squeeze(0).transpose(0, 1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# f0 helpers
|
| 122 |
+
|
| 123 |
+
def post_process_f0(f0, sample_rate, hop_length, n_frames, silence_front=0.0, cut_last=True):
|
| 124 |
+
"""
|
| 125 |
+
Post-process the extracted f0 to align with Mel spectrogram frames.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
f0 (numpy.ndarray): Extracted f0 array.
|
| 129 |
+
sample_rate (int): Sample rate of the audio.
|
| 130 |
+
hop_length (int): Hop length used during processing.
|
| 131 |
+
n_frames (int): Total number of frames (for alignment).
|
| 132 |
+
silence_front (float): Seconds of silence to remove from the front.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
numpy.ndarray: Processed f0 array aligned with Mel spectrogram frames.
|
| 136 |
+
"""
|
| 137 |
+
# Calculate number of frames to skip based on silence_front
|
| 138 |
+
start_frame = int(silence_front * sample_rate / hop_length)
|
| 139 |
+
real_silence_front = start_frame * hop_length / sample_rate
|
| 140 |
+
# Assuming silence_front has been handled during RMVPE inference if needed
|
| 141 |
+
|
| 142 |
+
# Handle unvoiced frames by interpolation
|
| 143 |
+
uv = f0 == 0
|
| 144 |
+
if np.any(~uv):
|
| 145 |
+
f0_interp = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
|
| 146 |
+
f0[uv] = f0_interp
|
| 147 |
+
else:
|
| 148 |
+
# If no voiced frames, set all to zero
|
| 149 |
+
f0 = np.zeros_like(f0)
|
| 150 |
+
|
| 151 |
+
# Align with hop_length frames
|
| 152 |
+
origin_time = 0.01 * np.arange(len(f0)) # Placeholder: Adjust based on RMVPE's timing
|
| 153 |
+
target_time = hop_length / sample_rate * np.arange(n_frames - start_frame)
|
| 154 |
+
f0 = np.interp(target_time, origin_time, f0)
|
| 155 |
+
uv = np.interp(target_time, origin_time, uv.astype(float)) > 0.5
|
| 156 |
+
f0[uv] = 0
|
| 157 |
+
|
| 158 |
+
# Pad the silence_front if needed
|
| 159 |
+
f0 = np.pad(f0, (start_frame, 0), mode='constant')
|
| 160 |
+
|
| 161 |
+
if cut_last:
|
| 162 |
+
return f0[:-1]
|
| 163 |
+
else:
|
| 164 |
+
return f0
|
| 165 |
+
|
| 166 |
+
# pyworld
|
| 167 |
+
def get_f0_pw(audio, sr, time_step, f0_min, f0_max):
|
| 168 |
+
pw_pre_f0, times = pw.dio(
|
| 169 |
+
audio.astype(np.double), sr,
|
| 170 |
+
f0_floor=f0_min, f0_ceil=f0_max,
|
| 171 |
+
frame_period=time_step*1000) # raw pitch extractor
|
| 172 |
+
pw_post_f0 = pw.stonemask(audio.astype(np.double), pw_pre_f0, times, sr) # pitch refinement
|
| 173 |
+
pw_post_f0[pw_post_f0==0] = np.nan
|
| 174 |
+
pw_post_f0 = slide_nanmedian(pw_post_f0, 3)
|
| 175 |
+
return pw_post_f0
|
| 176 |
+
|
| 177 |
+
# parselmouth
|
| 178 |
+
def get_f0_pm(audio, sr, time_step, f0_min, f0_max):
|
| 179 |
+
pmac_pitch = pm.Sound(audio, sampling_frequency=sr).to_pitch_ac(
|
| 180 |
+
time_step=time_step, voicing_threshold=0.6,
|
| 181 |
+
pitch_floor=f0_min, pitch_ceiling=f0_max,
|
| 182 |
+
very_accurate=True, octave_jump_cost=0.5)
|
| 183 |
+
pmac_f0 = pmac_pitch.selected_array['frequency']
|
| 184 |
+
pmac_f0[pmac_f0==0] = np.nan
|
| 185 |
+
pmac_f0 = slide_nanmedian(pmac_f0, 3)
|
| 186 |
+
return pmac_f0
|
| 187 |
+
|
| 188 |
+
from numba import njit
|
| 189 |
+
@njit
|
| 190 |
+
def slide_nanmedian(signals=np.array([]), win_length=3):
|
| 191 |
+
"""Filters a sequence, ignoring nan values
|
| 192 |
+
|
| 193 |
+
Arguments
|
| 194 |
+
signals (numpy.ndarray (shape=(time)))
|
| 195 |
+
The signals to filter
|
| 196 |
+
win_length
|
| 197 |
+
The size of the analysis window
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
filtered (numpy.ndarray (shape=(time)))
|
| 201 |
+
"""
|
| 202 |
+
# Output buffer
|
| 203 |
+
filtered = np.empty_like(signals)
|
| 204 |
+
|
| 205 |
+
# Loop over frames
|
| 206 |
+
for i in range(signals.shape[0]):
|
| 207 |
+
|
| 208 |
+
# Get analysis window bounds
|
| 209 |
+
start = max(0, i - win_length // 2)
|
| 210 |
+
end = min(signals.shape[0], i + win_length // 2 + 1)
|
| 211 |
+
|
| 212 |
+
# Apply filter to window
|
| 213 |
+
filtered[i] = np.nanmedian(signals[start:end])
|
| 214 |
+
|
| 215 |
+
return filtered
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def f0_ensemble(rmvpe_f0, pw_f0, pmac_f0):
|
| 219 |
+
trunc_len = len(rmvpe_f0)
|
| 220 |
+
pw_f0 = pw_f0[:trunc_len]
|
| 221 |
+
# pad pmac_f0
|
| 222 |
+
pmac_f0 = np.concatenate(
|
| 223 |
+
[pmac_f0, np.full(len(pw_f0)-len(pmac_f0), np.nan, dtype=pmac_f0.dtype)])
|
| 224 |
+
|
| 225 |
+
stack_f0 = np.stack([pw_f0, pmac_f0, rmvpe_f0], axis=0)
|
| 226 |
+
|
| 227 |
+
meadian_f0 = np.nanmedian(stack_f0, axis=0)
|
| 228 |
+
nan_nums = np.sum(np.isnan(stack_f0), axis=0)
|
| 229 |
+
meadian_f0[nan_nums>=2] = np.nan
|
| 230 |
+
|
| 231 |
+
slide_meadian_f0 = slide_nanmedian(meadian_f0, 41)
|
| 232 |
+
|
| 233 |
+
f0_dev = np.abs(meadian_f0-slide_meadian_f0)
|
| 234 |
+
meadian_f0[f0_dev>96] = slide_meadian_f0[f0_dev>96]
|
| 235 |
+
|
| 236 |
+
nan1_f0_min = np.nanmin(stack_f0[:, nan_nums==1], axis=0)
|
| 237 |
+
nan1_f0_max = np.nanmax(stack_f0[:, nan_nums==1], axis=0)
|
| 238 |
+
|
| 239 |
+
nan1_f0 = np.where(
|
| 240 |
+
np.abs(nan1_f0_min-slide_meadian_f0[nan_nums==1])<np.abs(nan1_f0_max-slide_meadian_f0[nan_nums==1]),
|
| 241 |
+
nan1_f0_min, nan1_f0_max)
|
| 242 |
+
meadian_f0[nan_nums==1] = nan1_f0
|
| 243 |
+
|
| 244 |
+
meadian_f0 = slide_nanmedian(meadian_f0, 3)
|
| 245 |
+
meadian_f0[nan_nums>=2] = np.nan
|
| 246 |
+
meadian_f0[np.isnan(meadian_f0)] = 0
|
| 247 |
+
|
| 248 |
+
return meadian_f0
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def f0_ensemble_light(rmvpe_f0, pw_f0, pmac_f0, rms=None, rms_threshold=0.05):
|
| 252 |
+
"""
|
| 253 |
+
A lighter version of f0 ensemble that preserves RMVPE's expressiveness.
|
| 254 |
+
Only applies corrections when RMVPE shows abnormalities.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
rmvpe_f0 (numpy.ndarray): F0 from RMVPE
|
| 258 |
+
pw_f0 (numpy.ndarray): F0 from WORLD
|
| 259 |
+
pmac_f0 (numpy.ndarray): F0 from Parselmouth
|
| 260 |
+
rms (numpy.ndarray, optional): RMS energy values, used to detect voiced segments
|
| 261 |
+
rms_threshold (float, optional): Threshold for RMS to consider a segment as voiced
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
numpy.ndarray: Corrected F0 values
|
| 265 |
+
"""
|
| 266 |
+
trunc_len = len(rmvpe_f0)
|
| 267 |
+
pw_f0 = pw_f0[:trunc_len]
|
| 268 |
+
|
| 269 |
+
# Pad pmac_f0 if needed
|
| 270 |
+
pmac_f0 = np.concatenate(
|
| 271 |
+
[pmac_f0, np.full(max(0, len(pw_f0)-len(pmac_f0)), np.nan, dtype=pmac_f0.dtype)])
|
| 272 |
+
|
| 273 |
+
# Create a copy of rmvpe_f0 to preserve most of its values
|
| 274 |
+
corrected_f0 = rmvpe_f0.copy()
|
| 275 |
+
|
| 276 |
+
# Stack all F0 values
|
| 277 |
+
stack_f0 = np.stack([pw_f0, pmac_f0, rmvpe_f0], axis=0)
|
| 278 |
+
|
| 279 |
+
# Count non-NaN values for each frame
|
| 280 |
+
valid_count = np.sum(~np.isnan(stack_f0), axis=0)
|
| 281 |
+
|
| 282 |
+
# Identify frames where RMVPE shows zero but other methods detect pitch
|
| 283 |
+
zero_rmvpe_mask = (rmvpe_f0 == 0)
|
| 284 |
+
|
| 285 |
+
# For frames where RMVPE is zero but at least one other method has a valid F0
|
| 286 |
+
# and there's voice activity (if RMS is provided)
|
| 287 |
+
other_methods_valid = ((~np.isnan(pw_f0) & (pw_f0 > 0)) |
|
| 288 |
+
(~np.isnan(pmac_f0) & (pmac_f0 > 0)))
|
| 289 |
+
|
| 290 |
+
correction_mask = zero_rmvpe_mask & other_methods_valid
|
| 291 |
+
|
| 292 |
+
# If RMS is provided, only correct frames with voice activity
|
| 293 |
+
if rms is not None:
|
| 294 |
+
voice_activity = rms > rms_threshold
|
| 295 |
+
correction_mask = correction_mask & voice_activity
|
| 296 |
+
|
| 297 |
+
# For frames needing correction, use median of available values
|
| 298 |
+
if np.any(correction_mask):
|
| 299 |
+
# For each frame needing correction, calculate median of non-NaN values
|
| 300 |
+
for i in np.where(correction_mask)[0]:
|
| 301 |
+
valid_values = stack_f0[:, i][~np.isnan(stack_f0[:, i]) & (stack_f0[:, i] > 0)]
|
| 302 |
+
if len(valid_values) > 0:
|
| 303 |
+
corrected_f0[i] = np.median(valid_values)
|
| 304 |
+
|
| 305 |
+
# Handle any remaining NaN values
|
| 306 |
+
corrected_f0[np.isnan(corrected_f0)] = 0
|
| 307 |
+
|
| 308 |
+
return corrected_f0
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# progress bar helper
|
| 312 |
+
|
| 313 |
+
class CustomProgressBar(TQDMProgressBar):
|
| 314 |
+
def __init__(self):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.start_time = None
|
| 317 |
+
self.step_start_time = None
|
| 318 |
+
self.total_steps = None
|
| 319 |
+
|
| 320 |
+
def on_train_start(self, trainer, pl_module):
|
| 321 |
+
super().on_train_start(trainer, pl_module)
|
| 322 |
+
self.start_time = time.time()
|
| 323 |
+
self.total_steps = trainer.max_steps
|
| 324 |
+
|
| 325 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
| 326 |
+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
| 327 |
+
|
| 328 |
+
current_step = trainer.global_step
|
| 329 |
+
total_steps = self.total_steps
|
| 330 |
+
|
| 331 |
+
# Calculate elapsed time since training started
|
| 332 |
+
elapsed_time = time.time() - self.start_time
|
| 333 |
+
|
| 334 |
+
# Estimate average step time and remaining time
|
| 335 |
+
average_step_time = elapsed_time / current_step if current_step > 0 else 0
|
| 336 |
+
remaining_steps = total_steps - current_step
|
| 337 |
+
remaining_time = average_step_time * remaining_steps if total_steps > 0 else 0
|
| 338 |
+
|
| 339 |
+
# Format times with no leading zeros for hours
|
| 340 |
+
def format_time(seconds):
|
| 341 |
+
hours = int(seconds // 3600)
|
| 342 |
+
minutes = int((seconds % 3600) // 60)
|
| 343 |
+
seconds = int(seconds % 60)
|
| 344 |
+
return f"{hours}:{minutes:02d}:{seconds:02d}"
|
| 345 |
+
|
| 346 |
+
elapsed_time_str = format_time(elapsed_time)
|
| 347 |
+
remaining_time_str = format_time(remaining_time)
|
| 348 |
+
|
| 349 |
+
# Update the progress bar with loss, elapsed time, remaining time, and remaining steps
|
| 350 |
+
self.train_progress_bar.set_postfix({
|
| 351 |
+
"loss": f"{outputs['loss'].item():.4f}",
|
| 352 |
+
"elapsed_time": elapsed_time_str + "/" + remaining_time_str,
|
| 353 |
+
"remaining_steps": str(remaining_steps) + "/" + str(total_steps)
|
| 354 |
+
})
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# state dict helpers
|
| 358 |
+
|
| 359 |
+
def load_state_dict(model, state_dict, strict=False):
|
| 360 |
+
"""Load state dict while handling 'model.' prefix"""
|
| 361 |
+
if any(k.startswith('model.') for k in state_dict.keys()):
|
| 362 |
+
# Remove 'model.' prefix
|
| 363 |
+
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
|
| 364 |
+
return model.load_state_dict(state_dict, strict=strict)
|
slicer.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import librosa
|
| 5 |
+
|
| 6 |
+
warnings.filterwarnings('ignore')
|
| 7 |
+
|
| 8 |
+
# Configure logging at the top of your slicer.py
|
| 9 |
+
logging.basicConfig(level=logging.INFO)
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Slicer:
|
| 14 |
+
def __init__(self,
|
| 15 |
+
sr: int,
|
| 16 |
+
threshold: float = -30.,
|
| 17 |
+
min_length: int = 3000,
|
| 18 |
+
min_interval: int = 100,
|
| 19 |
+
hop_size: int = 20,
|
| 20 |
+
max_sil_kept: int = 5000):
|
| 21 |
+
if not min_length >= min_interval >= hop_size:
|
| 22 |
+
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
|
| 23 |
+
if not max_sil_kept >= hop_size:
|
| 24 |
+
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
|
| 25 |
+
min_interval = sr * min_interval / 1000
|
| 26 |
+
self.sr = sr
|
| 27 |
+
self.threshold = 10 ** (threshold / 20.)
|
| 28 |
+
self.hop_size = round(sr * hop_size / 1000)
|
| 29 |
+
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
| 30 |
+
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
| 31 |
+
self.min_interval = round(min_interval / self.hop_size)
|
| 32 |
+
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _apply_slice(self, waveform, begin, end):
|
| 36 |
+
if len(waveform.shape) > 1:
|
| 37 |
+
return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
|
| 38 |
+
else:
|
| 39 |
+
return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def slice(self, waveform):
|
| 43 |
+
if len(waveform.shape) > 1:
|
| 44 |
+
samples = librosa.to_mono(waveform)
|
| 45 |
+
else:
|
| 46 |
+
samples = waveform
|
| 47 |
+
if samples.shape[0] <= self.min_length:
|
| 48 |
+
# Return the entire audio as a single chunk
|
| 49 |
+
return [(0, waveform)]
|
| 50 |
+
|
| 51 |
+
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
| 52 |
+
sil_tags = []
|
| 53 |
+
silence_start = None
|
| 54 |
+
clip_start = 0
|
| 55 |
+
for i, rms in enumerate(rms_list):
|
| 56 |
+
# Keep looping while frame is silent.
|
| 57 |
+
if rms < self.threshold:
|
| 58 |
+
# Record start of silent frames.
|
| 59 |
+
if silence_start is None:
|
| 60 |
+
silence_start = i
|
| 61 |
+
continue
|
| 62 |
+
# Keep looping while frame is not silent and silence start has not been recorded.
|
| 63 |
+
if silence_start is None:
|
| 64 |
+
continue
|
| 65 |
+
# Clear recorded silence start if interval is not enough or clip is too short
|
| 66 |
+
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
| 67 |
+
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
| 68 |
+
if not is_leading_silence and not need_slice_middle:
|
| 69 |
+
silence_start = None
|
| 70 |
+
continue
|
| 71 |
+
# Need slicing. Record the range of silent frames to be removed.
|
| 72 |
+
if i - silence_start <= self.max_sil_kept:
|
| 73 |
+
pos = rms_list[silence_start: i + 1].argmin() + silence_start
|
| 74 |
+
if silence_start == 0:
|
| 75 |
+
sil_tags.append((0, pos))
|
| 76 |
+
else:
|
| 77 |
+
sil_tags.append((pos, pos))
|
| 78 |
+
clip_start = pos
|
| 79 |
+
elif i - silence_start <= self.max_sil_kept * 2:
|
| 80 |
+
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
|
| 81 |
+
pos += i - self.max_sil_kept
|
| 82 |
+
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
| 83 |
+
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
| 84 |
+
if silence_start == 0:
|
| 85 |
+
sil_tags.append((0, pos_r))
|
| 86 |
+
clip_start = pos_r
|
| 87 |
+
else:
|
| 88 |
+
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
| 89 |
+
clip_start = max(pos_r, pos)
|
| 90 |
+
else:
|
| 91 |
+
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
| 92 |
+
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
| 93 |
+
if silence_start == 0:
|
| 94 |
+
sil_tags.append((0, pos_r))
|
| 95 |
+
else:
|
| 96 |
+
sil_tags.append((pos_l, pos_r))
|
| 97 |
+
clip_start = pos_r
|
| 98 |
+
silence_start = None
|
| 99 |
+
|
| 100 |
+
# Deal with trailing silence.
|
| 101 |
+
total_frames = rms_list.shape[0]
|
| 102 |
+
if silence_start is not None and total_frames - silence_start >= self.min_interval:
|
| 103 |
+
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
| 104 |
+
pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
|
| 105 |
+
sil_tags.append((pos, total_frames + 1))
|
| 106 |
+
|
| 107 |
+
# Apply and return slices.
|
| 108 |
+
if len(sil_tags) == 0:
|
| 109 |
+
# Return the entire audio as a single chunk if no silence detected
|
| 110 |
+
return [(0, waveform)]
|
| 111 |
+
|
| 112 |
+
# Extract non-silence chunks
|
| 113 |
+
non_silence_chunks = []
|
| 114 |
+
|
| 115 |
+
# Add first non-silence chunk if it exists
|
| 116 |
+
if sil_tags[0][0] > 0:
|
| 117 |
+
start_pos = 0
|
| 118 |
+
end_frame = sil_tags[0][0]
|
| 119 |
+
chunk = self._apply_slice(waveform, 0, end_frame)
|
| 120 |
+
non_silence_chunks.append((start_pos, chunk))
|
| 121 |
+
|
| 122 |
+
# Add middle non-silence chunks
|
| 123 |
+
for i in range(1, len(sil_tags)):
|
| 124 |
+
start_frame = sil_tags[i-1][1]
|
| 125 |
+
end_frame = sil_tags[i][0]
|
| 126 |
+
if start_frame < end_frame: # Only add if there's actual non-silence content
|
| 127 |
+
start_pos = start_frame * self.hop_size
|
| 128 |
+
chunk = self._apply_slice(waveform, start_frame, end_frame)
|
| 129 |
+
non_silence_chunks.append((start_pos, chunk))
|
| 130 |
+
|
| 131 |
+
# Add last non-silence chunk if it exists
|
| 132 |
+
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
| 133 |
+
start_frame = sil_tags[-1][1]
|
| 134 |
+
start_pos = start_frame * self.hop_size
|
| 135 |
+
chunk = self._apply_slice(waveform, start_frame, total_frames)
|
| 136 |
+
non_silence_chunks.append((start_pos, chunk))
|
| 137 |
+
|
| 138 |
+
for i, (start_pos, chunk) in enumerate(non_silence_chunks):
|
| 139 |
+
# Calculate start and end times in seconds
|
| 140 |
+
start_time_sec = start_pos / self.sr
|
| 141 |
+
end_time_sec = start_pos / self.sr + len(chunk) / self.sr if len(chunk.shape) == 1 else start_pos / self.sr + chunk.shape[1] / self.sr
|
| 142 |
+
duration_sec = end_time_sec - start_time_sec
|
| 143 |
+
|
| 144 |
+
# Format start and end times as mm:ss
|
| 145 |
+
start_min, start_sec = divmod(start_time_sec, 60)
|
| 146 |
+
end_min, end_sec = divmod(end_time_sec, 60)
|
| 147 |
+
|
| 148 |
+
# Log the information
|
| 149 |
+
logger.info(f"Chunk {i}: Start={int(start_min):02d}:{start_sec:05.2f}, End={int(end_min):02d}:{end_sec:05.2f}, Duration={duration_sec:.2f}s")
|
| 150 |
+
|
| 151 |
+
return non_silence_chunks
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def main():
|
| 156 |
+
import os.path
|
| 157 |
+
from argparse import ArgumentParser
|
| 158 |
+
import librosa
|
| 159 |
+
import soundfile
|
| 160 |
+
from pathlib import Path
|
| 161 |
+
|
| 162 |
+
parser = ArgumentParser()
|
| 163 |
+
parser.add_argument('audio', type=str, help='The audio file or directory to be sliced')
|
| 164 |
+
parser.add_argument('--out', type=str, help='Output directory of the sliced audio clips')
|
| 165 |
+
parser.add_argument('--db_thresh', type=float, required=False, default=-30,
|
| 166 |
+
help='The dB threshold for silence detection')
|
| 167 |
+
parser.add_argument('--min_length', type=int, required=False, default=3000,
|
| 168 |
+
help='The minimum milliseconds required for each sliced audio clip')
|
| 169 |
+
parser.add_argument('--min_interval', type=int, required=False, default=100,
|
| 170 |
+
help='The minimum milliseconds for a silence part to be sliced')
|
| 171 |
+
parser.add_argument('--hop_size', type=int, required=False, default=20,
|
| 172 |
+
help='Frame length in milliseconds')
|
| 173 |
+
parser.add_argument('--max_sil_kept', type=int, required=False, default=5000,
|
| 174 |
+
help='The maximum silence length kept around the sliced clip, presented in milliseconds')
|
| 175 |
+
args = parser.parse_args()
|
| 176 |
+
|
| 177 |
+
# Determine if the input is a file or directory
|
| 178 |
+
audio_path = Path(args.audio)
|
| 179 |
+
is_directory = audio_path.is_dir()
|
| 180 |
+
|
| 181 |
+
# Prepare output directory
|
| 182 |
+
out = args.out
|
| 183 |
+
if out is None:
|
| 184 |
+
if is_directory:
|
| 185 |
+
out = os.path.abspath(args.audio)
|
| 186 |
+
else:
|
| 187 |
+
out = os.path.dirname(os.path.abspath(args.audio))
|
| 188 |
+
|
| 189 |
+
if not os.path.exists(out):
|
| 190 |
+
os.makedirs(out)
|
| 191 |
+
|
| 192 |
+
# Audio file extensions to process
|
| 193 |
+
audio_extensions = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
| 194 |
+
|
| 195 |
+
# Process a single file or all files in a directory
|
| 196 |
+
if is_directory:
|
| 197 |
+
logger.info(f"Processing all audio files in directory: {args.audio}")
|
| 198 |
+
audio_files = []
|
| 199 |
+
for ext in audio_extensions:
|
| 200 |
+
audio_files.extend(list(audio_path.glob(f'*{ext}')))
|
| 201 |
+
|
| 202 |
+
if not audio_files:
|
| 203 |
+
logger.warning(f"No audio files found in {args.audio}")
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
logger.info(f"Found {len(audio_files)} audio files to process")
|
| 207 |
+
for audio_file in audio_files:
|
| 208 |
+
process_audio_file(audio_file, out, args)
|
| 209 |
+
else:
|
| 210 |
+
# Process a single audio file
|
| 211 |
+
logger.info(f"Processing single audio file: {args.audio}")
|
| 212 |
+
process_audio_file(audio_path, out, args)
|
| 213 |
+
|
| 214 |
+
def process_audio_file(audio_file, out_dir, args):
|
| 215 |
+
"""Process a single audio file with the given parameters"""
|
| 216 |
+
import os.path
|
| 217 |
+
import librosa
|
| 218 |
+
import soundfile
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
logger.info(f"Loading audio file: {audio_file}")
|
| 222 |
+
audio, sr = librosa.load(str(audio_file), sr=None, mono=False)
|
| 223 |
+
|
| 224 |
+
slicer = Slicer(
|
| 225 |
+
sr=sr,
|
| 226 |
+
threshold=args.db_thresh,
|
| 227 |
+
min_length=args.min_length,
|
| 228 |
+
min_interval=args.min_interval,
|
| 229 |
+
hop_size=args.hop_size,
|
| 230 |
+
max_sil_kept=args.max_sil_kept
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Get non-silence chunks with their positions
|
| 234 |
+
chunks_with_pos = slicer.slice(audio)
|
| 235 |
+
|
| 236 |
+
file_basename = os.path.basename(str(audio_file)).rsplit('.', maxsplit=1)[0]
|
| 237 |
+
logger.info(f"Saving {len(chunks_with_pos)} non-silence audio chunks from {file_basename}...")
|
| 238 |
+
|
| 239 |
+
for i, (pos, chunk) in enumerate(chunks_with_pos):
|
| 240 |
+
if len(chunk.shape) > 1:
|
| 241 |
+
chunk = chunk.T
|
| 242 |
+
|
| 243 |
+
output_file = os.path.join(out_dir, f'{file_basename}_{i}_pos_{pos}.wav')
|
| 244 |
+
soundfile.write(output_file, chunk, sr)
|
| 245 |
+
|
| 246 |
+
logger.info(f"Finished processing {audio_file}")
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.error(f"Error processing {audio_file}: {str(e)}")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == '__main__':
|
| 252 |
+
main()
|