miasambolec commited on
Commit
786c704
·
verified ·
1 Parent(s): 6f7dade

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
+
4
+ from huggingface_hub import hf_hub_download
5
+ import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ import torch
8
+ import torch.nn as nn
9
+ import pickle
10
+ import numpy as np
11
+ import re
12
+ import fasttext
13
+
14
+ svm_repo_id = "HighFive-OPJ/svm-sentiment-model"
15
+ svm_model_path = hf_hub_download(repo_id=svm_repo_id, filename="svm_model.pkl")
16
+ with open(svm_model_path, "rb") as f:
17
+ svm_model = pickle.load(f)
18
+ vectorizer_path = hf_hub_download(repo_id=svm_repo_id, filename="vectorizer.pkl")
19
+ with open(vectorizer_path, "rb") as f:
20
+ vectorizer = pickle.load(f)
21
+
22
+ fasttext_path = hf_hub_download(
23
+ repo_id="HighFive-OPJ/Deep_Learning",
24
+ filename="FastText.bin",
25
+ repo_type="dataset"
26
+ )
27
+ ft_model = fasttext.load_model(fasttext_path)
28
+
29
+ class LSTMClassifier(nn.Module):
30
+ def __init__(self, input_dim=300, hidden_dim=256, num_classes=3):
31
+ super().__init__()
32
+ self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
33
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
34
+
35
+ def forward(self, x):
36
+ _, (hn, _) = self.lstm(x)
37
+ hn = torch.cat((hn[-2], hn[-1]), dim=1)
38
+ out = self.fc(hn)
39
+ return out
40
+
41
+ lstm_repo_id = "HighFive-OPJ/lstm-sentiment-model"
42
+ lstm_model_path = hf_hub_download(repo_id=lstm_repo_id, filename="fasttext_lstm.pt")
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ lstm_model = LSTMClassifier()
46
+ lstm_model.load_state_dict(torch.load(lstm_model_path, map_location=device))
47
+ lstm_model.to(device)
48
+ lstm_model.eval()
49
+
50
+ bert_repo_id = "HighFive-OPJ/bertic_sentiment"
51
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_repo_id)
52
+ bert_model = AutoModelForSequenceClassification.from_pretrained(bert_repo_id)
53
+ bert_model.to(device)
54
+ bert_model.eval()
55
+
56
+ def preprocess_text(text):
57
+ text = text.lower()
58
+ text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
59
+ return text
60
+
61
+ def text_to_fasttext_tensor(text, max_len=200):
62
+ tokens = preprocess_text(text).split()
63
+ vectors = []
64
+ for t in tokens[:max_len]:
65
+ vec = ft_model.get_word_vector(t)
66
+ vectors.append(vec)
67
+ while len(vectors) < max_len:
68
+ vectors.append(np.zeros(300))
69
+ return torch.tensor([vectors], dtype=torch.float32).to(device)
70
+
71
+ def predict_with_svm(text):
72
+ transformed = vectorizer.transform([text])
73
+ prediction = svm_model.predict(transformed)
74
+ return int(prediction[0])
75
+
76
+ def predict_with_lstm(text):
77
+ input_tensor = text_to_fasttext_tensor(text)
78
+ with torch.no_grad():
79
+ outputs = lstm_model(input_tensor)
80
+ pred = torch.argmax(outputs, dim=1).item()
81
+ return pred
82
+
83
+ def predict_with_bert(text):
84
+ inputs = bert_tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
85
+ with torch.no_grad():
86
+ outputs = bert_model(**inputs)
87
+ logits = outputs.logits
88
+ predictions = logits.argmax(axis=-1).cpu().numpy()
89
+ bert_score = int(predictions[0])
90
+ if bert_score <= 2:
91
+ return 0
92
+ elif bert_score == 3:
93
+ return 1
94
+ else:
95
+ return 2
96
+
97
+ def analyze_sentiment(text):
98
+ try:
99
+ svm_result = predict_with_svm(text)
100
+ except Exception as e:
101
+ svm_result = f"Error: {str(e)}"
102
+
103
+ try:
104
+ lstm_result = predict_with_lstm(text)
105
+ except Exception as e:
106
+ lstm_result = f"Error: {str(e)}"
107
+
108
+ try:
109
+ bert_result = predict_with_bert(text)
110
+ except Exception as e:
111
+ bert_result = f"Error: {str(e)}"
112
+
113
+ try:
114
+ scores = []
115
+ for r in [svm_result, lstm_result, bert_result]:
116
+ if isinstance(r, int):
117
+ scores.append(r)
118
+ average = np.mean(scores) if scores else float("nan")
119
+ stats = f"Average Score (0=Pos,1=Neg,2=Neu): {average:.2f}\n"
120
+ except Exception as e:
121
+ stats = f"Error calculating stats: {str(e)}"
122
+
123
+ def format_output(result):
124
+ return convert_to_stars(result) if isinstance(result, int) else result
125
+
126
+ return (
127
+ format_output(svm_result),
128
+ format_output(lstm_result),
129
+ format_output(bert_result),
130
+ stats
131
+ )
132
+ def convert_to_stars(score):
133
+ star_map = {0: 5, 1: 1, 2: 3}
134
+ stars = star_map.get(score, 3)
135
+ return "★" * stars + "☆" * (5 - stars)
136
+
137
+ def process_input(text):
138
+ if not text.strip():
139
+ return ("", "", "", "Please enter valid text.")
140
+ try:
141
+ return analyze_sentiment(text)
142
+ except Exception as e:
143
+ error_message = f"Error during sentiment analysis:\n{str(e)}"
144
+ return ("error", "error", "error", error_message)
145
+
146
+ with gr.Blocks() as demo:
147
+ gr.Markdown("# Sentiment Analysis Demo")
148
+ gr.Markdown("""
149
+ Enter a review and see how different models evaluate its sentiment! This app uses:
150
+ - SVM for classic machine learning
151
+ - LSTM for deep learning (using FastText)
152
+ - BERTić for transformer-based analysis
153
+ """)
154
+
155
+ with gr.Row():
156
+ with gr.Column():
157
+ input_text = gr.Textbox(label="Enter your review:", lines=3)
158
+ analyze_button = gr.Button("Analyze Sentiment")
159
+
160
+ with gr.Column():
161
+ svm_output = gr.Textbox(label="SVM", interactive=False)
162
+ lstm_output = gr.Textbox(label="LSTM", interactive=False)
163
+ bert_output = gr.Textbox(label="BERTić", interactive=False)
164
+ stats_output = gr.Textbox(label="Statistics", interactive=False)
165
+
166
+ analyze_button.click(
167
+ process_input,
168
+ inputs=[input_text],
169
+ outputs=[svm_output, lstm_output, bert_output, stats_output]
170
+ )
171
+
172
+ demo.launch()