Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.cluster import KMeans | |
| import pickle | |
| class PersonalityClustering: | |
| DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2' | |
| def sentence_transformer(self): | |
| """Ленивая инициализация sentence_transformer.""" | |
| if not self.__sentence_transformer: | |
| self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device) | |
| return self.__sentence_transformer | |
| def clustering(self): | |
| """Ленивая инициализация кластеризации.""" | |
| if not self.__clustering: | |
| self.__clustering = KMeans(n_clusters=self.n_clusters) | |
| return self.__clustering | |
| def __init__(self, n_clusters=None, device='cpu', model_name=None): | |
| if model_name is None: | |
| self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER | |
| else: | |
| self.model_name = model_name | |
| self.device = device | |
| self.n_clusters = n_clusters | |
| self._cluster_centers = None | |
| self.__clustering = None | |
| self.__sentence_transformer = None | |
| def load(self, path): | |
| with open(path, "rb") as f: | |
| self.__clustering, self._cluster_centers = pickle.load(f) | |
| def save(self, path): | |
| with open(path, "wb") as f: | |
| pickle.dump((self.__clustering, self._cluster_centers), f) | |
| def fit(self, personalities): | |
| personalities = np.array(list(personalities)) | |
| train_embeddings = self.sentence_transformer.encode(personalities) | |
| clusters = self.clustering.fit_predict(train_embeddings) | |
| persona_cluster_centers = [] | |
| for clust, center in enumerate(self.clustering.cluster_centers_): | |
| cur_clust_embed = train_embeddings[clusters == clust] | |
| cur_clust_personalities = personalities[clusters == clust] | |
| min_distance_to_center = np.inf | |
| persona_center = None | |
| for embed, persona in zip(cur_clust_embed, cur_clust_personalities): | |
| cur_distance_to_center = np.linalg.norm(embed - center) | |
| if cur_distance_to_center < min_distance_to_center: | |
| min_distance_to_center = cur_distance_to_center | |
| persona_center = persona | |
| persona_cluster_centers.append(persona_center) | |
| self._cluster_centers = np.array(persona_cluster_centers) | |
| return self | |
| def predict(self, personalities): | |
| personalities = np.array(list(personalities)) | |
| embeddings = self.sentence_transformer.encode(personalities) | |
| clusters = self.clustering.predict(embeddings) | |
| return clusters | |
| def predict_nearest_personality(self, personalities): | |
| clusters = self.predict(personalities) | |
| return np.array([self._cluster_centers[clust] for clust in clusters]) | |
| def fit_predict(self, personalities): | |
| self.fit(personalities) | |
| return self.predict(personalities) | |