Upload folder using huggingface_hub
Browse files- app.py +93 -19
- entailment.py +1 -1
- highlighter.py +33 -42
- lcs.py +3 -3
- masking_methods.py +84 -12
- paraphraser.py +1 -1
- sampling_methods.py +31 -139
- tree.py +90 -47
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import plotly.graph_objs as go
|
|
| 6 |
import textwrap
|
| 7 |
from transformers import pipeline
|
| 8 |
import re
|
| 9 |
-
import time
|
| 10 |
import requests
|
| 11 |
from PIL import Image
|
| 12 |
import itertools
|
|
@@ -20,10 +19,7 @@ import pandas as pd
|
|
| 20 |
from pprint import pprint
|
| 21 |
from tenacity import retry
|
| 22 |
from tqdm import tqdm
|
| 23 |
-
import scipy.stats
|
| 24 |
-
import torch
|
| 25 |
from transformers import GPT2LMHeadModel
|
| 26 |
-
import seaborn as sns
|
| 27 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
|
| 28 |
import random
|
| 29 |
from nltk.corpus import stopwords
|
|
@@ -31,22 +27,92 @@ from termcolor import colored
|
|
| 31 |
from nltk.translate.bleu_score import sentence_bleu
|
| 32 |
from transformers import BertTokenizer, BertModel
|
| 33 |
import gradio as gr
|
| 34 |
-
from tree import
|
| 35 |
from paraphraser import generate_paraphrase
|
| 36 |
from lcs import find_common_subsequences
|
| 37 |
from highlighter import highlight_common_words, highlight_common_words_dict
|
| 38 |
from entailment import analyze_entailment
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Function for the Gradio interface
|
| 41 |
def model(prompt):
|
| 42 |
-
|
| 43 |
-
paraphrased_sentences = generate_paraphrase(
|
| 44 |
-
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
@@ -63,15 +129,23 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
| 63 |
highlighted_user_prompt = gr.HTML()
|
| 64 |
|
| 65 |
with gr.Row():
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
with gr.Row():
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt,
|
| 72 |
clear_button.click(lambda: "", inputs=None, outputs=user_input)
|
| 73 |
-
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt,
|
| 74 |
|
| 75 |
# Launch the demo
|
| 76 |
-
demo.launch(share=True)
|
| 77 |
-
|
|
|
|
| 6 |
import textwrap
|
| 7 |
from transformers import pipeline
|
| 8 |
import re
|
|
|
|
| 9 |
import requests
|
| 10 |
from PIL import Image
|
| 11 |
import itertools
|
|
|
|
| 19 |
from pprint import pprint
|
| 20 |
from tenacity import retry
|
| 21 |
from tqdm import tqdm
|
|
|
|
|
|
|
| 22 |
from transformers import GPT2LMHeadModel
|
|
|
|
| 23 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
|
| 24 |
import random
|
| 25 |
from nltk.corpus import stopwords
|
|
|
|
| 27 |
from nltk.translate.bleu_score import sentence_bleu
|
| 28 |
from transformers import BertTokenizer, BertModel
|
| 29 |
import gradio as gr
|
| 30 |
+
from tree import generate_subplot
|
| 31 |
from paraphraser import generate_paraphrase
|
| 32 |
from lcs import find_common_subsequences
|
| 33 |
from highlighter import highlight_common_words, highlight_common_words_dict
|
| 34 |
from entailment import analyze_entailment
|
| 35 |
+
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
|
| 36 |
+
from sampling_methods import sample_word
|
| 37 |
+
|
| 38 |
|
| 39 |
# Function for the Gradio interface
|
| 40 |
def model(prompt):
|
| 41 |
+
user_prompt = prompt
|
| 42 |
+
paraphrased_sentences = generate_paraphrase(user_prompt)
|
| 43 |
+
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
|
| 44 |
+
length_accepted_sentences = len(selected_sentences)
|
| 45 |
+
common_grams = find_common_subsequences(user_prompt, selected_sentences)
|
| 46 |
+
|
| 47 |
+
masked_sentences = []
|
| 48 |
+
masked_words = []
|
| 49 |
+
masked_logits = []
|
| 50 |
+
selected_sentences_list = list(selected_sentences.keys())
|
| 51 |
+
|
| 52 |
+
for sentence in selected_sentences_list:
|
| 53 |
+
# Mask non-stopword
|
| 54 |
+
masked_sent, logits, words = mask_non_stopword(sentence)
|
| 55 |
+
masked_sentences.append(masked_sent)
|
| 56 |
+
masked_words.append(words)
|
| 57 |
+
masked_logits.append(logits)
|
| 58 |
+
|
| 59 |
+
# Mask non-stopword pseudorandom
|
| 60 |
+
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
|
| 61 |
+
masked_sentences.append(masked_sent)
|
| 62 |
+
masked_words.append(words)
|
| 63 |
+
masked_logits.append(logits)
|
| 64 |
+
|
| 65 |
+
# High entropy words
|
| 66 |
+
masked_sent, logits, words = high_entropy_words(sentence, common_grams)
|
| 67 |
+
masked_sentences.append(masked_sent)
|
| 68 |
+
masked_words.append(words)
|
| 69 |
+
masked_logits.append(logits)
|
| 70 |
+
|
| 71 |
+
sampled_sentences = []
|
| 72 |
+
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
|
| 73 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
|
| 74 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
|
| 75 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
|
| 76 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
|
| 77 |
+
|
| 78 |
+
# Predefined set of colors that are visible on a white background, excluding black
|
| 79 |
+
colors = ["red", "blue", "brown", "green"]
|
| 80 |
+
|
| 81 |
+
# Function to generate color from predefined set
|
| 82 |
+
def select_color():
|
| 83 |
+
return random.choice(colors)
|
| 84 |
+
|
| 85 |
+
# Create highlight_info with selected colors
|
| 86 |
+
highlight_info = [(word, select_color()) for _, word in common_grams]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "User Prompt (Highlighted and Numbered)")
|
| 90 |
+
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
|
| 91 |
+
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
|
| 92 |
+
|
| 93 |
+
# Initialize empty list to hold the trees
|
| 94 |
+
trees = []
|
| 95 |
+
|
| 96 |
+
# Initialize the indices for masked and sampled sentences
|
| 97 |
+
masked_index = 0
|
| 98 |
+
sampled_index = 0
|
| 99 |
+
|
| 100 |
+
for i, sentence in enumerate(selected_sentences):
|
| 101 |
+
# Generate the sublists of masked and sampled sentences based on current indices
|
| 102 |
+
next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
|
| 103 |
+
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
|
| 104 |
+
|
| 105 |
+
# Create the tree for the current sentence
|
| 106 |
+
tree = generate_subplot(sentence, next_masked_sentences, next_sampled_sentences, highlight_info)
|
| 107 |
+
trees.append(tree)
|
| 108 |
+
|
| 109 |
+
# Update the indices for the next iteration
|
| 110 |
+
masked_index += 3
|
| 111 |
+
sampled_index += 12
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Return all the outputs together
|
| 115 |
+
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees
|
| 116 |
|
| 117 |
|
| 118 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
|
|
| 129 |
highlighted_user_prompt = gr.HTML()
|
| 130 |
|
| 131 |
with gr.Row():
|
| 132 |
+
with gr.Tabs():
|
| 133 |
+
with gr.TabItem("Paraphrased Sentences"):
|
| 134 |
+
highlighted_accepted_sentences = gr.HTML()
|
| 135 |
+
with gr.TabItem("Discarded Sentences"):
|
| 136 |
+
highlighted_discarded_sentences = gr.HTML()
|
| 137 |
|
| 138 |
with gr.Row():
|
| 139 |
+
with gr.Tabs():
|
| 140 |
+
tree_tabs = []
|
| 141 |
+
for i in range(3): # Adjust this range according to the number of trees
|
| 142 |
+
with gr.TabItem(f"Tree {i+1}"):
|
| 143 |
+
tree = gr.Plot()
|
| 144 |
+
tree_tabs.append(tree)
|
| 145 |
|
| 146 |
+
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
|
| 147 |
clear_button.click(lambda: "", inputs=None, outputs=user_input)
|
| 148 |
+
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
|
| 149 |
|
| 150 |
# Launch the demo
|
| 151 |
+
demo.launch(share=True)
|
|
|
entailment.py
CHANGED
|
@@ -28,4 +28,4 @@ def analyze_entailment(original_sentence, paraphrased_sentences, threshold):
|
|
| 28 |
|
| 29 |
return all_sentences, selected_sentences, discarded_sentences
|
| 30 |
|
| 31 |
-
print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
|
|
|
|
| 28 |
|
| 29 |
return all_sentences, selected_sentences, discarded_sentences
|
| 30 |
|
| 31 |
+
# print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
|
highlighter.py
CHANGED
|
@@ -39,57 +39,48 @@ def highlight_common_words(common_words, sentences, title):
|
|
| 39 |
'''
|
| 40 |
|
| 41 |
|
|
|
|
| 42 |
import re
|
| 43 |
|
| 44 |
-
def highlight_common_words_dict(common_words,
|
| 45 |
color_map = {}
|
| 46 |
color_index = 0
|
| 47 |
highlighted_html = []
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
highlighted_sentences = [f'<h4 style="color: #374151; margin-bottom: 5px;">{section_title}</h4>']
|
| 53 |
|
| 54 |
-
for
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
f'</span>'
|
| 73 |
-
),
|
| 74 |
-
highlighted_sentence,
|
| 75 |
-
flags=re.IGNORECASE
|
| 76 |
-
)
|
| 77 |
-
highlighted_sentences.append(
|
| 78 |
-
f'<div style="margin-bottom: 5px;">'
|
| 79 |
-
f'{highlighted_sentence}'
|
| 80 |
-
f'<div style="display: inline-block; margin-left: 5px; border: 1px solid #ddd; padding: 3px 5px; border-radius: 3px; background-color: white; font-size: 0.9em;">'
|
| 81 |
-
f'Entailment Score: {score}</div></div>'
|
| 82 |
)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
-
final_html = "<br>".join(
|
| 90 |
return f'''
|
| 91 |
-
<div style="
|
| 92 |
-
<h3 style="margin-top: 0; font-size: 1em; color: #111827;
|
| 93 |
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
| 94 |
</div>
|
| 95 |
-
'''
|
|
|
|
| 39 |
'''
|
| 40 |
|
| 41 |
|
| 42 |
+
|
| 43 |
import re
|
| 44 |
|
| 45 |
+
def highlight_common_words_dict(common_words, sentences, title):
|
| 46 |
color_map = {}
|
| 47 |
color_index = 0
|
| 48 |
highlighted_html = []
|
| 49 |
|
| 50 |
+
for idx, (sentence, score) in enumerate(sentences.items(), start=1):
|
| 51 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
| 52 |
+
highlighted_sentence = sentence_with_idx
|
|
|
|
| 53 |
|
| 54 |
+
for index, word in common_words:
|
| 55 |
+
if word not in color_map:
|
| 56 |
+
color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
|
| 57 |
+
color_index += 1
|
| 58 |
+
escaped_word = re.escape(word)
|
| 59 |
+
pattern = rf'\b{escaped_word}\b'
|
| 60 |
+
highlighted_sentence = re.sub(
|
| 61 |
+
pattern,
|
| 62 |
+
lambda m, idx=index, color=color_map[word]: (
|
| 63 |
+
f'<span style="background-color: {color}; font-weight: bold;'
|
| 64 |
+
f' padding: 1px 2px; border-radius: 2px; position: relative;">'
|
| 65 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
| 66 |
+
f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
|
| 67 |
+
f'{m.group(0)}'
|
| 68 |
+
f'</span>'
|
| 69 |
+
),
|
| 70 |
+
highlighted_sentence,
|
| 71 |
+
flags=re.IGNORECASE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
+
highlighted_html.append(
|
| 74 |
+
f'<div style="margin-bottom: 5px;">'
|
| 75 |
+
f'{highlighted_sentence}'
|
| 76 |
+
f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; background-color: white; font-size: 0.9em;">'
|
| 77 |
+
f'Entailment Score: {score}</div></div>'
|
| 78 |
+
)
|
| 79 |
|
| 80 |
+
final_html = "<br>".join(highlighted_html)
|
| 81 |
return f'''
|
| 82 |
+
<div style="background-color: #ffffff; color: #374151;">
|
| 83 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #111827;">{title}</h3>
|
| 84 |
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
| 85 |
</div>
|
| 86 |
+
'''
|
lcs.py
CHANGED
|
@@ -40,7 +40,7 @@ def find_common_subsequences(sentence, str_list):
|
|
| 40 |
return indexed_common_grams
|
| 41 |
|
| 42 |
# Example usage
|
| 43 |
-
sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
|
| 44 |
-
str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
|
| 45 |
|
| 46 |
-
print(find_common_subsequences(sentence, str_list))
|
|
|
|
| 40 |
return indexed_common_grams
|
| 41 |
|
| 42 |
# Example usage
|
| 43 |
+
# sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
|
| 44 |
+
# str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
|
| 45 |
|
| 46 |
+
# print(find_common_subsequences(sentence, str_list))
|
masking_methods.py
CHANGED
|
@@ -1,3 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 2 |
from transformers import pipeline
|
| 3 |
import random
|
|
@@ -10,21 +73,27 @@ def mask_non_stopword(sentence):
|
|
| 10 |
words = sentence.split()
|
| 11 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
| 12 |
if not non_stop_words:
|
| 13 |
-
return sentence
|
| 14 |
word_to_mask = random.choice(non_stop_words)
|
| 15 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def mask_non_stopword_pseudorandom(sentence):
|
| 19 |
stop_words = set(stopwords.words('english'))
|
| 20 |
words = sentence.split()
|
| 21 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
| 22 |
if not non_stop_words:
|
| 23 |
-
return sentence
|
| 24 |
random.seed(10)
|
| 25 |
word_to_mask = random.choice(non_stop_words)
|
| 26 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def high_entropy_words(sentence, non_melting_points):
|
| 30 |
stop_words = set(stopwords.words('english'))
|
|
@@ -37,10 +106,11 @@ def high_entropy_words(sentence, non_melting_points):
|
|
| 37 |
candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
| 38 |
|
| 39 |
if not candidate_words:
|
| 40 |
-
return sentence
|
| 41 |
|
| 42 |
max_entropy = -float('inf')
|
| 43 |
max_entropy_word = None
|
|
|
|
| 44 |
|
| 45 |
for word in candidate_words:
|
| 46 |
masked_sentence = sentence.replace(word, '[MASK]', 1)
|
|
@@ -52,17 +122,19 @@ def high_entropy_words(sentence, non_melting_points):
|
|
| 52 |
if entropy > max_entropy:
|
| 53 |
max_entropy = entropy
|
| 54 |
max_entropy_word = word
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# Load tokenizer and model for masked language model
|
| 60 |
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
| 61 |
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
| 62 |
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 1 |
+
# from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 2 |
+
# from transformers import pipeline
|
| 3 |
+
# import random
|
| 4 |
+
# from nltk.corpus import stopwords
|
| 5 |
+
# import math
|
| 6 |
+
|
| 7 |
+
# # Masking Model
|
| 8 |
+
# def mask_non_stopword(sentence):
|
| 9 |
+
# stop_words = set(stopwords.words('english'))
|
| 10 |
+
# words = sentence.split()
|
| 11 |
+
# non_stop_words = [word for word in words if word.lower() not in stop_words]
|
| 12 |
+
# if not non_stop_words:
|
| 13 |
+
# return sentence
|
| 14 |
+
# word_to_mask = random.choice(non_stop_words)
|
| 15 |
+
# masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
| 16 |
+
# return masked_sentence
|
| 17 |
+
|
| 18 |
+
# def mask_non_stopword_pseudorandom(sentence):
|
| 19 |
+
# stop_words = set(stopwords.words('english'))
|
| 20 |
+
# words = sentence.split()
|
| 21 |
+
# non_stop_words = [word for word in words if word.lower() not in stop_words]
|
| 22 |
+
# if not non_stop_words:
|
| 23 |
+
# return sentence
|
| 24 |
+
# random.seed(10)
|
| 25 |
+
# word_to_mask = random.choice(non_stop_words)
|
| 26 |
+
# masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
| 27 |
+
# return masked_sentence
|
| 28 |
+
|
| 29 |
+
# def high_entropy_words(sentence, non_melting_points):
|
| 30 |
+
# stop_words = set(stopwords.words('english'))
|
| 31 |
+
# words = sentence.split()
|
| 32 |
+
|
| 33 |
+
# non_melting_words = set()
|
| 34 |
+
# for _, point in non_melting_points:
|
| 35 |
+
# non_melting_words.update(point.lower().split())
|
| 36 |
+
|
| 37 |
+
# candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
| 38 |
+
|
| 39 |
+
# if not candidate_words:
|
| 40 |
+
# return sentence
|
| 41 |
+
|
| 42 |
+
# max_entropy = -float('inf')
|
| 43 |
+
# max_entropy_word = None
|
| 44 |
+
|
| 45 |
+
# for word in candidate_words:
|
| 46 |
+
# masked_sentence = sentence.replace(word, '[MASK]', 1)
|
| 47 |
+
# predictions = fill_mask(masked_sentence)
|
| 48 |
+
|
| 49 |
+
# # Calculate entropy based on top 5 predictions
|
| 50 |
+
# entropy = -sum(pred['score'] * math.log(pred['score']) for pred in predictions[:5])
|
| 51 |
+
|
| 52 |
+
# if entropy > max_entropy:
|
| 53 |
+
# max_entropy = entropy
|
| 54 |
+
# max_entropy_word = word
|
| 55 |
+
|
| 56 |
+
# return sentence.replace(max_entropy_word, '[MASK]', 1)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# # Load tokenizer and model for masked language model
|
| 60 |
+
# tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
| 61 |
+
# model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
| 62 |
+
# fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
| 63 |
+
|
| 64 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 65 |
from transformers import pipeline
|
| 66 |
import random
|
|
|
|
| 73 |
words = sentence.split()
|
| 74 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
| 75 |
if not non_stop_words:
|
| 76 |
+
return sentence, None, None
|
| 77 |
word_to_mask = random.choice(non_stop_words)
|
| 78 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
| 79 |
+
predictions = fill_mask(masked_sentence)
|
| 80 |
+
words = [pred['score'] for pred in predictions]
|
| 81 |
+
logits = [pred['token_str'] for pred in predictions]
|
| 82 |
+
return masked_sentence, words, logits
|
| 83 |
|
| 84 |
def mask_non_stopword_pseudorandom(sentence):
|
| 85 |
stop_words = set(stopwords.words('english'))
|
| 86 |
words = sentence.split()
|
| 87 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
| 88 |
if not non_stop_words:
|
| 89 |
+
return sentence, None, None
|
| 90 |
random.seed(10)
|
| 91 |
word_to_mask = random.choice(non_stop_words)
|
| 92 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
| 93 |
+
predictions = fill_mask(masked_sentence)
|
| 94 |
+
words = [pred['score'] for pred in predictions]
|
| 95 |
+
logits = [pred['token_str'] for pred in predictions]
|
| 96 |
+
return masked_sentence, words, logits
|
| 97 |
|
| 98 |
def high_entropy_words(sentence, non_melting_points):
|
| 99 |
stop_words = set(stopwords.words('english'))
|
|
|
|
| 106 |
candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
| 107 |
|
| 108 |
if not candidate_words:
|
| 109 |
+
return sentence, None, None
|
| 110 |
|
| 111 |
max_entropy = -float('inf')
|
| 112 |
max_entropy_word = None
|
| 113 |
+
max_logits = None
|
| 114 |
|
| 115 |
for word in candidate_words:
|
| 116 |
masked_sentence = sentence.replace(word, '[MASK]', 1)
|
|
|
|
| 122 |
if entropy > max_entropy:
|
| 123 |
max_entropy = entropy
|
| 124 |
max_entropy_word = word
|
| 125 |
+
max_logits = [pred['score'] for pred in predictions]
|
| 126 |
|
| 127 |
+
masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1)
|
| 128 |
+
words = [pred['score'] for pred in predictions]
|
| 129 |
+
logits = [pred['token_str'] for pred in predictions]
|
| 130 |
+
return masked_sentence, words, logits
|
| 131 |
|
| 132 |
# Load tokenizer and model for masked language model
|
| 133 |
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
| 134 |
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
| 135 |
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
| 136 |
|
| 137 |
+
non_melting_points = [(1, 'Jewish'), (2, 'messages'), (3, 'stab')]
|
| 138 |
+
a, b, c = high_entropy_words("A former Cornell University student was sentenced to 21 months in prison on Monday after admitting that he had posted a series of online messages last fall in which he threatened to stab, rape and behead Jewish people", non_melting_points)
|
| 139 |
+
print(f"logits type: {type(b)}")
|
| 140 |
+
print(f"logits content: {b}")
|
|
|
paraphraser.py
CHANGED
|
@@ -28,4 +28,4 @@ def generate_paraphrase(question):
|
|
| 28 |
res = paraphrase(question, para_tokenizer, para_model)
|
| 29 |
return res
|
| 30 |
|
| 31 |
-
print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
|
|
|
|
| 28 |
res = paraphrase(question, para_tokenizer, para_model)
|
| 29 |
return res
|
| 30 |
|
| 31 |
+
# print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
|
sampling_methods.py
CHANGED
|
@@ -1,145 +1,33 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
sentence = sentence.replace(word_to_mark, colored(word_to_mark, 'red'))
|
| 27 |
-
|
| 28 |
-
for word in common_words:
|
| 29 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
| 30 |
-
|
| 31 |
-
results.append({
|
| 32 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
| 33 |
-
"Common Substrings": common_substrings
|
| 34 |
-
})
|
| 35 |
-
return results
|
| 36 |
-
|
| 37 |
-
# Function for Inverse Transform Sampling
|
| 38 |
-
def inverse_transform_sampling(original_sentence, paraphrased_sentences):
|
| 39 |
-
stop_words = set(stopwords.words('english'))
|
| 40 |
-
original_sentence_lower = original_sentence.lower()
|
| 41 |
-
paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
|
| 42 |
-
paraphrased_sentences_no_stopwords = []
|
| 43 |
-
|
| 44 |
-
for sentence in paraphrased_sentences_lower:
|
| 45 |
-
words = re.findall(r'\b\w+\b', sentence)
|
| 46 |
-
filtered_sentence = ' '.join([word for word in words if word not in stop_words])
|
| 47 |
-
paraphrased_sentences_no_stopwords.append(filtered_sentence)
|
| 48 |
-
|
| 49 |
-
results = []
|
| 50 |
-
for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
|
| 51 |
-
common_words = set(original_sentence_lower.split()) & set(sentence.split())
|
| 52 |
-
common_substrings = ', '.join(sorted(common_words))
|
| 53 |
-
|
| 54 |
-
words_to_replace = [word for word in sentence.split() if word not in common_words]
|
| 55 |
-
if words_to_replace:
|
| 56 |
-
probabilities = [1 / len(words_to_replace)] * len(words_to_replace)
|
| 57 |
-
chosen_word = random.choices(words_to_replace, weights=probabilities)[0]
|
| 58 |
-
sentence = sentence.replace(chosen_word, colored(chosen_word, 'magenta'))
|
| 59 |
-
|
| 60 |
-
for word in common_words:
|
| 61 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
| 62 |
-
|
| 63 |
-
results.append({
|
| 64 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
| 65 |
-
"Common Substrings": common_substrings
|
| 66 |
-
})
|
| 67 |
-
return results
|
| 68 |
-
|
| 69 |
-
# Function for Contextual Sampling
|
| 70 |
-
def contextual_sampling(original_sentence, paraphrased_sentences):
|
| 71 |
-
stop_words = set(stopwords.words('english'))
|
| 72 |
-
original_sentence_lower = original_sentence.lower()
|
| 73 |
-
paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
|
| 74 |
-
paraphrased_sentences_no_stopwords = []
|
| 75 |
-
|
| 76 |
-
for sentence in paraphrased_sentences_lower:
|
| 77 |
-
words = re.findall(r'\b\w+\b', sentence)
|
| 78 |
-
filtered_sentence = ' '.join([word for word in words if word not in stop_words])
|
| 79 |
-
paraphrased_sentences_no_stopwords.append(filtered_sentence)
|
| 80 |
-
|
| 81 |
-
results = []
|
| 82 |
-
for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
|
| 83 |
-
common_words = set(original_sentence_lower.split()) & set(sentence.split())
|
| 84 |
-
common_substrings = ', '.join(sorted(common_words))
|
| 85 |
-
|
| 86 |
-
words_to_replace = [word for word in sentence.split() if word not in common_words]
|
| 87 |
-
if words_to_replace:
|
| 88 |
-
context = " ".join([word for word in sentence.split() if word not in common_words])
|
| 89 |
-
chosen_word = random.choice(words_to_replace)
|
| 90 |
-
sentence = sentence.replace(chosen_word, colored(chosen_word, 'red'))
|
| 91 |
-
|
| 92 |
-
for word in common_words:
|
| 93 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
| 94 |
-
|
| 95 |
-
results.append({
|
| 96 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
| 97 |
-
"Common Substrings": common_substrings
|
| 98 |
-
})
|
| 99 |
-
return results
|
| 100 |
-
|
| 101 |
-
# Function for Exponential Minimum Sampling
|
| 102 |
-
def exponential_minimum_sampling(original_sentence, paraphrased_sentences):
|
| 103 |
-
stop_words = set(stopwords.words('english'))
|
| 104 |
-
original_sentence_lower = original_sentence.lower()
|
| 105 |
-
paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
|
| 106 |
-
paraphrased_sentences_no_stopwords = []
|
| 107 |
-
|
| 108 |
-
for sentence in paraphrased_sentences_lower:
|
| 109 |
-
words = re.findall(r'\b\w+\b', sentence)
|
| 110 |
-
filtered_sentence = ' '.join([word for word in words if word not in stop_words])
|
| 111 |
-
paraphrased_sentences_no_stopwords.append(filtered_sentence)
|
| 112 |
-
|
| 113 |
-
results = []
|
| 114 |
-
for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
|
| 115 |
-
common_words = set(original_sentence_lower.split()) & set(sentence.split())
|
| 116 |
-
common_substrings = ', '.join(sorted(common_words))
|
| 117 |
-
|
| 118 |
-
words_to_replace = [word for word in sentence.split() if word not in common_words]
|
| 119 |
-
if words_to_replace:
|
| 120 |
-
num_words = len(words_to_replace)
|
| 121 |
-
probabilities = [2 ** (-i) for i in range(num_words)]
|
| 122 |
-
chosen_word = random.choices(words_to_replace, weights=probabilities)[0]
|
| 123 |
-
sentence = sentence.replace(chosen_word, colored(chosen_word, 'red'))
|
| 124 |
-
|
| 125 |
-
for word in common_words:
|
| 126 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
| 127 |
-
|
| 128 |
-
results.append({
|
| 129 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
| 130 |
-
"Common Substrings": common_substrings
|
| 131 |
-
})
|
| 132 |
-
return results
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
#---------------------------------------------------------------------------
|
| 137 |
-
# aryans implementation please refactor it as you see fit
|
| 138 |
|
| 139 |
import torch
|
| 140 |
import random
|
| 141 |
|
| 142 |
-
def sample_word(words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
| 143 |
if sampling_technique == 'inverse_transform':
|
| 144 |
probs = torch.softmax(torch.tensor(logits), dim=-1)
|
| 145 |
cumulative_probs = torch.cumsum(probs, dim=-1)
|
|
@@ -160,4 +48,8 @@ def sample_word(words, logits, sampling_technique='inverse_transform', temperatu
|
|
| 160 |
raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
| 161 |
|
| 162 |
sampled_word = words[sampled_index]
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import torch
|
| 2 |
+
# import random
|
| 3 |
+
|
| 4 |
+
# def sample_word(words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
| 5 |
+
# if sampling_technique == 'inverse_transform':
|
| 6 |
+
# probs = torch.softmax(torch.tensor(logits), dim=-1)
|
| 7 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 8 |
+
# random_prob = random.random()
|
| 9 |
+
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0]
|
| 10 |
+
# elif sampling_technique == 'exponential_minimum':
|
| 11 |
+
# probs = torch.softmax(torch.tensor(logits), dim=-1)
|
| 12 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
| 13 |
+
# random_probs = torch.rand_like(exp_probs)
|
| 14 |
+
# sampled_index = torch.argmax(random_probs * exp_probs)
|
| 15 |
+
# elif sampling_technique == 'temperature':
|
| 16 |
+
# scaled_logits = torch.tensor(logits) / temperature
|
| 17 |
+
# probs = torch.softmax(scaled_logits, dim=-1)
|
| 18 |
+
# sampled_index = torch.multinomial(probs, 1).item()
|
| 19 |
+
# elif sampling_technique == 'greedy':
|
| 20 |
+
# sampled_index = torch.argmax(torch.tensor(logits)).item()
|
| 21 |
+
# else:
|
| 22 |
+
# raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
| 23 |
+
|
| 24 |
+
# sampled_word = words[sampled_index]
|
| 25 |
+
# return sampled_word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
import torch
|
| 28 |
import random
|
| 29 |
|
| 30 |
+
def sample_word(sentence, words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
| 31 |
if sampling_technique == 'inverse_transform':
|
| 32 |
probs = torch.softmax(torch.tensor(logits), dim=-1)
|
| 33 |
cumulative_probs = torch.cumsum(probs, dim=-1)
|
|
|
|
| 48 |
raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
| 49 |
|
| 50 |
sampled_word = words[sampled_index]
|
| 51 |
+
|
| 52 |
+
# Replace [MASK] with the sampled word
|
| 53 |
+
filled_sentence = sentence.replace('[MASK]', sampled_word)
|
| 54 |
+
|
| 55 |
+
return filled_sentence
|
tree.py
CHANGED
|
@@ -1,29 +1,31 @@
|
|
| 1 |
-
import plotly.
|
| 2 |
import textwrap
|
| 3 |
import re
|
| 4 |
from collections import defaultdict
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
def get_levels_and_edges(nodes):
|
| 28 |
levels = {}
|
| 29 |
edges = []
|
|
@@ -37,58 +39,99 @@ def generate_plot(original_sentence, selected_sentences):
|
|
| 37 |
if level == 1:
|
| 38 |
edges.append((root_node, i))
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
return levels, edges
|
| 48 |
|
| 49 |
# Get levels and dynamic edges
|
| 50 |
levels, edges = get_levels_and_edges(nodes)
|
| 51 |
-
max_level = max(levels.values())
|
| 52 |
|
| 53 |
# Calculate positions
|
| 54 |
positions = {}
|
| 55 |
-
|
| 56 |
for node, level in levels.items():
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
|
| 62 |
for node, level in levels.items():
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Create figure
|
| 67 |
fig = go.Figure()
|
| 68 |
|
| 69 |
# Add nodes to the figure
|
| 70 |
for i, node in enumerate(wrapped_nodes):
|
|
|
|
| 71 |
x, y = positions[i]
|
| 72 |
fig.add_trace(go.Scatter(
|
| 73 |
-
x=[x],
|
| 74 |
y=[y],
|
| 75 |
mode='markers',
|
| 76 |
marker=dict(size=10, color='blue'),
|
| 77 |
hoverinfo='none'
|
| 78 |
))
|
| 79 |
fig.add_annotation(
|
| 80 |
-
x
|
| 81 |
y=y,
|
| 82 |
-
text=
|
| 83 |
showarrow=False,
|
| 84 |
-
|
| 85 |
align="center",
|
| 86 |
-
font=dict(size=
|
| 87 |
bordercolor='black',
|
| 88 |
borderwidth=1,
|
| 89 |
-
borderpad=
|
| 90 |
bgcolor='white',
|
| 91 |
-
width=
|
| 92 |
)
|
| 93 |
|
| 94 |
# Add edges to the figure
|
|
@@ -96,19 +139,19 @@ def generate_plot(original_sentence, selected_sentences):
|
|
| 96 |
x0, y0 = positions[edge[0]]
|
| 97 |
x1, y1 = positions[edge[1]]
|
| 98 |
fig.add_trace(go.Scatter(
|
| 99 |
-
x=[x0, x1],
|
| 100 |
y=[y0, y1],
|
| 101 |
mode='lines',
|
| 102 |
-
line=dict(color='black', width=
|
| 103 |
))
|
| 104 |
|
| 105 |
fig.update_layout(
|
| 106 |
showlegend=False,
|
| 107 |
-
margin=dict(t=
|
| 108 |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 109 |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 110 |
-
width=
|
| 111 |
-
height=
|
| 112 |
)
|
| 113 |
|
| 114 |
return fig
|
|
|
|
| 1 |
+
import plotly.graph_objects as go
|
| 2 |
import textwrap
|
| 3 |
import re
|
| 4 |
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
|
| 7 |
+
# Combine nodes into one list with appropriate labels
|
| 8 |
+
nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
|
| 9 |
+
nodes[0] += ' L0' # Paraphrased sentence is level 0
|
| 10 |
+
para_len = len(scheme_sentences)
|
| 11 |
+
for i in range(1, para_len + 1):
|
| 12 |
+
nodes[i] += ' L1' # Scheme sentences are level 1
|
| 13 |
+
for i in range(para_len + 1, len(nodes)):
|
| 14 |
+
nodes[i] += ' L2' # Sampled sentences are level 2
|
| 15 |
+
|
| 16 |
+
# Define the highlight_words function
|
| 17 |
+
def highlight_words(sentence, color_map):
|
| 18 |
+
for word, color in color_map.items():
|
| 19 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
| 20 |
+
return sentence
|
| 21 |
+
|
| 22 |
+
# Clean and wrap nodes, and highlight specified words globally
|
|
|
|
| 23 |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
| 24 |
+
global_color_map = dict(highlight_info)
|
| 25 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
| 26 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
|
| 27 |
+
|
| 28 |
+
# Function to determine tree levels and create edges dynamically
|
| 29 |
def get_levels_and_edges(nodes):
|
| 30 |
levels = {}
|
| 31 |
edges = []
|
|
|
|
| 39 |
if level == 1:
|
| 40 |
edges.append((root_node, i))
|
| 41 |
|
| 42 |
+
# Add edges from each L1 node to their corresponding L2 nodes
|
| 43 |
+
l1_indices = [i for i, level in levels.items() if level == 1]
|
| 44 |
+
l2_indices = [i for i, level in levels.items() if level == 2]
|
| 45 |
+
|
| 46 |
+
for i, l1_node in enumerate(l1_indices):
|
| 47 |
+
l2_start = i * 4
|
| 48 |
+
for j in range(4):
|
| 49 |
+
l2_index = l2_start + j
|
| 50 |
+
if l2_index < len(l2_indices):
|
| 51 |
+
edges.append((l1_node, l2_indices[l2_index]))
|
| 52 |
+
|
| 53 |
+
# Add edges from each L2 node to their corresponding L3 nodes
|
| 54 |
+
l2_indices = [i for i, level in levels.items() if level == 2]
|
| 55 |
+
l3_indices = [i for i, level in levels.items() if level == 3]
|
| 56 |
+
|
| 57 |
+
l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
|
| 58 |
+
|
| 59 |
+
# Map L3 nodes to L2 nodes
|
| 60 |
+
for l3_node in l3_indices:
|
| 61 |
+
l2_node = l3_node % len(l2_indices)
|
| 62 |
+
l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
|
| 63 |
+
|
| 64 |
+
for l2_node, l3_nodes in l2_to_l3_map.items():
|
| 65 |
+
for l3_node in l3_nodes:
|
| 66 |
+
edges.append((l2_node, l3_node))
|
| 67 |
|
| 68 |
return levels, edges
|
| 69 |
|
| 70 |
# Get levels and dynamic edges
|
| 71 |
levels, edges = get_levels_and_edges(nodes)
|
| 72 |
+
max_level = max(levels.values(), default=0)
|
| 73 |
|
| 74 |
# Calculate positions
|
| 75 |
positions = {}
|
| 76 |
+
level_heights = defaultdict(int)
|
| 77 |
for node, level in levels.items():
|
| 78 |
+
level_heights[level] += 1
|
| 79 |
|
| 80 |
+
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
|
| 81 |
+
x_gap = 2
|
| 82 |
+
l1_y_gap = 10
|
| 83 |
+
l2_y_gap = 6
|
| 84 |
|
| 85 |
for node, level in levels.items():
|
| 86 |
+
if level == 1:
|
| 87 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
| 88 |
+
elif level == 2:
|
| 89 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
|
| 90 |
+
else:
|
| 91 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
|
| 92 |
+
y_offsets[level] += 1
|
| 93 |
+
|
| 94 |
+
# Function to highlight words in a wrapped node string
|
| 95 |
+
def color_highlighted_words(node, color_map):
|
| 96 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
| 97 |
+
colored_parts = []
|
| 98 |
+
for part in parts:
|
| 99 |
+
match = re.match(r'\{\{(.*?)\}\}', part)
|
| 100 |
+
if match:
|
| 101 |
+
word = match.group(1)
|
| 102 |
+
color = color_map.get(word, 'black')
|
| 103 |
+
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
|
| 104 |
+
else:
|
| 105 |
+
colored_parts.append(part)
|
| 106 |
+
return ''.join(colored_parts)
|
| 107 |
|
| 108 |
# Create figure
|
| 109 |
fig = go.Figure()
|
| 110 |
|
| 111 |
# Add nodes to the figure
|
| 112 |
for i, node in enumerate(wrapped_nodes):
|
| 113 |
+
colored_node = color_highlighted_words(node, global_color_map)
|
| 114 |
x, y = positions[i]
|
| 115 |
fig.add_trace(go.Scatter(
|
| 116 |
+
x=[-x], # Reflect the x coordinate
|
| 117 |
y=[y],
|
| 118 |
mode='markers',
|
| 119 |
marker=dict(size=10, color='blue'),
|
| 120 |
hoverinfo='none'
|
| 121 |
))
|
| 122 |
fig.add_annotation(
|
| 123 |
+
x=-x, # Reflect the x coordinate
|
| 124 |
y=y,
|
| 125 |
+
text=colored_node,
|
| 126 |
showarrow=False,
|
| 127 |
+
xshift=15,
|
| 128 |
align="center",
|
| 129 |
+
font=dict(size=8),
|
| 130 |
bordercolor='black',
|
| 131 |
borderwidth=1,
|
| 132 |
+
borderpad=2,
|
| 133 |
bgcolor='white',
|
| 134 |
+
width=150
|
| 135 |
)
|
| 136 |
|
| 137 |
# Add edges to the figure
|
|
|
|
| 139 |
x0, y0 = positions[edge[0]]
|
| 140 |
x1, y1 = positions[edge[1]]
|
| 141 |
fig.add_trace(go.Scatter(
|
| 142 |
+
x=[-x0, -x1], # Reflect the x coordinates
|
| 143 |
y=[y0, y1],
|
| 144 |
mode='lines',
|
| 145 |
+
line=dict(color='black', width=1)
|
| 146 |
))
|
| 147 |
|
| 148 |
fig.update_layout(
|
| 149 |
showlegend=False,
|
| 150 |
+
margin=dict(t=20, b=20, l=20, r=20),
|
| 151 |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 152 |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 153 |
+
width=1200, # Adjusted width to accommodate more levels
|
| 154 |
+
height=1000 # Adjusted height to accommodate more levels
|
| 155 |
)
|
| 156 |
|
| 157 |
return fig
|