Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import torch | |
| from sklearn.neighbors import KDTree | |
| class PersonalityManager: | |
| def __init__(self, prompt_paths, personality_clustering): | |
| self.prompt_paths = prompt_paths | |
| self.personality_clustering = personality_clustering | |
| self.persona_ids = list(prompt_paths.keys()) | |
| self.personalities = [personality_clustering._cluster_centers[i] | |
| for i in self.persona_ids] | |
| self.embeddings = personality_clustering.sentence_transformer.encode(self.personalities) | |
| self._nearest_neighbours = KDTree(self.embeddings, metric='euclidean') | |
| def get_prompt(self, description): | |
| embedding = self.personality_clustering.sentence_transformer.encode([description]) | |
| dist, ind = self._nearest_neighbours.query(embedding, k=1) | |
| persona_id = self.persona_ids[ind[0][0]] | |
| prompt_path = self.prompt_paths[persona_id] | |
| cluster_center = self.personality_clustering._cluster_centers[persona_id] | |
| return prompt_path, cluster_center | |
| class PersonalizedChatBot: | |
| def __init__(self, model, tokenizer, prompt_path=None, generation_config=None): | |
| self.model = model | |
| if prompt_path is not None: | |
| self.load_prompt(prompt_path) | |
| self.tokenizer = tokenizer | |
| self.separator = '\n' | |
| self.dialog = '' | |
| self.generation_config = generation_config | |
| def load_prompt(self, path): | |
| self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path)) | |
| def load_config(self, path): | |
| with open(path, 'r') as f: | |
| config = json.load(f) | |
| self.generation_config = argparse.Namespace(**config) | |
| def reset_dialog(self, ): | |
| self.dialog = '' | |
| def answer(self, phrase): | |
| if len(phrase) == 0: | |
| return | |
| self.dialog += f"{phrase}{self.separator}" | |
| inputs = self.tokenizer([self.dialog], return_tensors='pt')['input_ids'] | |
| outputs = self.model.generate( | |
| inputs, | |
| temperature=self.generation_config.TEMPERATURE, | |
| do_sample=True, | |
| top_k=self.generation_config.TOP_K, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| max_new_tokens=self.generation_config.MAX_TOKENS, | |
| ) | |
| bloom_answer = self.tokenizer.batch_decode(outputs)[0] | |
| bloom_answer = bloom_answer[len(self.dialog):].split("\n")[0] | |
| self.dialog += f"{bloom_answer}{self.separator}" | |
| return bloom_answer |