Spaces:
Runtime error
Runtime error
| import argparse | |
| from datasets import load_dataset | |
| from models.personality_clustering import PersonalityClustering | |
| import os | |
| """Пример запуска | |
| python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500 | |
| """ | |
| PERSONACHAT_DATASET = "bavard/personachat_truecased" | |
| def load_persona_chat_personalities(personachat_dataset): | |
| dataset = load_dataset(personachat_dataset) | |
| train_personalities = [sent for persona in dataset['train']['personality'] | |
| for sent in persona] | |
| test_personalities = [sent for persona in dataset['train']['personality'] | |
| for sent in persona] | |
| personalities = list(set(train_personalities) | set(test_personalities)) | |
| return personalities | |
| def parse_args(args=None): | |
| parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.") | |
| parser.add_argument('-clustering-path', '--clustering-path', type=str, | |
| help='Path to clustering data.') | |
| parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500, | |
| help='The number of clusters to form.') | |
| parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False) | |
| args = parser.parse_args(args) | |
| return args | |
| def main(): | |
| args = parse_args() | |
| personalities = load_persona_chat_personalities(PERSONACHAT_DATASET) | |
| print('Data loaded') | |
| model = PersonalityClustering(n_clusters=args.n_clusters) | |
| print('Model fitting') | |
| model.fit(personalities) | |
| print('Model fitted') | |
| if args.model_name is None: | |
| model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl' | |
| else: | |
| model_name = args.model_name | |
| model.save(os.path.join(args.clustering_path, model_name)) | |
| print(f'{model_name} saved') | |
| if __name__ == '__main__': | |
| main() |