Spaces:
Sleeping
Sleeping
| import argparse | |
| import models | |
| import time | |
| import sys | |
| import warnings | |
| #from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from scipy.stats import norm | |
| from scipy.io.wavfile import write | |
| from torch.nn.utils.rnn import pad_sequence | |
| #import style_controller | |
| from common.utils import load_wav_to_torch | |
| from common import utils, layers | |
| from common.text.text_processing import TextProcessing | |
| import os | |
| #os.environ["CUDA_VISIBLE_DEVICES"]="" | |
| #device = "cuda:0" | |
| device = "cpu" | |
| vocoder = "hifigan" | |
| SHARPEN = True | |
| from hifigan.data_function import MAX_WAV_VALUE, mel_spectrogram | |
| from hifigan.models import Denoiser | |
| import json | |
| from scipy import ndimage | |
| import os | |
| def parse_args(parser): | |
| """ | |
| Parse commandline arguments. | |
| """ | |
| parser.add_argument('-i', '--input', type=str, required=False, | |
| help='Full path to the input text (phareses separated by newlines)') | |
| parser.add_argument('-o', '--output', default=None, | |
| help='Output folder to save audio (file per phrase)') | |
| parser.add_argument('--log-file', type=str, default=None, | |
| help='Path to a DLLogger log file') | |
| parser.add_argument('--save-mels', action='store_true', help='') | |
| parser.add_argument('--cuda', action='store_true', | |
| help='Run inference on a GPU using CUDA') | |
| parser.add_argument('--cudnn-benchmark', action='store_true', | |
| help='Enable cudnn benchmark mode') | |
| parser.add_argument('--fastpitch', type=str, default='hf://kathiasi/multi-sami/FastPitch_checkpoint_200.pt', | |
| help='Full path to the generator checkpoint file (skip to use ground truth mels). Can be a local path or an HF path like hf://<repo_id>/<filename>') ######### | |
| parser.add_argument('-d', '--denoising-strength', default=0.01, type=float, | |
| help='WaveGlow denoising') | |
| parser.add_argument('-sr', '--sampling-rate', default=22050, type=int, | |
| help='Sampling rate') | |
| parser.add_argument('--stft-hop-length', type=int, default=256, | |
| help='STFT hop length for estimating audio length from mel size') | |
| parser.add_argument('--amp', action='store_true',default=False, | |
| help='Inference with AMP') | |
| parser.add_argument('-bs', '--batch-size', type=int, default=1) | |
| parser.add_argument('--ema', action='store_true', | |
| help='Use EMA averaged model (if saved in checkpoints)') | |
| parser.add_argument('--speaker', type=int, default=0, | |
| help='Speaker ID for a multi-speaker model') | |
| parser.add_argument('--language', type=int, default=0, | |
| help='Language ID for a multilingual model') | |
| parser.add_argument('--p-arpabet', type=float, default=0.0, help='') ################ | |
| text_processing = parser.add_argument_group('Text processing parameters') | |
| text_processing.add_argument('--text-cleaners', nargs='*', | |
| default=['basic_cleaners'], type=str, | |
| help='Type of text cleaners for input text') | |
| text_processing.add_argument('--symbol-set', type=str, default='all_sami', ################# | |
| help='Define symbol set for input text') | |
| cond = parser.add_argument_group('conditioning on additional attributes') | |
| cond.add_argument('--n-speakers', type=int, default=10, | |
| help='Number of speakers in the model.') | |
| cond.add_argument('--n-languages', type=int, default=3, | |
| help='Number of languages in the model.') | |
| return parser | |
| class AttrDict(dict): | |
| def __init__(self, *args, **kwargs): | |
| super(AttrDict, self).__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| def load_model_from_ckpt(checkpoint_path, ema, model): | |
| checkpoint_data = torch.load(checkpoint_path,map_location = device) | |
| status = '' | |
| if 'state_dict' in checkpoint_data: | |
| sd = checkpoint_data['state_dict'] | |
| if ema and 'ema_state_dict' in checkpoint_data: | |
| sd = checkpoint_data['ema_state_dict'] | |
| status += ' (EMA)' | |
| elif ema and not 'ema_state_dict' in checkpoint_data: | |
| print(f'WARNING: EMA weights missing for {checkpoint_data}') | |
| if any(key.startswith('module.') for key in sd): | |
| sd = {k.replace('module.', ''): v for k,v in sd.items()} | |
| status += ' ' + str(model.load_state_dict(sd, strict=False)) | |
| else: | |
| model = checkpoint_data['model'] | |
| print(f'Loaded {checkpoint_path}{status}') | |
| return model | |
| def load_and_setup_model(model_name, parser, checkpoint, amp, device, | |
| unk_args=[], forward_is_infer=False, ema=True, | |
| jitable=False): | |
| model_parser = models.parse_model_args(model_name, parser, add_help=False) | |
| model_args, model_unk_args = model_parser.parse_known_args() | |
| unk_args[:] = list(set(unk_args) & set(model_unk_args)) | |
| setattr(model_args, "energy_conditioning",True) | |
| model_config = models.get_model_config(model_name, model_args) | |
| # print(model_config) | |
| model = models.get_model(model_name, model_config, device, | |
| forward_is_infer=forward_is_infer, | |
| jitable=jitable) | |
| if checkpoint is not None: | |
| model = load_model_from_ckpt(checkpoint, ema, model) | |
| amp = False | |
| if amp: | |
| model.half() | |
| model.eval() | |
| return model.to(device) | |
| class Synthesizer: | |
| def _load_pyt_or_ts_model(self, model_name, ckpt_path, format = 'pyt'): | |
| if format == 'ts': | |
| model = models.load_and_setup_ts_model(model_name, ckpt_path, | |
| False, device) | |
| model_train_setup = {} | |
| return model, model_train_setup | |
| is_ts_based_infer = False | |
| model, _, model_train_setup = models.load_and_setup_model( | |
| model_name, self.parser, ckpt_path, False, device, | |
| unk_args=self.unk_args, forward_is_infer=True, jitable=is_ts_based_infer) | |
| if is_ts_based_infer: | |
| model = torch.jit.script(model) | |
| return model, model_train_setup | |
| def __init__(self): | |
| parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference', | |
| allow_abbrev=False) | |
| self.parser = parse_args(parser) | |
| self.args, self.unk_args = self.parser.parse_known_args() | |
| # If an HF_MODEL_ID secret is present, prefer it as the FastPitch source. | |
| # HF_MODEL_ID should be in the form "<repo_id>/<path/inside/repo>" or start with "hf://". | |
| env_hf_model = os.environ.get('HF_MODEL_ID') | |
| if env_hf_model: | |
| # Only override when the CLI argument looks like a HF placeholder or wasn't explicitly set. | |
| # This lets a user pass an explicit --fastpitch on the command line to override the secret. | |
| if (not self.args.fastpitch) or self.args.fastpitch.startswith('hf://'): | |
| if env_hf_model.startswith('hf://'): | |
| self.args.fastpitch = env_hf_model | |
| else: | |
| self.args.fastpitch = 'hf://' + env_hf_model | |
| # If the fastpitch path points to a Hugging Face private repo, download it at runtime. | |
| # Expected format: hf://<repo_id>/<path/inside/repo>. Example: hf://username/private-fastpitch/FastPitch_checkpoint_200.pt | |
| if isinstance(self.args.fastpitch, str) and self.args.fastpitch.startswith('hf://'): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except Exception as e: | |
| raise RuntimeError('huggingface_hub is required to download model from HF. Add it to requirements and set HF_TOKEN secret.') from e | |
| hf_ref = self.args.fastpitch[len('hf://'):] | |
| # Expect format: <owner>/<repo>/<path/inside/repo> | |
| parts = hf_ref.split('/') | |
| if len(parts) < 3: | |
| raise ValueError('fastpitch HF path must be in format hf://<owner>/<repo>/<path/inside/repo> e.g. hf://username/repo/FastPitch_checkpoint_200.pt') | |
| repo_id = parts[0] + '/' + parts[1] | |
| filename = '/'.join(parts[2:]) | |
| token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGINGFACE_HUB_TOKEN') | |
| if not token: | |
| raise RuntimeError('HF token not found. Set HF_TOKEN or HUGGINGFACE_HUB_TOKEN in the environment/secrets of the Space.') | |
| # download to local cache and replace the checkpoint path with local file path | |
| local_ckpt = hf_hub_download(repo_id=repo_id, filename=filename, token=token) | |
| self.args.fastpitch = local_ckpt | |
| self.generator = load_and_setup_model( | |
| 'FastPitch', parser, self.args.fastpitch, self.args.amp, device, | |
| unk_args=self.unk_args, forward_is_infer=True, ema=self.args.ema, | |
| jitable=False) | |
| self.hifigan_model = "pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt" # Better with Sander! | |
| #self.hifigan_model = "pretrained_models/hifigan/hifigan_gen_checkpoint_6500.pt" | |
| #self.vocoder = UnivNetModel.from_pretrained(model_name="tts_en_libritts_univnet") | |
| self.vocoder, voc_train_setup= self._load_pyt_or_ts_model('HiFi-GAN', self.hifigan_model) | |
| self.denoiser = Denoiser(self.vocoder,device=device) #, win_length=self.args.win_length).to(device) | |
| self.tp = TextProcessing(self.args.symbol_set, self.args.text_cleaners, p_arpabet=0.0) | |
| def unsharp_mask(self, img, radius=1, amount=1): | |
| blurred = ndimage.gaussian_filter(img, radius) | |
| sharpened = img + amount * ( img - blurred) | |
| return sharpened | |
| def speak(self, text, output_file="/tmp/tmp", spkr=0, lang=0, l_weight=1, s_weight=1, pace=0.95, clarity=1): | |
| text = self.tp.encode_text(text) | |
| #text = [9]+self.tp.encode_text(text)+[9] | |
| text = torch.LongTensor([text]).to(device) | |
| #probs = surprisals | |
| for p in [0]: | |
| with torch.no_grad(): | |
| print(s_weight, l_weight) | |
| mel, mel_lens, *_ = self.generator(text, pace, max_duration=15, speaker=spkr, language=lang, speaker_weight=s_weight, language_weight=l_weight) #, ref_vector=embedding, speaker=speaker_i) #, **gen_kw, speaker 0 = bad audio, speaker 1 = better audio | |
| if SHARPEN: | |
| mel_np = mel.float().data.cpu().numpy()[0] | |
| tgt_min = -11 | |
| tgt_max = 1.25 | |
| #print(np.min(mel_np), np.max(mel_np)) | |
| mel_np = self.unsharp_mask(mel_np, radius = 0.5, amount=0.5) | |
| mel_np = self.unsharp_mask(mel_np, radius = 3, amount=.05) | |
| # mel_np = self.unsharp_mask(mel_np, radius = 7, amount=0.05) | |
| for i in range(0, 80): | |
| mel_np[i,:]+=(i-30)*clarity*0.02 | |
| mel_np = (mel_np-np.min(mel_np))/ (np.max(mel_np)-np.min(mel_np)) * (tgt_max - tgt_min) + tgt_min | |
| mel[0] = torch.from_numpy(mel_np).float().to(device) | |
| """ | |
| mel_np = mel.float().data.cpu().numpy()[0] | |
| blurred_f = ndimage.gaussian_filter(mel_np, 1.0) #3 | |
| alpha = 0.2 #0.3 ta | |
| mel_np = mel_np + alpha * (mel_np - blurred_f) | |
| blurred_f = ndimage.gaussian_filter(mel_np, 3.0) #3 | |
| alpha = 0.1 # 0.1 ta | |
| sharpened = mel_np + alpha * (mel_np - blurred_f) | |
| for i in range(0,80): | |
| sharpened[i, :]+=(i-40)*0.01 #0.01 ta | |
| mel[0] = torch.from_numpy(sharpened).float().to(device) | |
| """ | |
| with torch.no_grad(): | |
| y_g_hat = self.vocoder(mel).float() ########### | |
| #y_g_hat = self.denoiser(y_g_hat.squeeze(1), strength=0.01) #[:, 0] | |
| audio = y_g_hat.squeeze() | |
| # normalize volume | |
| audio = audio/torch.max(torch.abs(audio))*0.95*32768 | |
| audio = audio.cpu().numpy().astype('int16') | |
| write(output_file+".wav", 22050, audio) | |
| os.system("play -q "+output_file+".wav") | |
| return audio | |
| if __name__ == '__main__': | |
| syn = Synthesizer() | |
| hifigan = syn.hifigan_model | |
| hifigan_n = hifigan.replace(".pt", "") | |
| fastpitch = syn.args.fastpitch | |
| fastpitch_n = fastpitch.replace(".pt", "") | |
| print(hifigan_n + " " + fastpitch_n) | |
| hifigan_n_short = hifigan_n.split("/") | |
| hifigan_n_shorter = hifigan_n_short[2].split("_") | |
| hifigan_n_shortest = hifigan_n_shorter[3] | |
| fastpitch_n_short = fastpitch_n.split("/") | |
| fastpitch_n_shorter = fastpitch_n_short[1].split("_") | |
| fastpitch_n_shortest = fastpitch_n_shorter[2] | |
| #syn.speak("Gå lij riek mælggadav vádtsám, de bådij vijmak tjáppa vuobmáj.") | |
| i = 0 | |
| spkr = 1 | |
| lang = 1 | |
| while (1==1): | |
| text = input(">") | |
| text1 = text.split(" ") | |
| syn.speak(text, output_file="/tmp/tmp.wav", spkr=6, lang=1) | |
| syn.speak(text, output_file="/tmp/tmp.wav", spkr=7, lang=1) | |
| continue | |
| for s in range(1,10): | |
| for l in range(3): ## | |
| print("speaker", s, "language", l) ## | |
| syn.speak(text, output_file="/tmp/"+str(i)+"_"+text1[0]+"_"+str(s)+"_"+str(l)+"_FP_"+fastpitch_n_shortest+"univnet", spkr=s, lang=l) | |
| #syn.speak(text, output_file="/home/hiovain/DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitchMulti/inf_output_multi/"+str(i)+"_"+text1[0]+"_"+str(s)+"_"+str(l)+"_FP_"+fastpitch_n_shortest+"univnet", spkr=s, lang=l) | |
| i += 1 | |