File size: 3,629 Bytes
014393d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import json

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch

from ArticulatoryTextFrontend import ArticulatoryTextFrontend


def visualize_one_hot_encoded_sequence(tensor, sentence, col_labels, cmap='BuGn'):
    """
    Visualize a 2D one-hot encoded tensor as a heatmap.
    """
    tensor = torch.clamp(tensor, min=0, max=1).transpose(0, 1).cpu().numpy()
    if tensor.ndim != 2:
        raise ValueError("Input tensor must be a 2D array")

    # Check the size of labels matches the tensor dimensions
    row_labels = ["stressed", "very-high-tone", "high-tone", "mid-tone", "low-tone", "very-low-tone", "rising-tone", "falling-tone", "peaking-tone", "dipping-tone", "lengthened", "half-length", "shortened", "consonant", "vowel", "phoneme", "silence", "end of sentence", "questionmark", "exclamationmark", "fullstop", "word-boundary", "dental", "postalveolar",
                  "velar", "palatal", "glottal", "uvular", "labiodental", "labial-velar", "alveolar", "bilabial", "alveolopalatal", "retroflex", "pharyngal", "epiglottal", "central", "back", "front_central", "front", "central_back", "mid", "close-mid", "close", "open-mid", "close_close-mid", "open-mid_open", "open", "rounded", "unrounded", "plosive",
                  "nasal", "approximant", "trill", "flap", "fricative", "lateral-approximant", "implosive", "vibrant", "click", "ejective", "aspirated", "unvoiced", "voiced"]

    if row_labels and len(row_labels) != tensor.shape[0]:
        raise ValueError("Number of row labels must match the number of rows in the tensor")
    if col_labels and len(col_labels) != tensor.shape[1]:
        raise ValueError("Number of column labels must match the number of columns in the tensor")

    fig, ax = plt.subplots(figsize=(16, 16))

    # Create the heatmap
    ax.imshow(tensor, cmap=cmap, aspect='auto')

    # Add labels
    if row_labels:
        ax.set_yticks(np.arange(tensor.shape[0]), row_labels)
    if col_labels:
        ax.set_xticks(np.arange(tensor.shape[1]), col_labels, rotation=0)

    ax.grid(False)
    ax.set_xlabel('Phones')
    ax.set_ylabel('Features')

    # Display the heatmap
    ax.set_title(f"»{sentence}«")
    return fig


def vis_wrapper(sentence, language):
    tf = ArticulatoryTextFrontend(language=language.split(" ")[-1].split("(")[1].split(")")[0])
    features = tf.string_to_tensor(sentence)
    phones = tf.get_phone_string(sentence)

    return visualize_one_hot_encoded_sequence(tensor=features, sentence=sentence, col_labels=phones)


def load_json_from_path(path):
    with open(path, "r", encoding="utf8") as f:
        obj = json.loads(f.read())

    return obj


iso_to_name = load_json_from_path("iso_to_fullname.json")
text_selection = [f"{iso_to_name[iso_code]} ({iso_code})" for iso_code in iso_to_name]
iface = gr.Interface(fn=vis_wrapper,
                     inputs=[gr.Textbox(lines=2,
                                        placeholder="write the sentence you want to visualize here...",
                                        value="What I cannot create, I do not understand.",
                                        label="Text input"),
                             gr.Dropdown(text_selection,
                                         type="value",
                                         value='English (eng)',
                                         label="Select the Language of the Text (type on your keyboard to find it quickly)")],
                     outputs=[gr.Plot()],
                     allow_flagging="never",
                     live=False,
                     fill_width=True)
iface.launch()