|
|
import sys |
|
|
import os |
|
|
|
|
|
import time |
|
|
import json |
|
|
import torch |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from codeclm.trainer.codec_song_pl import CodecLM_PL |
|
|
from codeclm.models import CodecLM |
|
|
from third_party.demucs.models.pretrained import get_model_from_yaml |
|
|
|
|
|
|
|
|
class Separator: |
|
|
def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: |
|
|
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): |
|
|
self.device = torch.device(f"cuda:{gpu_id}") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) |
|
|
|
|
|
def init_demucs_model(self, model_path, config_path): |
|
|
model = get_model_from_yaml(config_path, model_path) |
|
|
model.to(self.device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def load_audio(self, f): |
|
|
a, fs = torchaudio.load(f) |
|
|
if (fs != 48000): |
|
|
a = torchaudio.functional.resample(a, fs, 48000) |
|
|
if a.shape[-1] >= 48000*10: |
|
|
a = a[..., :48000*10] |
|
|
else: |
|
|
a = torch.cat([a, a], -1) |
|
|
return a[:, 0:48000*10] |
|
|
|
|
|
def run(self, audio_path, output_dir='tmp', ext=".flac"): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
name, _ = os.path.splitext(os.path.split(audio_path)[-1]) |
|
|
output_paths = [] |
|
|
|
|
|
for stem in self.demucs_model.sources: |
|
|
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") |
|
|
if os.path.exists(output_path): |
|
|
output_paths.append(output_path) |
|
|
if len(output_paths) == 1: |
|
|
vocal_path = output_paths[0] |
|
|
else: |
|
|
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) |
|
|
for path in [drums_path, bass_path, other_path]: |
|
|
os.remove(path) |
|
|
full_audio = self.load_audio(audio_path) |
|
|
vocal_audio = self.load_audio(vocal_path) |
|
|
bgm_audio = full_audio - vocal_audio |
|
|
return full_audio, vocal_audio, bgm_audio |
|
|
|
|
|
|
|
|
def main_sep(): |
|
|
torch.backends.cudnn.enabled = False |
|
|
OmegaConf.register_new_resolver("eval", lambda x: eval(x)) |
|
|
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) |
|
|
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0]) |
|
|
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) |
|
|
cfg = OmegaConf.load(sys.argv[1]) |
|
|
save_dir = sys.argv[2] |
|
|
input_jsonl = sys.argv[3] |
|
|
sidx = sys.argv[4] |
|
|
cfg.mode = 'inference' |
|
|
max_duration = cfg.max_dur |
|
|
|
|
|
|
|
|
model_light = CodecLM_PL(cfg) |
|
|
|
|
|
model_light = model_light.eval().cuda() |
|
|
model_light.audiolm.cfg = cfg |
|
|
model = CodecLM(name = "tmp", |
|
|
lm = model_light.audiolm, |
|
|
audiotokenizer = model_light.audio_tokenizer, |
|
|
max_duration = max_duration, |
|
|
seperate_tokenizer = model_light.seperate_tokenizer, |
|
|
) |
|
|
separator = Separator() |
|
|
|
|
|
cfg_coef = 1.5 |
|
|
temp = 1.0 |
|
|
top_k = 50 |
|
|
top_p = 0.0 |
|
|
record_tokens = True |
|
|
record_window = 50 |
|
|
|
|
|
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef, |
|
|
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window) |
|
|
os.makedirs(save_dir + "/token", exist_ok=True) |
|
|
os.makedirs(save_dir + "/audios", exist_ok=True) |
|
|
os.makedirs(save_dir + "/jsonl", exist_ok=True) |
|
|
|
|
|
with open(input_jsonl, "r") as fp: |
|
|
lines = fp.readlines() |
|
|
|
|
|
new_items = [] |
|
|
for line in lines: |
|
|
item = json.loads(line) |
|
|
target_name = f"{save_dir}/token/{item['idx']}_s{sidx}.npy" |
|
|
target_wav_name = f"{save_dir}/audios/{item['idx']}_s{sidx}.flac" |
|
|
descriptions = item["descriptions"] |
|
|
lyric = item["gt_lyric"] |
|
|
|
|
|
start_time = time.time() |
|
|
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path']) |
|
|
generate_inp = { |
|
|
'lyrics': [lyric.replace(" ", " ")], |
|
|
'descriptions': [descriptions], |
|
|
'melody_wavs': pmt_wav, |
|
|
'vocal_wavs': vocal_wav, |
|
|
'bgm_wavs': bgm_wav, |
|
|
} |
|
|
|
|
|
mid_time = time.time() |
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
|
tokens = model.generate(**generate_inp, return_tokens=True) |
|
|
end_time = time.time() |
|
|
if tokens.shape[-1] > 3000: |
|
|
tokens = tokens[..., :3000] |
|
|
|
|
|
with torch.no_grad(): |
|
|
wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav) |
|
|
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate) |
|
|
np.save(target_name, tokens.cpu().squeeze(0).numpy()) |
|
|
print(f"process{item['idx']}, demucs cost {mid_time - start_time}s, lm cos {end_time - mid_time}") |
|
|
|
|
|
item["idx"] = f"{item['idx']}_s{sidx}" |
|
|
item["tk_path"] = target_name |
|
|
new_items.append(item) |
|
|
|
|
|
src_jsonl_name = os.path.split(input_jsonl)[-1] |
|
|
with open(f"{save_dir}/jsonl/{src_jsonl_name}-s{sidx}.jsonl", "w", encoding='utf-8') as fw: |
|
|
for item in new_items: |
|
|
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main_sep() |
|
|
|