Spaces:
Sleeping
Sleeping
| import atexit | |
| import glob | |
| import re | |
| from itertools import product | |
| from pathlib import Path | |
| import dllogger | |
| import torch | |
| import numpy as np | |
| from dllogger import StdOutBackend, JSONStreamBackend, Verbosity | |
| from torch.utils.tensorboard import SummaryWriter | |
| tb_loggers = {} | |
| class TBLogger: | |
| """ | |
| xyz_dummies: stretch the screen with empty plots so the legend would | |
| always fit for other plots | |
| """ | |
| def __init__(self, enabled, log_dir, name, interval=1, dummies=True): | |
| self.enabled = enabled | |
| self.interval = interval | |
| self.cache = {} | |
| if self.enabled: | |
| self.summary_writer = SummaryWriter( | |
| log_dir=Path(log_dir, name), flush_secs=120, max_queue=200) | |
| atexit.register(self.summary_writer.close) | |
| if dummies: | |
| for key in ('_', '✕'): | |
| self.summary_writer.add_scalar(key, 0.0, 1) | |
| def log(self, step, data): | |
| for k, v in data.items(): | |
| self.log_value(step, k, v.item() if type(v) is torch.Tensor else v) | |
| def log_value(self, step, key, val, stat='mean'): | |
| if self.enabled: | |
| if key not in self.cache: | |
| self.cache[key] = [] | |
| self.cache[key].append(val) | |
| if len(self.cache[key]) == self.interval: | |
| agg_val = getattr(np, stat)(self.cache[key]) | |
| self.summary_writer.add_scalar(key, agg_val, step) | |
| del self.cache[key] | |
| def log_grads(self, step, model): | |
| if self.enabled: | |
| norms = [p.grad.norm().item() for p in model.parameters() | |
| if p.grad is not None] | |
| for stat in ('max', 'min', 'mean'): | |
| self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms), | |
| stat=stat) | |
| def unique_log_fpath(fpath): | |
| if not Path(fpath).is_file(): | |
| return fpath | |
| # Avoid overwriting old logs | |
| saved = [re.search('\.(\d+)$', f) for f in glob.glob(f'{fpath}.*')] | |
| saved = [0] + [int(m.group(1)) for m in saved if m is not None] | |
| return f'{fpath}.{max(saved) + 1}' | |
| def stdout_step_format(step): | |
| if isinstance(step, str): | |
| return step | |
| fields = [] | |
| if len(step) > 0: | |
| fields.append("epoch {:>4}".format(step[0])) | |
| if len(step) > 1: | |
| fields.append("iter {:>3}".format(step[1])) | |
| if len(step) > 2: | |
| fields[-1] += "/{}".format(step[2]) | |
| return " | ".join(fields) | |
| def stdout_metric_format(metric, metadata, value): | |
| name = metadata.get("name", metric + " : ") | |
| unit = metadata.get("unit", None) | |
| format = f'{{{metadata.get("format", "")}}}' | |
| fields = [name, format.format(value) if value is not None else value, unit] | |
| fields = [f for f in fields if f is not None] | |
| return "| " + " ".join(fields) | |
| def init(log_fpath, log_dir, enabled=True, tb_subsets=[], **tb_kw): | |
| if enabled: | |
| backends = [JSONStreamBackend(Verbosity.DEFAULT, | |
| unique_log_fpath(log_fpath)), | |
| StdOutBackend(Verbosity.VERBOSE, | |
| step_format=stdout_step_format, | |
| metric_format=stdout_metric_format)] | |
| else: | |
| backends = [] | |
| dllogger.init(backends=backends) | |
| dllogger.metadata("train_lrate", {"name": "lrate", "unit": None, "format": ":>3.2e"}) | |
| for id_, pref in [('train', ''), ('train_avg', 'avg train '), | |
| ('val', ' avg val '), ('val_ema', ' EMA val ')]: | |
| dllogger.metadata(f"{id_}_loss", | |
| {"name": f"{pref}loss", "unit": None, "format": ":>5.2f"}) | |
| dllogger.metadata(f"{id_}_mel_loss", | |
| {"name": f"{pref}mel loss", "unit": None, "format": ":>5.2f"}) | |
| dllogger.metadata(f"{id_}_kl_loss", | |
| {"name": f"{pref}kl loss", "unit": None, "format": ":>5.5f"}) | |
| dllogger.metadata(f"{id_}_kl_weight", | |
| {"name": f"{pref}kl weight", "unit": None, "format": ":>5.5f"}) | |
| dllogger.metadata(f"{id_}_frames/s", | |
| {"name": None, "unit": "frames/s", "format": ":>10.2f"}) | |
| dllogger.metadata(f"{id_}_took", | |
| {"name": "took", "unit": "s", "format": ":>3.2f"}) | |
| global tb_loggers | |
| tb_loggers = {s: TBLogger(enabled, log_dir, name=s, **tb_kw) | |
| for s in tb_subsets} | |
| def init_inference_metadata(batch_size=None): | |
| modalities = [('latency', 's', ':>10.5f'), ('RTF', 'x', ':>10.2f'), | |
| ('frames/s', 'frames/s', ':>10.2f'), ('samples/s', 'samples/s', ':>10.2f'), | |
| ('letters/s', 'letters/s', ':>10.2f'), ('tokens/s', 'tokens/s', ':>10.2f')] | |
| if batch_size is not None: | |
| modalities.append((f'RTF@{batch_size}', 'x', ':>10.2f')) | |
| percs = ['', 'avg', '90%', '95%', '99%'] | |
| models = ['', 'fastpitch', 'waveglow', 'hifigan'] | |
| for perc, model, (mod, unit, fmt) in product(percs, models, modalities): | |
| name = f'{perc} {model} {mod}'.strip().replace(' ', ' ') | |
| dllogger.metadata(name.replace(' ', '_'), | |
| {'name': f'{name: <26}', 'unit': unit, 'format': fmt}) | |
| def log(step, tb_total_steps=None, data={}, subset='train'): | |
| if tb_total_steps is not None: | |
| tb_loggers[subset].log(tb_total_steps, data) | |
| if subset != '': | |
| data = {f'{subset}_{key}': v for key, v in data.items()} | |
| dllogger.log(step, data=data) | |
| def log_grads_tb(tb_total_steps, grads, tb_subset='train'): | |
| tb_loggers[tb_subset].log_grads(tb_total_steps, grads) | |
| def parameters(data, verbosity=0, tb_subset=None): | |
| for k, v in data.items(): | |
| dllogger.log(step="PARAMETER", data={k: v}, verbosity=verbosity) | |
| if tb_subset is not None and tb_loggers[tb_subset].enabled: | |
| tb_data = {k: v for k, v in data.items() | |
| if type(v) in (str, bool, int, float)} | |
| tb_loggers[tb_subset].summary_writer.add_hparams(tb_data, {}) | |
| def flush(): | |
| dllogger.flush() | |
| for tbl in tb_loggers.values(): | |
| if tbl.enabled: | |
| tbl.summary_writer.flush() | |