Spaces:
Sleeping
Sleeping
| # ***************************************************************************** | |
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # * Redistributions of source code must retain the above copyright | |
| # notice, this list of conditions and the following disclaimer. | |
| # * Redistributions in binary form must reproduce the above copyright | |
| # notice, this list of conditions and the following disclaimer in the | |
| # documentation and/or other materials provided with the distribution. | |
| # * Neither the name of the NVIDIA CORPORATION nor the | |
| # names of its contributors may be used to endorse or promote products | |
| # derived from this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| # | |
| # ***************************************************************************** | |
| import argparse | |
| import copy | |
| import os | |
| import time | |
| from collections import defaultdict, OrderedDict | |
| from itertools import cycle | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import amp_C | |
| from apex.optimizers import FusedAdam, FusedLAMB | |
| from torch.nn.parallel import DistributedDataParallel | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| import common.tb_dllogger as logger | |
| import models | |
| from common.tb_dllogger import log | |
| from common.repeated_dataloader import (RepeatedDataLoader, | |
| RepeatedDistributedSampler) | |
| from common.text import cmudict | |
| from common.utils import BenchmarkStats, Checkpointer, prepare_tmp | |
| from fastpitch.attn_loss_function import AttentionBinarizationLoss | |
| from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset | |
| from fastpitch.loss_function import FastPitchLoss | |
| import matplotlib.pyplot as plt | |
| def parse_args(parser): | |
| parser.add_argument('-o', '--output', type=str, required=True, | |
| help='Directory to save checkpoints') | |
| parser.add_argument('-d', '--dataset-path', type=str, default='./', | |
| help='Path to dataset') | |
| parser.add_argument('--log-file', type=str, default=None, | |
| help='Path to a DLLogger log file') | |
| train = parser.add_argument_group('training setup') | |
| train.add_argument('--epochs', type=int, required=True, | |
| help='Number of total epochs to run') | |
| train.add_argument('--epochs-per-checkpoint', type=int, default=50, | |
| help='Number of epochs per checkpoint') | |
| train.add_argument('--checkpoint-path', type=str, default=None, | |
| help='Checkpoint path to resume training') | |
| train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)), | |
| type=int, nargs='+', | |
| help='Milestone checkpoints to keep from removing') | |
| train.add_argument('--resume', action='store_true', | |
| help='Resume training from the last checkpoint') | |
| train.add_argument('--seed', type=int, default=1234, | |
| help='Seed for PyTorch random number generators') | |
| train.add_argument('--amp', action='store_true', | |
| help='Enable AMP') | |
| train.add_argument('--cuda', action='store_true', | |
| help='Run on GPU using CUDA') | |
| train.add_argument('--cudnn-benchmark', action='store_true', | |
| help='Enable cudnn benchmark mode') | |
| train.add_argument('--ema-decay', type=float, default=0, | |
| help='Discounting factor for training weights EMA') | |
| train.add_argument('--grad-accumulation', type=int, default=1, | |
| help='Training steps to accumulate gradients for') | |
| train.add_argument('--kl-loss-start-epoch', type=int, default=250, | |
| help='Start adding the hard attention loss term') | |
| train.add_argument('--kl-loss-warmup-epochs', type=int, default=100, | |
| help='Gradually increase the hard attention loss term') | |
| train.add_argument('--kl-loss-weight', type=float, default=1.0, | |
| help='Gradually increase the hard attention loss term') | |
| train.add_argument('--benchmark-epochs-num', type=int, default=20, | |
| help='Number of epochs for calculating final stats') | |
| train.add_argument('--validation-freq', type=int, default=1, | |
| help='Validate every N epochs to use less compute') | |
| opt = parser.add_argument_group('optimization setup') | |
| opt.add_argument('--optimizer', type=str, default='lamb', | |
| help='Optimization algorithm') | |
| opt.add_argument('-lr', '--learning-rate', type=float, required=True, | |
| help='Learing rate') | |
| opt.add_argument('--weight-decay', default=1e-6, type=float, | |
| help='Weight decay') | |
| opt.add_argument('--grad-clip-thresh', default=1000.0, type=float, | |
| help='Clip threshold for gradients') | |
| opt.add_argument('-bs', '--batch-size', type=int, required=True, | |
| help='Batch size per GPU') | |
| opt.add_argument('--warmup-steps', type=int, default=1000, | |
| help='Number of steps for lr warmup') | |
| opt.add_argument('--dur-predictor-loss-scale', type=float, | |
| default=1.0, help='Rescale duration predictor loss') | |
| opt.add_argument('--pitch-predictor-loss-scale', type=float, | |
| default=1.0, help='Rescale pitch predictor loss') | |
| opt.add_argument('--attn-loss-scale', type=float, | |
| default=1.0, help='Rescale alignment loss') | |
| data = parser.add_argument_group('dataset parameters') | |
| data.add_argument('--training-files', type=str, nargs='*', required=True, | |
| help='Paths to training filelists.') | |
| data.add_argument('--validation-files', type=str, nargs='*', | |
| required=True, help='Paths to validation filelists') | |
| data.add_argument('--text-cleaners', nargs='*', | |
| default=['english_cleaners'], type=str, | |
| help='Type of text cleaners for input text') | |
| data.add_argument('--symbol-set', type=str, default='english_basic', | |
| help='Define symbol set for input text') | |
| data.add_argument('--p-arpabet', type=float, default=0.0, | |
| help='Probability of using arpabets instead of graphemes ' | |
| 'for each word; set 0 for pure grapheme training') | |
| data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms', | |
| help='Path to the list of heteronyms') | |
| data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b', | |
| help='Path to the pronouncing dictionary') | |
| data.add_argument('--prepend-space-to-text', action='store_true', | |
| help='Capture leading silence with a space token') | |
| data.add_argument('--append-space-to-text', action='store_true', | |
| help='Capture trailing silence with a space token') | |
| data.add_argument('--num-workers', type=int, default=2, # 6 | |
| help='Subprocesses for train and val DataLoaders') | |
| data.add_argument('--trainloader-repeats', type=int, default=100, | |
| help='Repeats the dataset to prolong epochs') | |
| cond = parser.add_argument_group('data for conditioning') | |
| cond.add_argument('--n-speakers', type=int, default=1, | |
| help='Number of speakers in the dataset. ' | |
| 'n_speakers > 1 enables speaker embeddings') | |
| # ANT: added language | |
| cond.add_argument('--n-languages', type=int, default=1, | |
| help='Number of languages in the dataset. ' | |
| 'n_languages > 1 enables language embeddings') | |
| cond.add_argument('--load-pitch-from-disk', action='store_true', | |
| help='Use pitch cached on disk with prepare_dataset.py') | |
| cond.add_argument('--pitch-online-method', default='pyin', | |
| choices=['pyin'], | |
| help='Calculate pitch on the fly during trainig') | |
| cond.add_argument('--pitch-online-dir', type=str, default=None, | |
| help='A directory for storing pitch calculated on-line') | |
| cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203, | |
| help='Normalization value for pitch') | |
| cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038, | |
| help='Normalization value for pitch') | |
| cond.add_argument('--load-mel-from-disk', action='store_true', | |
| help='Use mel-spectrograms cache on the disk') # XXX | |
| audio = parser.add_argument_group('audio parameters') | |
| audio.add_argument('--max-wav-value', default=32768.0, type=float, | |
| help='Maximum audiowave value') | |
| audio.add_argument('--sampling-rate', default=22050, type=int, | |
| help='Sampling rate') | |
| audio.add_argument('--filter-length', default=1024, type=int, | |
| help='Filter length') | |
| audio.add_argument('--hop-length', default=256, type=int, | |
| help='Hop (stride) length') | |
| audio.add_argument('--win-length', default=1024, type=int, | |
| help='Window length') | |
| audio.add_argument('--mel-fmin', default=0.0, type=float, | |
| help='Minimum mel frequency') | |
| audio.add_argument('--mel-fmax', default=8000.0, type=float, | |
| help='Maximum mel frequency') | |
| dist = parser.add_argument_group('distributed setup') | |
| dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0), | |
| help='Rank of the process for multiproc; do not set manually') | |
| dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1), | |
| help='Number of processes for multiproc; do not set manually') | |
| return parser | |
| def reduce_tensor(tensor, num_gpus): | |
| rt = tensor.clone() | |
| dist.all_reduce(rt, op=dist.ReduceOp.SUM) | |
| return rt.true_divide(num_gpus) | |
| def init_distributed(args, world_size, rank): | |
| assert torch.cuda.is_available(), "Distributed mode requires CUDA." | |
| print("Initializing distributed training") | |
| # Set cuda device so everything is done on the right GPU. | |
| torch.cuda.set_device(rank % torch.cuda.device_count()) | |
| # Initialize distributed communication | |
| dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'), | |
| init_method='env://') | |
| print("Done initializing distributed training") | |
| def validate(model, epoch, total_iter, criterion, val_loader, distributed_run, | |
| batch_to_gpu, local_rank, ema=False): | |
| was_training = model.training | |
| model.eval() | |
| tik = time.perf_counter() | |
| with torch.no_grad(): | |
| val_meta = defaultdict(float) | |
| val_num_frames = 0 | |
| for i, batch in enumerate(val_loader): | |
| x, y, num_frames = batch_to_gpu(batch) | |
| y_pred = model(x) | |
| loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum') | |
| if distributed_run: | |
| for k, v in meta.items(): | |
| val_meta[k] += reduce_tensor(v, 1) | |
| val_num_frames += reduce_tensor(num_frames.data, 1).item() | |
| else: | |
| for k, v in meta.items(): | |
| val_meta[k] += v | |
| val_num_frames += num_frames.item() | |
| # NOTE: ugly patch to visualize the first utterance of the validation corpus. | |
| # The goal is to determine if the training is progressing properly | |
| if (i == 0) and (local_rank == 0) and (not ema): | |
| # Plot some debug information | |
| fig, axs = plt.subplots(2, 2, figsize=(21,14)) | |
| # - Mel-spectrogram | |
| pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T | |
| orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32) | |
| axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest') | |
| axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest') | |
| # Prosody | |
| f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32) | |
| f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32) | |
| axs[1,1].plot(f0_ori) | |
| axs[1,1].plot(f0_pred) | |
| # # Duration | |
| # att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32) | |
| # att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32) | |
| # axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest') | |
| if not os.path.exists("debug_epoch/"): | |
| os.makedirs("debug_epoch_laila/") | |
| fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight') | |
| val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()} | |
| val_meta['took'] = time.perf_counter() - tik | |
| log((epoch,) if epoch is not None else (), tb_total_steps=total_iter, | |
| subset='val_ema' if ema else 'val', | |
| data=OrderedDict([ | |
| ('loss', val_meta['loss'].item()), | |
| ('mel_loss', val_meta['mel_loss'].item()), | |
| ('frames/s', val_num_frames / val_meta['took']), | |
| ('took', val_meta['took'])]), | |
| ) | |
| if was_training: | |
| model.train() | |
| return val_meta | |
| def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None): | |
| if warmup_iters == 0: | |
| scale = 1.0 | |
| elif total_iter > warmup_iters: | |
| scale = 1. / (total_iter ** 0.5) | |
| else: | |
| scale = total_iter / (warmup_iters ** 1.5) | |
| for param_group in opt.param_groups: | |
| param_group['lr'] = learning_rate * scale | |
| def apply_ema_decay(model, ema_model, decay): | |
| if not decay: | |
| return | |
| st = model.state_dict() | |
| add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module') | |
| for k, v in ema_model.state_dict().items(): | |
| if add_module and not k.startswith('module.'): | |
| k = 'module.' + k | |
| v.copy_(decay * v + (1 - decay) * st[k]) | |
| def init_multi_tensor_ema(model, ema_model): | |
| model_weights = list(model.state_dict().values()) | |
| ema_model_weights = list(ema_model.state_dict().values()) | |
| ema_overflow_buf = torch.cuda.IntTensor([0]) | |
| return model_weights, ema_model_weights, ema_overflow_buf | |
| def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf): | |
| amp_C.multi_tensor_axpby( | |
| 65536, overflow_buf, [ema_weights, model_weights, ema_weights], | |
| decay, 1-decay, -1) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='PyTorch FastPitch Training', | |
| allow_abbrev=False) | |
| parser = parse_args(parser) | |
| args, _ = parser.parse_known_args() | |
| if args.p_arpabet > 0.0: | |
| cmudict.initialize(args.cmudict_path, args.heteronyms_path) | |
| distributed_run = args.world_size > 1 | |
| torch.manual_seed(args.seed + args.local_rank) | |
| np.random.seed(args.seed + args.local_rank) | |
| if args.local_rank == 0: | |
| if not os.path.exists(args.output): | |
| os.makedirs(args.output) | |
| log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json') | |
| tb_subsets = ['train', 'val'] | |
| if args.ema_decay > 0.0: | |
| tb_subsets.append('val_ema') | |
| logger.init(log_fpath, args.output, enabled=(args.local_rank == 0), | |
| tb_subsets=tb_subsets) | |
| logger.parameters(vars(args), tb_subset='train') | |
| parser = models.parse_model_args('FastPitch', parser) | |
| args, unk_args = parser.parse_known_args() | |
| if len(unk_args) > 0: | |
| raise ValueError(f'Invalid options {unk_args}') | |
| torch.backends.cudnn.benchmark = args.cudnn_benchmark | |
| if distributed_run: | |
| init_distributed(args, args.world_size, args.local_rank) | |
| else: | |
| if args.trainloader_repeats > 1: | |
| print('WARNING: Disabled --trainloader-repeats, supported only for' | |
| ' multi-GPU data loading.') | |
| args.trainloader_repeats = 1 | |
| device = torch.device('cuda' if args.cuda else 'cpu') | |
| model_config = models.get_model_config('FastPitch', args) | |
| model = models.get_model('FastPitch', model_config, device) | |
| attention_kl_loss = AttentionBinarizationLoss() | |
| # Store pitch mean/std as params to translate from Hz during inference | |
| model.pitch_mean[0] = args.pitch_mean | |
| model.pitch_std[0] = args.pitch_std | |
| kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9, | |
| weight_decay=args.weight_decay) | |
| if args.optimizer == 'adam': | |
| optimizer = FusedAdam(model.parameters(), **kw) | |
| # optimizer = torch.optim.Adam(model.parameters(), **kw) | |
| elif args.optimizer == 'lamb': | |
| optimizer = FusedLAMB(model.parameters(), **kw) | |
| # optimizer = torch.optim.Adam(model.parameters(), **kw) | |
| else: | |
| raise ValueError | |
| scaler = torch.cuda.amp.GradScaler(enabled=args.amp) | |
| if args.ema_decay > 0: | |
| ema_model = copy.deepcopy(model) | |
| else: | |
| ema_model = None | |
| if distributed_run: | |
| model = DistributedDataParallel( | |
| model, device_ids=[args.local_rank], output_device=args.local_rank, | |
| find_unused_parameters=True) | |
| train_state = {'epoch': 1, 'total_iter': 1} | |
| checkpointer = Checkpointer(args.output, args.keep_milestones) | |
| checkpointer.maybe_load(model, optimizer, scaler, train_state, args, | |
| ema_model) | |
| start_epoch = train_state['epoch'] | |
| total_iter = train_state['total_iter'] | |
| criterion = FastPitchLoss( | |
| dur_predictor_loss_scale=args.dur_predictor_loss_scale, | |
| pitch_predictor_loss_scale=args.pitch_predictor_loss_scale, | |
| attn_loss_scale=args.attn_loss_scale) | |
| collate_fn = TTSCollate() | |
| if args.local_rank == 0: | |
| prepare_tmp(args.pitch_online_dir) | |
| trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args)) | |
| valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args)) | |
| if distributed_run: | |
| train_sampler = RepeatedDistributedSampler(args.trainloader_repeats, | |
| trainset, drop_last=True) | |
| val_sampler = DistributedSampler(valset) | |
| shuffle = False | |
| else: | |
| train_sampler, val_sampler, shuffle = None, None, False ########### was True | |
| # 4 workers are optimal on DGX-1 (from epoch 2 onwards) | |
| kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size, | |
| 'collate_fn': collate_fn} | |
| train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset, | |
| shuffle=shuffle, drop_last=True, | |
| sampler=train_sampler, pin_memory=True, | |
| persistent_workers=True, **kw) | |
| val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler, | |
| pin_memory=False, **kw) | |
| if args.ema_decay: | |
| mt_ema_params = init_multi_tensor_ema(model, ema_model) | |
| model.train() | |
| bmark_stats = BenchmarkStats() | |
| torch.cuda.synchronize() | |
| for epoch in range(start_epoch, args.epochs + 1): | |
| epoch_start_time = time.perf_counter() | |
| epoch_loss = 0.0 | |
| epoch_mel_loss = 0.0 | |
| epoch_num_frames = 0 | |
| epoch_frames_per_sec = 0.0 | |
| if distributed_run: | |
| train_loader.sampler.set_epoch(epoch) | |
| iter_loss = 0 | |
| iter_num_frames = 0 | |
| iter_meta = {} | |
| iter_start_time = time.perf_counter() | |
| epoch_iter = 1 | |
| for batch, accum_step in zip(train_loader, | |
| cycle(range(1, args.grad_accumulation + 1))): | |
| if accum_step == 1: | |
| adjust_learning_rate(total_iter, optimizer, args.learning_rate, | |
| args.warmup_steps) | |
| model.zero_grad(set_to_none=True) | |
| x, y, num_frames = batch_to_gpu(batch) | |
| with torch.cuda.amp.autocast(enabled=args.amp): | |
| y_pred = model(x) | |
| loss, meta = criterion(y_pred, y) | |
| if (args.kl_loss_start_epoch is not None | |
| and epoch >= args.kl_loss_start_epoch): | |
| if args.kl_loss_start_epoch == epoch and epoch_iter == 1: | |
| print('Begin hard_attn loss') | |
| _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred | |
| binarization_loss = attention_kl_loss(attn_hard, attn_soft) | |
| kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight | |
| meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight | |
| loss += kl_weight * binarization_loss | |
| else: | |
| meta['kl_loss'] = torch.zeros_like(loss) | |
| kl_weight = 0 | |
| binarization_loss = 0 | |
| loss /= args.grad_accumulation | |
| meta = {k: v / args.grad_accumulation | |
| for k, v in meta.items()} | |
| if args.amp: | |
| scaler.scale(loss).backward() | |
| else: | |
| loss.backward() | |
| if distributed_run: | |
| reduced_loss = reduce_tensor(loss.data, args.world_size).item() | |
| reduced_num_frames = reduce_tensor(num_frames.data, 1).item() | |
| meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()} | |
| else: | |
| reduced_loss = loss.item() | |
| reduced_num_frames = num_frames.item() | |
| if np.isnan(reduced_loss): | |
| raise Exception("loss is NaN") | |
| iter_loss += reduced_loss | |
| iter_num_frames += reduced_num_frames | |
| iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta} | |
| if accum_step % args.grad_accumulation == 0: | |
| logger.log_grads_tb(total_iter, model) | |
| if args.amp: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), args.grad_clip_thresh) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), args.grad_clip_thresh) | |
| optimizer.step() | |
| if args.ema_decay > 0.0: | |
| apply_multi_tensor_ema(args.ema_decay, *mt_ema_params) | |
| iter_mel_loss = iter_meta['mel_loss'].item() | |
| iter_kl_loss = iter_meta['kl_loss'].item() | |
| iter_time = time.perf_counter() - iter_start_time | |
| epoch_frames_per_sec += iter_num_frames / iter_time | |
| epoch_loss += iter_loss | |
| epoch_num_frames += iter_num_frames | |
| epoch_mel_loss += iter_mel_loss | |
| num_iters = len(train_loader) // args.grad_accumulation | |
| log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter, | |
| subset='train', data=OrderedDict([ | |
| ('loss', iter_loss), | |
| ('mel_loss', iter_mel_loss), | |
| ('kl_loss', iter_kl_loss), | |
| ('kl_weight', kl_weight), | |
| ('frames/s', iter_num_frames / iter_time), | |
| ('took', iter_time), | |
| ('lrate', optimizer.param_groups[0]['lr'])]), | |
| ) | |
| iter_loss = 0 | |
| iter_num_frames = 0 | |
| iter_meta = {} | |
| iter_start_time = time.perf_counter() | |
| if epoch_iter == num_iters: | |
| break | |
| epoch_iter += 1 | |
| total_iter += 1 | |
| # Finished epoch | |
| epoch_loss /= epoch_iter | |
| epoch_mel_loss /= epoch_iter | |
| epoch_time = time.perf_counter() - epoch_start_time | |
| log((epoch,), tb_total_steps=None, subset='train_avg', | |
| data=OrderedDict([ | |
| ('loss', epoch_loss), | |
| ('mel_loss', epoch_mel_loss), | |
| ('frames/s', epoch_num_frames / epoch_time), | |
| ('took', epoch_time)]), | |
| ) | |
| bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss, | |
| epoch_time) | |
| if epoch % args.validation_freq == 0: | |
| validate(model, epoch, total_iter, criterion, val_loader, | |
| distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank) | |
| if args.ema_decay > 0: | |
| validate(ema_model, epoch, total_iter, criterion, val_loader, | |
| distributed_run, batch_to_gpu, args.local_rank, ema=True) | |
| # save before making sched.step() for proper loading of LR | |
| checkpointer.maybe_save(args, model, ema_model, optimizer, scaler, | |
| epoch, total_iter, model_config) | |
| logger.flush() | |
| # Finished training | |
| if len(bmark_stats) > 0: | |
| log((), tb_total_steps=None, subset='train_avg', | |
| data=bmark_stats.get(args.benchmark_epochs_num)) | |
| validate(model, None, total_iter, criterion, val_loader, distributed_run, | |
| batch_to_gpu) | |
| if __name__ == '__main__': | |
| main() | |