Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| import tempfile | |
| import gc | |
| import traceback | |
| import os | |
| import requests | |
| import spaces | |
| from slicer import Slicer | |
| from infer import ( | |
| load_models, | |
| load_audio, | |
| apply_fade, | |
| process_segment, | |
| batch_process_segments | |
| ) | |
| # Global variables for models | |
| global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg, device | |
| svc_model = vocoder = rmvpe = hubert = rms_extractor = spk2idx = dataset_cfg = None | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Set default model path | |
| DEFAULT_MODEL_PATH = "pretrained/dit-768-12_nanami.ckpt" | |
| # HuggingFace repository URL for the model | |
| HF_MODEL_URL = "https://huggingface.co/Pur1zumu/RIFT-SVC-finetuned/resolve/main/dit-768-12_nanami.ckpt" | |
| # Maximum audio duration in seconds to avoid memory issues | |
| MAX_AUDIO_DURATION = 300 # 5 minutes | |
| def initialize_models(model_path=DEFAULT_MODEL_PATH): | |
| global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg | |
| # Always use FP16 by default | |
| use_fp16 = True | |
| # Clean up memory before loading models | |
| if 'svc_model' in globals() and svc_model is not None: | |
| del svc_model | |
| del vocoder | |
| del rmvpe | |
| del hubert | |
| del rms_extractor | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| try: | |
| # Check if the model file exists at the default path | |
| temp_model_path = None | |
| if not os.path.exists(model_path): | |
| print(f"Model not found at {model_path}, attempting to download from HuggingFace...") | |
| # Use a persistent temp directory path for reuse between sessions | |
| temp_model_path = os.path.join(tempfile.gettempdir(), "RIFT-SVC-model.ckpt") | |
| # Only download if the model is not already in the temp location | |
| if not os.path.exists(temp_model_path): | |
| try: | |
| # Create the directory if it doesn't exist | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| # Download the model to a temporary file | |
| response = requests.get(HF_MODEL_URL, stream=True) | |
| response.raise_for_status() | |
| with open(temp_model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Model downloaded successfully to {temp_model_path}") | |
| except Exception as e: | |
| print(f"Failed to download model: {str(e)}") | |
| raise Exception(f"Model not found at {DEFAULT_MODEL_PATH} and download failed: {str(e)}") | |
| else: | |
| print(f"Using previously downloaded model from {temp_model_path}") | |
| model_path = temp_model_path | |
| svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg = load_models(model_path, device, use_fp16) | |
| available_speakers = list(spk2idx.keys()) | |
| return available_speakers, f"✅ 模型加载成功!可用说话人: {', '.join(available_speakers)}" | |
| except Exception as e: | |
| error_trace = traceback.format_exc() | |
| return [], f"❌ 加载模型出错: {str(e)}\n\n详细信息: {error_trace}" | |
| def check_audio_length(audio_path, max_duration=MAX_AUDIO_DURATION): | |
| """Check if audio file is too long to process safely""" | |
| try: | |
| info = torchaudio.info(audio_path) | |
| duration = info.num_frames / info.sample_rate | |
| return duration <= max_duration, duration | |
| except Exception: | |
| # If we can't determine the length, we'll try to process it anyway | |
| return True, 0 | |
| def process_with_progress( | |
| progress=gr.Progress(), | |
| input_audio=None, | |
| speaker=None, | |
| key_shift=0, | |
| infer_steps=32, | |
| robust_f0=1, | |
| # Advanced CFG parameters | |
| ds_cfg_strength=0.1, | |
| spk_cfg_strength=1.0, | |
| skip_cfg_strength=0.0, | |
| cfg_skip_layers=6, | |
| cfg_rescale=0.7, | |
| cvec_downsample_rate=2, | |
| # Slicer parameters | |
| slicer_threshold=-30.0, | |
| slicer_min_length=3000, | |
| slicer_min_interval=100, | |
| slicer_hop_size=10, | |
| slicer_max_sil_kept=200, | |
| # Batch processing | |
| batch_size=1 | |
| ): | |
| global svc_model, vocoder, rmvpe, hubert, rms_extractor, spk2idx, dataset_cfg, device | |
| # Fixed parameters | |
| target_loudness = -18.0 | |
| restore_loudness = True | |
| fade_duration = 20.0 | |
| sliced_inference = False | |
| use_fp16 = True # Always use FP16 by default | |
| # Input validation | |
| if input_audio is None: | |
| return None, "❌ 错误: 未提供输入音频。" | |
| if svc_model is None: | |
| return None, "❌ 错误: 模型未加载。请重新加载页面或检查模型路径。" | |
| if speaker is None or speaker not in spk2idx: | |
| return None, f"❌ 错误: 无效的说话人选择。可用说话人: {', '.join(spk2idx.keys())}" | |
| # Check audio length to avoid memory issues | |
| is_safe_length, duration = check_audio_length(input_audio) | |
| if not is_safe_length: | |
| return None, f"❌ 错误: 音频过长 ({duration:.1f} 秒)。允许的最大时长为 {MAX_AUDIO_DURATION} 秒。" | |
| # Process the audio | |
| try: | |
| # Update status message | |
| progress(0, desc="处理中: 加载音频...") | |
| # Convert speaker name to ID | |
| speaker_id = spk2idx[speaker] | |
| # Get config from loaded model | |
| hop_length = 512 | |
| sample_rate = 44100 | |
| # Load audio | |
| audio = load_audio(input_audio, sample_rate) | |
| # Initialize Slicer | |
| slicer = Slicer( | |
| sr=sample_rate, | |
| threshold=slicer_threshold, | |
| min_length=slicer_min_length, | |
| min_interval=slicer_min_interval, | |
| hop_size=slicer_hop_size, | |
| max_sil_kept=slicer_max_sil_kept | |
| ) | |
| progress(0.1, desc="处理中: 切分音频...") | |
| # Slice the input audio | |
| segments_with_pos = slicer.slice(audio) | |
| if not segments_with_pos: | |
| return None, "❌ 错误: 在输入文件中未找到有效的音频片段。" | |
| # Calculate fade size in samples | |
| fade_samples = int(fade_duration * sample_rate / 1000) | |
| # Process segments | |
| result_audio = np.zeros(len(audio) + fade_samples) # Extra space for potential overlap | |
| progress(0.2, desc="处理中: 开始转换...") | |
| with torch.no_grad(): | |
| if batch_size > 1: | |
| # Use batch processing | |
| progress_desc = f"处理中: 批次 {{0}}/{{1}}" | |
| processed_segments = batch_process_segments( | |
| segments_with_pos, svc_model, vocoder, rmvpe, hubert, rms_extractor, | |
| speaker_id, sample_rate, hop_length, device, | |
| key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength, | |
| skip_cfg_strength, cfg_skip_layers, cfg_rescale, | |
| cvec_downsample_rate, target_loudness, restore_loudness, | |
| robust_f0, use_fp16, batch_size, progress, progress_desc | |
| ) | |
| for idx, (start_sample, audio_out, expected_length) in enumerate(processed_segments): | |
| # Apply fades | |
| if idx > 0: # Not first segment | |
| audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True) | |
| result_audio[start_sample:start_sample + fade_samples] *= \ | |
| np.linspace(1, 0, fade_samples) # Fade out previous | |
| if idx < len(processed_segments) - 1: # Not last segment | |
| audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out | |
| # Add to result | |
| result_audio[start_sample:start_sample + len(audio_out)] += audio_out | |
| # Clean up memory after each segment | |
| if idx % 5 == 0: # Clean up every 5 segments | |
| torch.cuda.empty_cache() | |
| else: | |
| # Use sequential processing | |
| for i, (start_sample, chunk) in enumerate(segments_with_pos): | |
| segment_progress = 0.2 + (0.7 * (i / len(segments_with_pos))) | |
| progress(segment_progress, desc=f"处理中: 片段 {i+1}/{len(segments_with_pos)}") | |
| # Process the segment | |
| audio_out = process_segment( | |
| chunk, svc_model, vocoder, rmvpe, hubert, rms_extractor, | |
| speaker_id, sample_rate, hop_length, device, | |
| key_shift, infer_steps, ds_cfg_strength, spk_cfg_strength, | |
| skip_cfg_strength, cfg_skip_layers, cfg_rescale, | |
| cvec_downsample_rate, target_loudness, restore_loudness, sliced_inference, | |
| robust_f0, use_fp16 | |
| ) | |
| # Ensure consistent length | |
| expected_length = len(chunk) | |
| if len(audio_out) > expected_length: | |
| audio_out = audio_out[:expected_length] | |
| elif len(audio_out) < expected_length: | |
| audio_out = np.pad(audio_out, (0, expected_length - len(audio_out)), 'constant') | |
| # Apply fades | |
| if i > 0: # Not first segment | |
| audio_out = apply_fade(audio_out.copy(), fade_samples, fade_in=True) | |
| result_audio[start_sample:start_sample + fade_samples] *= \ | |
| np.linspace(1, 0, fade_samples) # Fade out previous | |
| if i < len(segments_with_pos) - 1: # Not last segment | |
| audio_out[-fade_samples:] *= np.linspace(1, 0, fade_samples) # Fade out | |
| # Add to result | |
| result_audio[start_sample:start_sample + len(audio_out)] += audio_out | |
| # Clean up memory after each segment | |
| torch.cuda.empty_cache() | |
| progress(0.9, desc="处理中: 完成音频...") | |
| # Trim any extra padding | |
| result_audio = result_audio[:len(audio)] | |
| # Create a temporary file to save the result | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
| output_path = temp_file.name | |
| # Save output | |
| torchaudio.save(output_path, torch.from_numpy(result_audio).unsqueeze(0).float(), sample_rate) | |
| progress(1.0, desc="处理完成!") | |
| batch_text = f"批处理大小 {batch_size}" if batch_size > 1 else "顺序处理" | |
| return (sample_rate, result_audio), f"✅ 转换完成! 已转换为 **{speaker}** 并调整 **{key_shift}** 个半音。{batch_text}" | |
| except RuntimeError as e: | |
| # Handle CUDA out of memory errors | |
| if "CUDA out of memory" in str(e): | |
| # Clean up memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return None, f"❌ 错误: 内存不足。请尝试更短的音频文件或减少推理步骤。" | |
| else: | |
| return None, f"❌ 转换过程中出错: {str(e)}" | |
| except Exception as e: | |
| error_trace = traceback.format_exc() | |
| return None, f"❌ 转换过程中出错: {str(e)}\n\n详细信息: {error_trace}" | |
| finally: | |
| # Clean up memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def create_ui(): | |
| # CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| } | |
| .footer { | |
| margin-top: 20px; | |
| text-align: center; | |
| font-size: 0.9em; | |
| color: #666; | |
| } | |
| .title { | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| margin-bottom: 20px; | |
| color: #666; | |
| } | |
| .button-primary { | |
| background-color: #5460DE !important; | |
| } | |
| .output-message { | |
| margin-top: 10px; | |
| padding: 10px; | |
| border-radius: 4px; | |
| background-color: #f8f9fa; | |
| border-left: 4px solid #5460DE; | |
| } | |
| .error-message { | |
| color: #d62828; | |
| font-weight: bold; | |
| } | |
| .success-message { | |
| color: #588157; | |
| font-weight: bold; | |
| } | |
| .info-box { | |
| background-color: #f8f9fa; | |
| border-left: 4px solid #5460DE; | |
| padding: 10px; | |
| margin: 10px 0; | |
| border-radius: 4px; | |
| } | |
| """ | |
| # Initialize models | |
| available_speakers, init_message = initialize_models() | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(), title="RIFT-SVC 声音转换") as app: | |
| gr.HTML(""" | |
| <div class="title"> | |
| <h1>🎤 RIFT-SVC 歌声音色转换 (七海Nanami demo)</h1> | |
| </div> | |
| <div class="subtitle"> | |
| <h3>使用 RIFT-SVC 模型将歌声或语音转换为七海Nanami的音色</h3> | |
| </div> | |
| <div class="info-box"> | |
| <p>🔗 <strong>想要微调自己的说话人?</strong> 请访问 <a href="https://github.com/Pur1zumu/RIFT-SVC" target="_blank">RIFT-SVC GitHub 仓库</a> 获取完整的训练和微调指南。</p> | |
| </div> | |
| <div class="info-box"> | |
| <p>🎤 <strong>数据来源说明:</strong> 该demo数据来源为b站上快速爬取的约30分钟七海唱歌片段,直接分离人声后进行训练,没有额外筛选。</p> | |
| </div> | |
| <div class="info-box"> | |
| <p>📝 <strong>注意:</strong> 为获得最佳效果,请使用背景噪音较少的干净音频。最大音频长度为5分钟。<strong>建议用较短的音频测试避免平台意外中断任务。</strong></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left column (input parameters) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### 📥 输入") | |
| model_path = gr.Textbox(label="模型路径", value=DEFAULT_MODEL_PATH, interactive=True) | |
| input_audio = gr.Audio(label="输入音频文件", type="filepath", elem_id="input_audio") | |
| reload_btn = gr.Button("🔄 重新加载模型", elem_id="reload_btn") | |
| with gr.Accordion("⚙️ 基本参数", open=True): | |
| speaker = gr.Dropdown(choices=available_speakers, label="目标说话人", interactive=True, elem_id="speaker") | |
| key_shift = gr.Slider(minimum=-12, maximum=12, step=1, value=0, label="音调调整(半音)", elem_id="key_shift") | |
| infer_steps = gr.Slider(minimum=8, maximum=64, step=1, value=32, label="推理步数", elem_id="infer_steps", | |
| info="更低的值 = 更快但质量较低,更高的值 = 更慢但质量更好") | |
| robust_f0 = gr.Radio(choices=[0, 1, 2], value=1, label="音高滤波", | |
| info="0=无,1=轻度过滤,2=强力过滤(有助于解决断音/破音问题)", | |
| elem_id="robust_f0") | |
| batch_size = gr.Slider(minimum=1, maximum=64, step=1, value=4, label="批处理大小", | |
| info="使用批处理可以加速转换,但需要更多VRAM。1=不使用批处理", | |
| elem_id="batch_size") | |
| with gr.Accordion("🔬 高级CFG参数", open=True): | |
| ds_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, | |
| label="内容向量引导强度", | |
| info="更高的值可以改善内容保留和咬字清晰度。过高会用力过猛。", | |
| elem_id="ds_cfg_strength") | |
| spk_cfg_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, | |
| label="说话人引导强度", | |
| info="更高的值可以增强说话人相似度。过高可能导致音色失真。", | |
| elem_id="spk_cfg_strength") | |
| skip_cfg_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, | |
| label="层引导强度(实验性功能)", | |
| info="增强指定层的特征渲染。效果取决于目标层的功能。", | |
| elem_id="skip_cfg_strength") | |
| cfg_skip_layers = gr.Number(value=6, label="CFG跳过层(实验性功能)", precision=0, | |
| info="目标增强层下标", | |
| elem_id="cfg_skip_layers") | |
| cfg_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, | |
| label="CFG重缩放因子", | |
| info="约束整体引导强度。当引导效果过于强烈时使用调高该值,减少失真和噪音。", | |
| elem_id="cfg_rescale") | |
| cvec_downsample_rate = gr.Radio(choices=[1, 2, 4, 8], value=2, | |
| label="用于反向引导的内容向量下采样率", | |
| info="更高的值(可能)可以提高内容清晰度。", | |
| elem_id="cvec_downsample_rate") | |
| with gr.Accordion("✂️ 切片参数", open=False): | |
| slicer_threshold = gr.Slider(minimum=-60.0, maximum=-20.0, step=0.1, value=-30.0, | |
| label="阈值 (dB)", | |
| info="静音检测阈值", | |
| elem_id="slicer_threshold") | |
| slicer_min_length = gr.Slider(minimum=1000, maximum=10000, step=100, value=3000, | |
| label="最小长度 (毫秒)", | |
| info="最小片段长度", | |
| elem_id="slicer_min_length") | |
| slicer_min_interval = gr.Slider(minimum=10, maximum=500, step=10, value=100, | |
| label="最小静音间隔 (毫秒)", | |
| info="分割片段的最小间隔", | |
| elem_id="slicer_min_interval") | |
| slicer_hop_size = gr.Slider(minimum=1, maximum=20, step=1, value=10, | |
| label="跳跃大小 (毫秒)", | |
| info="片段检测窗口大小", | |
| elem_id="slicer_hop_size") | |
| slicer_max_sil_kept = gr.Slider(minimum=10, maximum=1000, step=10, value=200, | |
| label="保留的最大静音 (毫秒)", | |
| info="保留在每个片段边缘的最大静音长度", | |
| elem_id="slicer_max_sil_kept") | |
| # Right column (output) | |
| with gr.Column(scale=1): | |
| convert_btn = gr.Button("🎵 转换声音", variant="primary", elem_id="convert_btn") | |
| gr.Markdown("### 📤 输出") | |
| output_audio = gr.Audio(label="转换后的音频", elem_id="output_audio", autoplay=False, show_share_button=False) | |
| output_message = gr.Markdown(init_message, elem_id="output_message", elem_classes="output-message") | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| <h4>🔍 快速提示</h4> | |
| <ul> | |
| <li><strong>音调调整:</strong> 以半音为单位上调或下调音高。</li> | |
| <li><strong>推理步骤:</strong> 步骤越多 = 质量越好但速度越慢。</li> | |
| <li><strong>音高滤波:</strong> 有助于提高具有挑战性的音频中的音高稳定性。</li> | |
| <li><strong>批处理大小:</strong> 值越大 = 转换越快,但需要更多GPU内存。遇到内存不足时降低此值。</li> | |
| <li><strong>CFG参数:</strong> 调整转换质量和音色。</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Define button click events | |
| reload_btn.click( | |
| fn=initialize_models, | |
| inputs=[model_path], | |
| outputs=[speaker, output_message] | |
| ) | |
| # Updated convert button click event | |
| convert_btn.click( | |
| fn=lambda: "⏳ 处理中... 请稍候。", | |
| inputs=None, | |
| outputs=output_message, | |
| queue=False | |
| ).then( | |
| fn=process_with_progress, | |
| inputs=[ | |
| input_audio, speaker, key_shift, infer_steps, robust_f0, | |
| ds_cfg_strength, spk_cfg_strength, skip_cfg_strength, cfg_skip_layers, cfg_rescale, cvec_downsample_rate, | |
| slicer_threshold, slicer_min_length, slicer_min_interval, slicer_hop_size, slicer_max_sil_kept, | |
| batch_size | |
| ], | |
| outputs=[output_audio, output_message], | |
| show_progress_on=output_audio | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_ui() | |
| app.launch(share=True) |