Flux9665's picture
Create app.py
014393d verified
raw
history blame
3.63 kB
import json
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from ArticulatoryTextFrontend import ArticulatoryTextFrontend
def visualize_one_hot_encoded_sequence(tensor, sentence, col_labels, cmap='BuGn'):
"""
Visualize a 2D one-hot encoded tensor as a heatmap.
"""
tensor = torch.clamp(tensor, min=0, max=1).transpose(0, 1).cpu().numpy()
if tensor.ndim != 2:
raise ValueError("Input tensor must be a 2D array")
# Check the size of labels matches the tensor dimensions
row_labels = ["stressed", "very-high-tone", "high-tone", "mid-tone", "low-tone", "very-low-tone", "rising-tone", "falling-tone", "peaking-tone", "dipping-tone", "lengthened", "half-length", "shortened", "consonant", "vowel", "phoneme", "silence", "end of sentence", "questionmark", "exclamationmark", "fullstop", "word-boundary", "dental", "postalveolar",
"velar", "palatal", "glottal", "uvular", "labiodental", "labial-velar", "alveolar", "bilabial", "alveolopalatal", "retroflex", "pharyngal", "epiglottal", "central", "back", "front_central", "front", "central_back", "mid", "close-mid", "close", "open-mid", "close_close-mid", "open-mid_open", "open", "rounded", "unrounded", "plosive",
"nasal", "approximant", "trill", "flap", "fricative", "lateral-approximant", "implosive", "vibrant", "click", "ejective", "aspirated", "unvoiced", "voiced"]
if row_labels and len(row_labels) != tensor.shape[0]:
raise ValueError("Number of row labels must match the number of rows in the tensor")
if col_labels and len(col_labels) != tensor.shape[1]:
raise ValueError("Number of column labels must match the number of columns in the tensor")
fig, ax = plt.subplots(figsize=(16, 16))
# Create the heatmap
ax.imshow(tensor, cmap=cmap, aspect='auto')
# Add labels
if row_labels:
ax.set_yticks(np.arange(tensor.shape[0]), row_labels)
if col_labels:
ax.set_xticks(np.arange(tensor.shape[1]), col_labels, rotation=0)
ax.grid(False)
ax.set_xlabel('Phones')
ax.set_ylabel('Features')
# Display the heatmap
ax.set_title(f"»{sentence}«")
return fig
def vis_wrapper(sentence, language):
tf = ArticulatoryTextFrontend(language=language.split(" ")[-1].split("(")[1].split(")")[0])
features = tf.string_to_tensor(sentence)
phones = tf.get_phone_string(sentence)
return visualize_one_hot_encoded_sequence(tensor=features, sentence=sentence, col_labels=phones)
def load_json_from_path(path):
with open(path, "r", encoding="utf8") as f:
obj = json.loads(f.read())
return obj
iso_to_name = load_json_from_path("iso_to_fullname.json")
text_selection = [f"{iso_to_name[iso_code]} ({iso_code})" for iso_code in iso_to_name]
iface = gr.Interface(fn=vis_wrapper,
inputs=[gr.Textbox(lines=2,
placeholder="write the sentence you want to visualize here...",
value="What I cannot create, I do not understand.",
label="Text input"),
gr.Dropdown(text_selection,
type="value",
value='English (eng)',
label="Select the Language of the Text (type on your keyboard to find it quickly)")],
outputs=[gr.Plot()],
allow_flagging="never",
live=False,
fill_width=True)
iface.launch()