#THIS IS LITTLE CHAD VERSION BETTER THAN THE ONE CLONE FROM ITS TRANSCRIING MIC AND FILE ALMOST HEHE from __future__ import annotations import os import copy from dataclasses import dataclass from typing import List, Optional, Tuple, Dict import numpy as np import torch import torchaudio import gradio as gr # NeMo from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.utils import rnnt_utils # noqa: F401 from omegaconf import OmegaConf # add near other audio utilities import soundfile as sf def load_mono16k(path: str) -> np.ndarray: """Load any audio file, convert to mono float32 at 16 kHz.""" try: # soundfile: shape (T, C) or (T,) wav, sr = sf.read(path, dtype="float32", always_2d=True) wav = wav.mean(axis=1) # to mono wav = wav.astype(np.float32, copy=False) return RESAMPLER.resample(wav, sr) except Exception: # fallback to torchaudio wav_t, sr = torchaudio.load(path) # (C, T) if wav_t.dtype != torch.float32: wav_t = wav_t.float() wav = wav_t.mean(dim=0).numpy() # to mono np.float32 return RESAMPLER.resample(wav, int(sr)) # ---------------------------- # Config # ---------------------------- MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3") TARGET_SR = 16_000 # per model card BEAM_SIZE = 8 # ---------------------------- # Logging # ---------------------------- import logging LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() logger = logging.getLogger("parakeet_app") logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) _handler = logging.StreamHandler() _formatter = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") _handler.setFormatter(_formatter) logger.handlers = [_handler] logger.propagate = False # ---------------------------- # Audio utilities # ---------------------------- def to_mono_np(x: np.ndarray) -> np.ndarray: """Ensure mono float32 [-1, 1]. Accepts shape (T,) or (T, C).""" if x.ndim == 2: x = x.mean(axis=1) x = x.astype(np.float32, copy=False) return x class ResamplerCache: def __init__(self): self._cache: Dict[int, torchaudio.transforms.Resample] = {} def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray: if src_sr == TARGET_SR: return wav if src_sr not in self._cache: logger.info(f"create_resampler src_sr={src_sr} -> {TARGET_SR}") self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR) t = torch.from_numpy(wav) if t.ndim == 1: t = t.unsqueeze(0) y = self._cache[src_sr](t) logger.debug(f"resampled {wav.shape[-1]} samples") return y.squeeze(0).numpy() RESAMPLER = ResamplerCache() # ---------------------------- # Text overlap merger # ---------------------------- def _tokenize(s: str) -> List[str]: return s.strip().split() def merge_overlap(prev_full: str, new_full: str, max_overlap_tokens: int = 30) -> str: """Return the incremental new tail to append to the running transcript. Compute the longest suffix of prev_full that matches a prefix of new_full. """ if not prev_full: return new_full prev_tokens = _tokenize(prev_full) new_tokens = _tokenize(new_full) max_k = min(len(prev_tokens), len(new_tokens), max_overlap_tokens) overlap = 0 for k in range(max_k, 0, -1): if prev_tokens[-k:] == new_tokens[:k]: overlap = k break tail_tokens = new_tokens[overlap:] return (" ".join(tail_tokens)).strip() # ---------------------------- # Model manager # ---------------------------- class ParakeetManager: def __init__(self, device: str = "cpu"): self.device = torch.device(device) logger.info(f"loading_model name={MODEL_NAME} device={self.device}") self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME) self.model.to(self.device) self.model.eval() for p in self.model.parameters(): p.requires_grad = False # Base decoding cfg differs by model class; handle both layouts. if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"): self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg) else: self._base_decoding = copy.deepcopy(self.model.cfg.decoding) self.set_beam_decoding() logger.info(f"model_loaded strategy=beam beam_size={BEAM_SIZE}") def set_beam_decoding(self): cfg = copy.deepcopy(self._base_decoding) cfg.strategy = "beam" cfg.beam = OmegaConf.create( {"beam_size": BEAM_SIZE, "return_best_hypothesis": True, "score_norm": True} ) # Enable label-looping if available. if not hasattr(cfg, "loop_labels") or cfg.loop_labels is None: OmegaConf.set_struct(cfg, False) cfg["loop_labels"] = True else: cfg.loop_labels = True self.model.change_decoding_strategy(cfg) logger.info("decoding_set strategy=beam loop_labels=True") def _nemo_transcribe(self, items: List) -> List: """ Call NeMo transcribe using positional first arg to support RNNT/TDT. Always request hypotheses; caller will normalize outputs. """ return self.model.transcribe(items, batch_size=1 if len(items) == 1 else 8, num_workers=0, return_hypotheses=True) def transcribe_np(self, audio_16k: np.ndarray, compute_timestamps: bool = False) -> Tuple[str, Optional[List]]: """Transcribe a single in-memory 16k waveform.""" dur = len(audio_16k) / TARGET_SR if TARGET_SR else 0.0 logger.debug(f"transcribe_np len={len(audio_16k)} dur={dur:.3f}s ts={compute_timestamps}") try: out = self._nemo_transcribe([audio_16k]) except Exception: logger.exception("transcribe_np_failed") raise first = out[0] # Hypotheses can be a list (N-best) or a string. if isinstance(first, list) and first: first = first[0] if isinstance(first, str): return first, None text = getattr(first, "text", "") ts = getattr(first, "timestamps", None) if compute_timestamps else None logger.debug(f"transcribe_np_ok chars={len(text)} ts_count={len(ts) if ts else 0}") return text, ts def transcribe_files(self, paths: List[str], batch_size: int = 8, compute_timestamps: bool = False): logger.info(f"offline_transcribe files={len(paths)} batch={batch_size} ts={compute_timestamps}") if not paths: return [] # load and normalize to mono 16k arrays: List[np.ndarray] = [] for p in paths: try: arrays.append(load_mono16k(p)) except Exception: logger.exception(f"load_mono16k_failed path={p}") raise try: out = self.model.transcribe(arrays, batch_size=min(batch_size, len(arrays)), num_workers=0, return_hypotheses=True) except Exception: logger.exception("offline_transcribe_failed") raise results = [] for p, o in zip(paths, out): h = o[0] if isinstance(o, list) and o else o if isinstance(h, str): text, stamp = h, None else: text = getattr(h, "text", "") stamp = getattr(h, "timestamps", None) if compute_timestamps else None results.append({"path": p, "text": text, "timestamp": stamp}) logger.info("offline_transcribe_ok") return results # ---------------------------- # Streaming engine # ---------------------------- @dataclass class StreamConfig: left_s: float = 10.0 chunk_s: float = 1.0 # lower latency right_s: float = 0.5 # lower latency class StreamingSession: def __init__(self, manager: ParakeetManager, cfg: StreamConfig): self.mgr = manager self.cfg = cfg self.reset() def reset(self): self.prev_text: str = "" self.running_text: str = "" self.processed_tail: np.ndarray = np.zeros(0, dtype=np.float32) self.pending: np.ndarray = np.zeros(0, dtype=np.float32) logger.info(f"stream_reset left={self.cfg.left_s} chunk={self.cfg.chunk_s} right={self.cfg.right_s}") def add_chunk(self, audio: np.ndarray, src_sr: int): mono = to_mono_np(audio) res = RESAMPLER.resample(mono, src_sr) self.pending = np.concatenate([self.pending, res]) if self.pending.size else res logger.debug(f"stream_add_chunk src_sr={src_sr} pending_samples={self.pending.size}") def _step_windows(self) -> None: L = int(self.cfg.left_s * TARGET_SR) C = int(self.cfg.chunk_s * TARGET_SR) R = int(self.cfg.right_s * TARGET_SR) steps = 0 while self.pending.size >= C + R: left_ctx = self.processed_tail[-L:] if self.processed_tail.size > 0 else np.zeros(0, dtype=np.float32) head = self.pending[: C + R] window = np.concatenate([left_ctx, head]) if left_ctx.size else head try: new_text, _ = self.mgr.transcribe_np(window, compute_timestamps=False) except Exception: logger.exception("stream_step_transcribe_failed") break inc = merge_overlap(self.prev_text, new_text) if inc: if self.running_text and not self.running_text.endswith(" "): self.running_text += " " self.running_text += inc self.prev_text = new_text chunk = self.pending[:C] self.processed_tail = np.concatenate([self.processed_tail, chunk]) if self.processed_tail.size else chunk if self.processed_tail.size > L: self.processed_tail = self.processed_tail[-L:] self.pending = self.pending[C:] steps += 1 if steps: logger.debug(f"stream_step_windows steps={steps} remaining_pending={self.pending.size}") def flush(self) -> str: L = int(self.cfg.left_s * TARGET_SR) R = int(self.cfg.right_s * TARGET_SR) logger.info(f"stream_flush pending_samples={self.pending.size}") if self.pending.size == 0: return self.running_text left_ctx = self.processed_tail[-L:] if self.processed_tail.size > 0 else np.zeros(0, dtype=np.float32) pad = np.zeros(R, dtype=np.float32) window = np.concatenate([left_ctx, self.pending, pad]) if left_ctx.size else np.concatenate([self.pending, pad]) try: new_text, _ = self.mgr.transcribe_np(window, compute_timestamps=False) except Exception: logger.exception("stream_flush_transcribe_failed") return self.running_text inc = merge_overlap(self.prev_text, new_text) if inc: if self.running_text and not self.running_text.endswith(" "): self.running_text += " " self.running_text += inc self.prev_text = new_text self.pending = np.zeros(0, dtype=np.float32) self.processed_tail = np.zeros(0, dtype=np.float32) logger.info("stream_flush_done") return self.running_text # ---------------------------- # Gradio UI callbacks # ---------------------------- MANAGER = ParakeetManager(device="cpu") def _parse_gr_audio(x) -> Tuple[np.ndarray, int]: """Accept Gradio 5 Audio stream payload. Returns (waveform np.float32, samplerate int). """ if x is None: return np.zeros(0, dtype=np.float32), TARGET_SR # tuple (sr, array) if isinstance(x, tuple) and len(x) == 2: sr = int(x[0]) arr = np.array(x[1], dtype=np.float32) return arr, sr # dict {"sampling_rate": int, "data": np.array} if isinstance(x, dict) and "data" in x and "sampling_rate" in x: arr = np.array(x["data"], dtype=np.float32) sr = int(x["sampling_rate"]) return arr, sr # raw numpy at known sr if isinstance(x, np.ndarray): return x.astype(np.float32, copy=False), TARGET_SR logger.error(f"unsupported_gr_audio_payload type={type(x)}") raise ValueError("Unsupported audio payload") def streaming_reset(left, chunk, right): logger.info(f"ui_streaming_reset left={left} chunk={chunk} right={right}") sess = StreamingSession(MANAGER, StreamConfig(left_s=left, chunk_s=chunk, right_s=right)) return sess, "" def streaming_step(audio_chunk, sess: Optional[StreamingSession]): if sess is None: return None, "" try: wav, sr = _parse_gr_audio(audio_chunk) except Exception: logger.exception("ui_streaming_step_parse_failed") return sess, sess.running_text if wav.size: sess.add_chunk(wav, sr) sess._step_windows() return sess, sess.running_text def streaming_flush(sess: Optional[StreamingSession]): if sess is None: return None, "" text = sess.flush() return None, text # Offline / batch callback def offline_run(files, batch_size: int, want_ts: bool): logger.info(f"ui_offline_run click files={0 if files is None else len(files)} batch={batch_size} ts={want_ts}") if not files: return [] # Gradio 5 File with type="filepath" returns list[str] of paths. paths: List[str] = [] for f in files: if isinstance(f, str): paths.append(f) elif hasattr(f, "name"): paths.append(f.name) try: results = MANAGER.transcribe_files(paths, batch_size=batch_size, compute_timestamps=want_ts) except Exception: logger.exception("ui_offline_run_failed") raise # Build table rows [file, text, timestamps] table = [] for r in results: row = [os.path.basename(r["path"]), r["text"], r.get("timestamp")] table.append(row) logger.info("ui_offline_run_ok") return table # ---------------------------- # Build Gradio Interface # ---------------------------- with gr.Blocks(title="Parakeet-TDT v3: Streaming (Mic) + Offline (File)") as demo: gr.Markdown( """ # FINALLY, SIMPLE EXPLANATION OF THE NVIDIA NEMO TECHNICALS! This is a CHAD (smol one, not a GIGA one/basic) version of the idea of local transcription in real-time on cheap hardware, for example to use with rapsberry pi locally to say - do this, do that - you can easily use another SMOL LLM to in real-time take an action when you say "open the door # Beam ASR = Google Maps for Speech - THE ANALOGY OF USED FEATURES TO THE NVIDIA TECHNICAL TERMS USED IN THIS APP NOTE this app is not using lame chunking of audio like others... We are fully compatible here with modern architecture of Parakeet-TDT-v3 model, its no joke anymore - soon, the cache-aware streaming will be implemented - another gamechanger, anyway - THE WORK IS ONGOING ON THIS APP, FOR TODAY HERE YOU HAVE SMOL CHAD COMMIT **One-route vs many** - Greedy: pick the first route and drive. - Beam: keep several good routes, update with traffic, follow the best. **Beam size** - How many alternate routes you watch at once. **Label-looping** - Make back-to-back turns at the same intersection when signs are clear (e.g., "right, then immediate merge"). Faster, fewer stutters. """ ) with gr.Tab("Streaming"): with gr.Row(): left = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="Left context (s)") chunk = gr.Slider(0.2, 5.0, value=1.0, step=0.2, label="Chunk size (s)") right = gr.Slider(0.0, 5.0, value=0.5, step=0.1, label="Right context (s)") mic = gr.Audio( sources=["microphone"], type="numpy", streaming=True, label="Speak here" ) text_out = gr.Textbox(label="Transcript", lines=8) with gr.Row(): reset_btn = gr.Button("Reset") flush_btn = gr.Button("Flush") sess_state = gr.State() reset_btn.click( streaming_reset, inputs=[left, chunk, right], outputs=[sess_state, text_out], ) mic.stream( streaming_step, inputs=[mic, sess_state], outputs=[sess_state, text_out], ) flush_btn.click( streaming_flush, inputs=[sess_state], outputs=[sess_state, text_out], ) with gr.Tab("Offline"): files = gr.File(file_count="multiple", type="filepath", label="Upload audio files") batch_size = gr.Slider(1, 32, value=8, step=1, label="Batch size") want_ts = gr.Checkbox(label="Compute timestamps", value=False) run_btn = gr.Button("Run") results_table = gr.Dataframe( headers=["file", "text", "timestamps"], label="Results", row_count=(0, "dynamic"), col_count=(3, "fixed"), ) run_btn.click( offline_run, inputs=[files, batch_size, want_ts], outputs=[results_table], ) demo.queue().launch(ssr_mode=False)