|
|
from __future__ import print_function |
|
|
import sys |
|
|
from os import path, makedirs |
|
|
|
|
|
sys.path.append(".") |
|
|
sys.path.append("..") |
|
|
|
|
|
import argparse |
|
|
from copy import deepcopy |
|
|
import json |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from collections import namedtuple |
|
|
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits |
|
|
from utils.models.parsing_gating import BiAffine_Parser_Gated |
|
|
from utils import load_word_embeddings |
|
|
from utils.tasks import parse |
|
|
import time |
|
|
from torch.nn.utils import clip_grad_norm_ |
|
|
from torch.optim import Adam, SGD |
|
|
import uuid |
|
|
|
|
|
uid = uuid.uuid4().hex[:6] |
|
|
|
|
|
logger = get_logger('GraphParser') |
|
|
|
|
|
def read_arguments(): |
|
|
args_ = argparse.ArgumentParser(description='Sovling GraphParser') |
|
|
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True) |
|
|
args_.add_argument('--domain', help='domain/language', required=True) |
|
|
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn', |
|
|
required=True) |
|
|
args_.add_argument('--gating',action='store_true', help='use gated mechanism') |
|
|
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism') |
|
|
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs') |
|
|
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch') |
|
|
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN') |
|
|
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space') |
|
|
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space') |
|
|
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN') |
|
|
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN') |
|
|
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN') |
|
|
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.') |
|
|
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.') |
|
|
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings') |
|
|
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings') |
|
|
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings') |
|
|
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters') |
|
|
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm') |
|
|
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer') |
|
|
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer') |
|
|
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate') |
|
|
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate') |
|
|
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay') |
|
|
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping') |
|
|
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization') |
|
|
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam') |
|
|
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN') |
|
|
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings') |
|
|
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer') |
|
|
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True) |
|
|
args_.add_argument('--unk_replace', type=float, default=0., |
|
|
help='The rate to replace a singleton word with UNK') |
|
|
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations') |
|
|
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'], |
|
|
help='Embedding for words') |
|
|
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random') |
|
|
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).') |
|
|
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.') |
|
|
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters', |
|
|
required=True) |
|
|
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos', |
|
|
required=True) |
|
|
args_.add_argument('--char_path', help='path for character embedding dict') |
|
|
args_.add_argument('--pos_path', help='path for pos embedding dict') |
|
|
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples') |
|
|
args_.add_argument('--model_path', help='path for saving model file.', required=True) |
|
|
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None) |
|
|
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None) |
|
|
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin ' |
|
|
'exactly the same keys as current model') |
|
|
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it') |
|
|
args = args_.parse_args() |
|
|
args_dict = {} |
|
|
args_dict['dataset'] = args.dataset |
|
|
args_dict['domain'] = args.domain |
|
|
args_dict['rnn_mode'] = args.rnn_mode |
|
|
args_dict['gating'] = args.gating |
|
|
args_dict['num_gates'] = args.num_gates |
|
|
args_dict['arc_decode'] = args.arc_decode |
|
|
|
|
|
args_dict['splits'] = ['train', 'dev', 'test','poetry','prose'] |
|
|
args_dict['model_path'] = args.model_path |
|
|
if not path.exists(args_dict['model_path']): |
|
|
makedirs(args_dict['model_path']) |
|
|
args_dict['data_paths'] = {} |
|
|
if args_dict['dataset'] == 'ontonotes': |
|
|
data_path = 'data/Pre_MRL/onto_pos_ner_dp' |
|
|
else: |
|
|
data_path = 'data/Prep_MRL/ud_pos_ner_dp' |
|
|
for split in args_dict['splits']: |
|
|
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'] |
|
|
|
|
|
args_dict['data_paths']['poetry'] = data_path + '_' + 'test' + '_' + args_dict['domain'] |
|
|
args_dict['data_paths']['prose'] = data_path + '_' + 'test' + '_' + args_dict['domain'] |
|
|
|
|
|
args_dict['alphabet_data_paths'] = {} |
|
|
for split in args_dict['splits']: |
|
|
if args_dict['dataset'] == 'ontonotes': |
|
|
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all' |
|
|
else: |
|
|
if '_' in args_dict['domain']: |
|
|
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0] |
|
|
else: |
|
|
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split] |
|
|
args_dict['model_name'] = 'domain_' + args_dict['domain'] |
|
|
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name']) |
|
|
args_dict['load_path'] = args.load_path |
|
|
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths |
|
|
if args_dict['load_sequence_taggers_paths'] is not None: |
|
|
args_dict['gating'] = True |
|
|
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1 |
|
|
else: |
|
|
if not args_dict['gating']: |
|
|
args_dict['num_gates'] = 0 |
|
|
args_dict['strict'] = args.strict |
|
|
args_dict['num_epochs'] = args.num_epochs |
|
|
args_dict['batch_size'] = args.batch_size |
|
|
args_dict['hidden_size'] = args.hidden_size |
|
|
args_dict['arc_space'] = args.arc_space |
|
|
args_dict['arc_tag_space'] = args.arc_tag_space |
|
|
args_dict['num_layers'] = args.num_layers |
|
|
args_dict['num_filters'] = args.num_filters |
|
|
args_dict['kernel_size'] = args.kernel_size |
|
|
args_dict['learning_rate'] = args.learning_rate |
|
|
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None |
|
|
args_dict['opt'] = args.opt |
|
|
args_dict['momentum'] = args.momentum |
|
|
args_dict['betas'] = tuple(args.betas) |
|
|
args_dict['epsilon'] = args.epsilon |
|
|
args_dict['decay_rate'] = args.decay_rate |
|
|
args_dict['clip'] = args.clip |
|
|
args_dict['gamma'] = args.gamma |
|
|
args_dict['schedule'] = args.schedule |
|
|
args_dict['p_rnn'] = tuple(args.p_rnn) |
|
|
args_dict['p_in'] = args.p_in |
|
|
args_dict['p_out'] = args.p_out |
|
|
args_dict['unk_replace'] = args.unk_replace |
|
|
args_dict['set_num_training_samples'] = args.set_num_training_samples |
|
|
args_dict['punct_set'] = None |
|
|
if args.punct_set is not None: |
|
|
args_dict['punct_set'] = set(args.punct_set) |
|
|
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set']))) |
|
|
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings |
|
|
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers |
|
|
args_dict['word_embedding'] = args.word_embedding |
|
|
args_dict['word_path'] = args.word_path |
|
|
args_dict['use_char'] = args.use_char |
|
|
args_dict['char_embedding'] = args.char_embedding |
|
|
args_dict['char_path'] = args.char_path |
|
|
args_dict['pos_embedding'] = args.pos_embedding |
|
|
args_dict['pos_path'] = args.pos_path |
|
|
args_dict['use_pos'] = args.use_pos |
|
|
args_dict['pos_dim'] = args.pos_dim |
|
|
args_dict['word_dict'] = None |
|
|
args_dict['word_dim'] = args.word_dim |
|
|
if args_dict['word_embedding'] != 'random' and args_dict['word_path']: |
|
|
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'], |
|
|
args_dict['word_path']) |
|
|
args_dict['char_dict'] = None |
|
|
args_dict['char_dim'] = args.char_dim |
|
|
if args_dict['char_embedding'] != 'random': |
|
|
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'], |
|
|
args_dict['char_path']) |
|
|
args_dict['pos_dict'] = None |
|
|
if args_dict['pos_embedding'] != 'random': |
|
|
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'], |
|
|
args_dict['pos_path']) |
|
|
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/') |
|
|
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name']) |
|
|
args_dict['eval_mode'] = args.eval_mode |
|
|
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune' |
|
|
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled' |
|
|
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled' |
|
|
logger.info("Saving arguments to file") |
|
|
save_args(args, args_dict['full_model_name']) |
|
|
logger.info("Creating Alphabets") |
|
|
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict']) |
|
|
args_dict = {**args_dict, **alphabet_dict} |
|
|
ARGS = namedtuple('ARGS', args_dict.keys()) |
|
|
my_args = ARGS(**args_dict) |
|
|
return my_args |
|
|
|
|
|
|
|
|
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict): |
|
|
train_paths = alphabet_data_paths['train'] |
|
|
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train'] |
|
|
alphabet_dict = {} |
|
|
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path, |
|
|
train_paths, |
|
|
extra_paths=extra_paths, |
|
|
max_vocabulary_size=100000, |
|
|
embedd_dict=word_dict) |
|
|
for k, v in alphabet_dict['alphabets'].items(): |
|
|
num_key = 'num_' + k.split('_')[0] |
|
|
alphabet_dict[num_key] = v.size() |
|
|
logger.info("%s : %d" % (num_key, alphabet_dict[num_key])) |
|
|
return alphabet_dict |
|
|
|
|
|
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'): |
|
|
if tokens_dict is None: |
|
|
return None |
|
|
scale = np.sqrt(3.0 / dim) |
|
|
table = np.empty([alphabet.size(), dim], dtype=np.float32) |
|
|
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32) |
|
|
oov_tokens = 0 |
|
|
for token, index in alphabet.items(): |
|
|
if token in ['aTA']: |
|
|
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32) |
|
|
oov_tokens += 1 |
|
|
elif token in tokens_dict: |
|
|
embedding = tokens_dict[token] |
|
|
elif token.lower() in tokens_dict: |
|
|
embedding = tokens_dict[token.lower()] |
|
|
else: |
|
|
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32) |
|
|
oov_tokens += 1 |
|
|
|
|
|
table[index, :] = embedding |
|
|
print('token type : %s, number of oov: %d' % (token_type, oov_tokens)) |
|
|
table = torch.from_numpy(table) |
|
|
return table |
|
|
|
|
|
def save_args(args, full_model_name): |
|
|
arg_path = full_model_name + '.arg.json' |
|
|
argparse_dict = vars(args) |
|
|
with open(arg_path, 'w') as f: |
|
|
json.dump(argparse_dict, f) |
|
|
|
|
|
def generate_optimizer(args, lr, params): |
|
|
params = filter(lambda param: param.requires_grad, params) |
|
|
if args.opt == 'adam': |
|
|
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon) |
|
|
elif args.opt == 'sgd': |
|
|
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True) |
|
|
else: |
|
|
raise ValueError('Unknown optimization algorithm: %s' % args.opt) |
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name): |
|
|
path_name = full_model_name + '.pt' |
|
|
print('Saving model to: %s' % path_name) |
|
|
state = {'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'opt': opt, |
|
|
'dev_eval_dict': dev_eval_dict, |
|
|
'test_eval_dict': test_eval_dict} |
|
|
torch.save(state, path_name) |
|
|
|
|
|
|
|
|
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True): |
|
|
print('Loading saved model from: %s' % load_path) |
|
|
checkpoint = torch.load(load_path, map_location=args.device) |
|
|
if checkpoint['opt'] != args.opt: |
|
|
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt)) |
|
|
model.load_state_dict(checkpoint['model_state_dict'], strict=strict) |
|
|
|
|
|
if strict: |
|
|
generate_optimizer(args, args.learning_rate, model.parameters()) |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
for state in optimizer.state.values(): |
|
|
for k, v in state.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
state[k] = v.to(args.device) |
|
|
dev_eval_dict = checkpoint['dev_eval_dict'] |
|
|
test_eval_dict = checkpoint['test_eval_dict'] |
|
|
start_epoch = dev_eval_dict['in_domain']['epoch'] |
|
|
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch |
|
|
|
|
|
|
|
|
def build_model_and_optimizer(args): |
|
|
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word') |
|
|
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char') |
|
|
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos') |
|
|
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char, |
|
|
args.use_pos, args.use_char, args.pos_dim, args.num_pos, |
|
|
args.num_filters, args.kernel_size, args.rnn_mode, |
|
|
args.hidden_size, args.num_layers, args.num_arc, |
|
|
args.arc_space, args.arc_tag_space, args.num_gates, |
|
|
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table, |
|
|
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn, |
|
|
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer) |
|
|
print(model) |
|
|
optimizer = generate_optimizer(args, args.learning_rate, model.parameters()) |
|
|
start_epoch = 0 |
|
|
dev_eval_dict = {'in_domain': initialize_eval_dict()} |
|
|
test_eval_dict = {'in_domain': initialize_eval_dict()} |
|
|
if args.load_path: |
|
|
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \ |
|
|
load_checkpoint(args, model, optimizer, |
|
|
dev_eval_dict, test_eval_dict, |
|
|
start_epoch, args.load_path, strict=args.strict) |
|
|
if args.load_sequence_taggers_paths: |
|
|
pretrained_dict = {} |
|
|
model_dict = model.state_dict() |
|
|
for idx, path in enumerate(args.load_sequence_taggers_paths): |
|
|
print('Loading saved sequence_tagger from: %s' % path) |
|
|
checkpoint = torch.load(path, map_location=args.device) |
|
|
for k, v in checkpoint['model_state_dict'].items(): |
|
|
if 'rnn_encoder.' in k: |
|
|
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v |
|
|
model_dict.update(pretrained_dict) |
|
|
model.load_state_dict(model_dict) |
|
|
if args.freeze_sequence_taggers: |
|
|
print('Freezing Classifiers') |
|
|
for name, parameter in model.named_parameters(): |
|
|
if 'extra_rnn_encoders' in name: |
|
|
parameter.requires_grad = False |
|
|
if args.freeze_word_embeddings: |
|
|
model.rnn_encoder.word_embedd.weight.requires_grad = False |
|
|
|
|
|
|
|
|
device = args.device |
|
|
model.to(device) |
|
|
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch |
|
|
|
|
|
|
|
|
def initialize_eval_dict(): |
|
|
eval_dict = {} |
|
|
eval_dict['dp_uas'] = 0.0 |
|
|
eval_dict['dp_las'] = 0.0 |
|
|
eval_dict['epoch'] = 0 |
|
|
eval_dict['dp_ucorrect'] = 0.0 |
|
|
eval_dict['dp_lcorrect'] = 0.0 |
|
|
eval_dict['dp_total'] = 0.0 |
|
|
eval_dict['dp_ucomplete_match'] = 0.0 |
|
|
eval_dict['dp_lcomplete_match'] = 0.0 |
|
|
eval_dict['dp_ucorrect_nopunc'] = 0.0 |
|
|
eval_dict['dp_lcorrect_nopunc'] = 0.0 |
|
|
eval_dict['dp_total_nopunc'] = 0.0 |
|
|
eval_dict['dp_ucomplete_match_nopunc'] = 0.0 |
|
|
eval_dict['dp_lcomplete_match_nopunc'] = 0.0 |
|
|
eval_dict['dp_root_correct'] = 0.0 |
|
|
eval_dict['dp_total_root'] = 0.0 |
|
|
eval_dict['dp_total_inst'] = 0.0 |
|
|
eval_dict['dp_total'] = 0.0 |
|
|
eval_dict['dp_total_inst'] = 0.0 |
|
|
eval_dict['dp_total_nopunc'] = 0.0 |
|
|
eval_dict['dp_total_root'] = 0.0 |
|
|
return eval_dict |
|
|
|
|
|
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, |
|
|
best_model, best_optimizer, patient): |
|
|
|
|
|
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results') |
|
|
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \ |
|
|
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and |
|
|
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc']) |
|
|
|
|
|
if is_best_in_domain: |
|
|
for key, value in curr_dev_eval_dict.items(): |
|
|
dev_eval_dict['in_domain'][key] = value |
|
|
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results') |
|
|
for key, value in curr_test_eval_dict.items(): |
|
|
test_eval_dict['in_domain'][key] = value |
|
|
best_model = deepcopy(model) |
|
|
best_optimizer = deepcopy(optimizer) |
|
|
patient = 0 |
|
|
else: |
|
|
patient += 1 |
|
|
if epoch == args.num_epochs: |
|
|
|
|
|
if args.set_num_training_samples is not None: |
|
|
splits_to_write = datasets.keys() |
|
|
else: |
|
|
splits_to_write = ['dev', 'test'] |
|
|
for split in splits_to_write: |
|
|
if split == 'dev': |
|
|
eval_dict = dev_eval_dict['in_domain'] |
|
|
elif split == 'test': |
|
|
eval_dict = test_eval_dict['in_domain'] |
|
|
else: |
|
|
eval_dict = None |
|
|
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict) |
|
|
print("Saving best model") |
|
|
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name) |
|
|
|
|
|
print('\n') |
|
|
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient |
|
|
|
|
|
|
|
|
def evaluation(args, data, split, model, domain, epoch, str_res='results'): |
|
|
|
|
|
model.eval() |
|
|
|
|
|
eval_dict = initialize_eval_dict() |
|
|
eval_dict['epoch'] = epoch |
|
|
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device): |
|
|
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch |
|
|
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths) |
|
|
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner,out_arc, out_arc_tag, mask=masks, length=lengths, |
|
|
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS) |
|
|
lengths = lengths.cpu().numpy() |
|
|
word = word.data.cpu().numpy() |
|
|
pos = pos.data.cpu().numpy() |
|
|
ner = ner.data.cpu().numpy() |
|
|
heads = heads.data.cpu().numpy() |
|
|
arc_tags = arc_tags.data.cpu().numpy() |
|
|
heads_pred = heads_pred.data.cpu().numpy() |
|
|
arc_tags_pred = arc_tags_pred.data.cpu().numpy() |
|
|
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads, |
|
|
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'], |
|
|
lengths, punct_set=args.punct_set, symbolic_root=True) |
|
|
ucorr, lcorr, total, ucm, lcm = stats |
|
|
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc |
|
|
corr_root, total_root = stats_root |
|
|
eval_dict['dp_ucorrect'] += ucorr |
|
|
eval_dict['dp_lcorrect'] += lcorr |
|
|
eval_dict['dp_total'] += total |
|
|
eval_dict['dp_ucomplete_match'] += ucm |
|
|
eval_dict['dp_lcomplete_match'] += lcm |
|
|
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc |
|
|
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc |
|
|
eval_dict['dp_total_nopunc'] += total_nopunc |
|
|
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc |
|
|
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc |
|
|
eval_dict['dp_root_correct'] += corr_root |
|
|
eval_dict['dp_total_root'] += total_root |
|
|
eval_dict['dp_total_inst'] += num_inst |
|
|
|
|
|
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] |
|
|
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] |
|
|
print_results(eval_dict, split, domain, str_res) |
|
|
return eval_dict |
|
|
|
|
|
|
|
|
def print_results(eval_dict, split, domain, str_res='results'): |
|
|
print('----------------------------------------------------------------------------------------------------------------------------') |
|
|
print('Testing model on domain %s' % domain) |
|
|
print('--------------- Dependency Parsing - %s ---------------' % split) |
|
|
print( |
|
|
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % ( |
|
|
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'], |
|
|
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'], |
|
|
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'], |
|
|
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'], |
|
|
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'], |
|
|
eval_dict['epoch'])) |
|
|
print( |
|
|
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % ( |
|
|
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'], |
|
|
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'], |
|
|
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'], |
|
|
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'], |
|
|
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'], |
|
|
eval_dict['epoch'])) |
|
|
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % ( |
|
|
eval_dict['dp_root_correct'], eval_dict['dp_total_root'], |
|
|
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch'])) |
|
|
print('\n') |
|
|
|
|
|
def write_results(args, data, data_domain, split, model, model_domain, eval_dict): |
|
|
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain |
|
|
res_filename = str_file + '_res.txt' |
|
|
pred_filename = str_file + '_pred.txt' |
|
|
gold_filename = str_file + '_gold.txt' |
|
|
if eval_dict is not None: |
|
|
|
|
|
with open(res_filename, 'w') as f: |
|
|
json.dump(eval_dict, f) |
|
|
|
|
|
|
|
|
pred_writer = Writer(args.alphabets) |
|
|
gold_writer = Writer(args.alphabets) |
|
|
pred_writer.start(pred_filename) |
|
|
gold_writer.start(gold_filename) |
|
|
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device): |
|
|
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch |
|
|
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths) |
|
|
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos,ner,out_arc, out_arc_tag, mask=masks, length=lengths, |
|
|
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS) |
|
|
lengths = lengths.cpu().numpy() |
|
|
word = word.data.cpu().numpy() |
|
|
pos = pos.data.cpu().numpy() |
|
|
ner = ner.data.cpu().numpy() |
|
|
heads = heads.data.cpu().numpy() |
|
|
arc_tags = arc_tags.data.cpu().numpy() |
|
|
heads_pred = heads_pred.data.cpu().numpy() |
|
|
arc_tags_pred = arc_tags_pred.data.cpu().numpy() |
|
|
|
|
|
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True) |
|
|
|
|
|
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True) |
|
|
|
|
|
pred_writer.close() |
|
|
gold_writer.close() |
|
|
|
|
|
def main(): |
|
|
logger.info("Reading and creating arguments") |
|
|
args = read_arguments() |
|
|
logger.info("Reading Data") |
|
|
datasets = {} |
|
|
for split in args.splits: |
|
|
print("Splits are:",split) |
|
|
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device, |
|
|
symbolic_root=True) |
|
|
datasets[split] = dataset |
|
|
if args.set_num_training_samples is not None: |
|
|
print('Setting train and dev to %d samples' % args.set_num_training_samples) |
|
|
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples) |
|
|
logger.info("Creating Networks") |
|
|
num_data = sum(datasets['train'][1]) |
|
|
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args) |
|
|
best_model = deepcopy(model) |
|
|
best_optimizer = deepcopy(optimizer) |
|
|
|
|
|
logger.info('Training INFO of in domain %s' % args.domain) |
|
|
logger.info('Training on Dependecy Parsing') |
|
|
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace)) |
|
|
logger.info('number of training samples for %s is: %d' % (args.domain, num_data)) |
|
|
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn)) |
|
|
logger.info("num_epochs: %d" % (args.num_epochs)) |
|
|
print('\n') |
|
|
|
|
|
if not args.eval_mode: |
|
|
logger.info("Training") |
|
|
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size) |
|
|
lr = args.learning_rate |
|
|
patient = 0 |
|
|
decay = 0 |
|
|
for epoch in range(start_epoch + 1, args.num_epochs + 1): |
|
|
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % ( |
|
|
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay)) |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
total_arc_loss = 0.0 |
|
|
total_arc_tag_loss = 0.0 |
|
|
total_train_inst = 0.0 |
|
|
|
|
|
train_iter = prepare_data.iterate_batch_rand_bucket_choosing( |
|
|
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace) |
|
|
start_time = time.time() |
|
|
batch_num = 0 |
|
|
for batch_num, batch in enumerate(train_iter): |
|
|
batch_num = batch_num + 1 |
|
|
optimizer.zero_grad() |
|
|
|
|
|
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch |
|
|
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths) |
|
|
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths) |
|
|
loss = loss_arc + loss_arc_tag |
|
|
|
|
|
|
|
|
num_insts = masks.data.sum() - word.size(0) |
|
|
total_arc_loss += loss_arc.item() * num_insts |
|
|
total_arc_tag_loss += loss_arc_tag.item() * num_insts |
|
|
total_loss += loss.item() * num_insts |
|
|
total_train_inst += num_insts |
|
|
|
|
|
loss.backward() |
|
|
clip_grad_norm_(model.parameters(), args.clip) |
|
|
optimizer.step() |
|
|
|
|
|
time_ave = (time.time() - start_time) / batch_num |
|
|
time_left = (num_batches - batch_num) * time_ave |
|
|
|
|
|
|
|
|
if batch_num % 50 == 0: |
|
|
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \ |
|
|
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst, |
|
|
total_arc_tag_loss / total_train_inst, time_left) |
|
|
sys.stdout.write(log_info) |
|
|
sys.stdout.write('\n') |
|
|
sys.stdout.flush() |
|
|
print('\n') |
|
|
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' % |
|
|
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst, |
|
|
total_arc_tag_loss / total_train_inst, time.time() - start_time)) |
|
|
|
|
|
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient) |
|
|
if patient >= args.schedule: |
|
|
lr = args.learning_rate / (1.0 + epoch * args.decay_rate) |
|
|
optimizer = generate_optimizer(args, lr, model.parameters()) |
|
|
print('updated learning rate to %.6f' % lr) |
|
|
patient = 0 |
|
|
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results') |
|
|
print('\n') |
|
|
for split in datasets.keys(): |
|
|
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results') |
|
|
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict) |
|
|
|
|
|
else: |
|
|
logger.info("Evaluating") |
|
|
epoch = start_epoch |
|
|
for split in ['train', 'dev', 'test','poetry','prose']: |
|
|
eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results') |
|
|
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|