LCM / examples /GraphParser_MRL.py
shivrajanand's picture
Upload folder using huggingface_hub
e8f4897 verified
from __future__ import print_function
import sys
from os import path, makedirs
sys.path.append(".")
sys.path.append("..")
import argparse
from copy import deepcopy
import json
import numpy as np
import torch
import torch.nn as nn
from collections import namedtuple
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
from utils.models.parsing_gating import BiAffine_Parser_Gated
from utils import load_word_embeddings
from utils.tasks import parse
import time
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, SGD
import uuid
uid = uuid.uuid4().hex[:6]
logger = get_logger('GraphParser')
def read_arguments():
args_ = argparse.ArgumentParser(description='Sovling GraphParser')
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
args_.add_argument('--domain', help='domain/language', required=True)
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
required=True)
args_.add_argument('--gating',action='store_true', help='use gated mechanism')
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
args_.add_argument('--unk_replace', type=float, default=0.,
help='The rate to replace a singleton word with UNK')
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
help='Embedding for words')
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
required=True)
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
required=True)
args_.add_argument('--char_path', help='path for character embedding dict')
args_.add_argument('--pos_path', help='path for pos embedding dict')
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
args_.add_argument('--model_path', help='path for saving model file.', required=True)
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
'exactly the same keys as current model')
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
args = args_.parse_args()
args_dict = {}
args_dict['dataset'] = args.dataset
args_dict['domain'] = args.domain
args_dict['rnn_mode'] = args.rnn_mode
args_dict['gating'] = args.gating
args_dict['num_gates'] = args.num_gates
args_dict['arc_decode'] = args.arc_decode
# args_dict['splits'] = ['train', 'dev', 'test']
args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
args_dict['model_path'] = args.model_path
if not path.exists(args_dict['model_path']):
makedirs(args_dict['model_path'])
args_dict['data_paths'] = {}
if args_dict['dataset'] == 'ontonotes':
data_path = 'data/Pre_MRL/onto_pos_ner_dp'
else:
data_path = 'data/Prep_MRL/ud_pos_ner_dp'
for split in args_dict['splits']:
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
###################################
args_dict['data_paths']['poetry'] = data_path + '_' + 'test' + '_' + args_dict['domain']
args_dict['data_paths']['prose'] = data_path + '_' + 'test' + '_' + args_dict['domain']
###################################
args_dict['alphabet_data_paths'] = {}
for split in args_dict['splits']:
if args_dict['dataset'] == 'ontonotes':
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
else:
if '_' in args_dict['domain']:
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
else:
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
args_dict['model_name'] = 'domain_' + args_dict['domain']
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
args_dict['load_path'] = args.load_path
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
if args_dict['load_sequence_taggers_paths'] is not None:
args_dict['gating'] = True
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
else:
if not args_dict['gating']:
args_dict['num_gates'] = 0
args_dict['strict'] = args.strict
args_dict['num_epochs'] = args.num_epochs
args_dict['batch_size'] = args.batch_size
args_dict['hidden_size'] = args.hidden_size
args_dict['arc_space'] = args.arc_space
args_dict['arc_tag_space'] = args.arc_tag_space
args_dict['num_layers'] = args.num_layers
args_dict['num_filters'] = args.num_filters
args_dict['kernel_size'] = args.kernel_size
args_dict['learning_rate'] = args.learning_rate
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
args_dict['opt'] = args.opt
args_dict['momentum'] = args.momentum
args_dict['betas'] = tuple(args.betas)
args_dict['epsilon'] = args.epsilon
args_dict['decay_rate'] = args.decay_rate
args_dict['clip'] = args.clip
args_dict['gamma'] = args.gamma
args_dict['schedule'] = args.schedule
args_dict['p_rnn'] = tuple(args.p_rnn)
args_dict['p_in'] = args.p_in
args_dict['p_out'] = args.p_out
args_dict['unk_replace'] = args.unk_replace
args_dict['set_num_training_samples'] = args.set_num_training_samples
args_dict['punct_set'] = None
if args.punct_set is not None:
args_dict['punct_set'] = set(args.punct_set)
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
args_dict['word_embedding'] = args.word_embedding
args_dict['word_path'] = args.word_path
args_dict['use_char'] = args.use_char
args_dict['char_embedding'] = args.char_embedding
args_dict['char_path'] = args.char_path
args_dict['pos_embedding'] = args.pos_embedding
args_dict['pos_path'] = args.pos_path
args_dict['use_pos'] = args.use_pos
args_dict['pos_dim'] = args.pos_dim
args_dict['word_dict'] = None
args_dict['word_dim'] = args.word_dim
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
args_dict['word_path'])
args_dict['char_dict'] = None
args_dict['char_dim'] = args.char_dim
if args_dict['char_embedding'] != 'random':
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
args_dict['char_path'])
args_dict['pos_dict'] = None
if args_dict['pos_embedding'] != 'random':
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
args_dict['pos_path'])
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
args_dict['eval_mode'] = args.eval_mode
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
logger.info("Saving arguments to file")
save_args(args, args_dict['full_model_name'])
logger.info("Creating Alphabets")
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
args_dict = {**args_dict, **alphabet_dict}
ARGS = namedtuple('ARGS', args_dict.keys())
my_args = ARGS(**args_dict)
return my_args
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
train_paths = alphabet_data_paths['train']
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
alphabet_dict = {}
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
train_paths,
extra_paths=extra_paths,
max_vocabulary_size=100000,
embedd_dict=word_dict)
for k, v in alphabet_dict['alphabets'].items():
num_key = 'num_' + k.split('_')[0]
alphabet_dict[num_key] = v.size()
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
return alphabet_dict
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
if tokens_dict is None:
return None
scale = np.sqrt(3.0 / dim)
table = np.empty([alphabet.size(), dim], dtype=np.float32)
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
oov_tokens = 0
for token, index in alphabet.items():
if token in ['aTA']:
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
oov_tokens += 1
elif token in tokens_dict:
embedding = tokens_dict[token]
elif token.lower() in tokens_dict:
embedding = tokens_dict[token.lower()]
else:
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
oov_tokens += 1
# print(token)
table[index, :] = embedding
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
table = torch.from_numpy(table)
return table
def save_args(args, full_model_name):
arg_path = full_model_name + '.arg.json'
argparse_dict = vars(args)
with open(arg_path, 'w') as f:
json.dump(argparse_dict, f)
def generate_optimizer(args, lr, params):
params = filter(lambda param: param.requires_grad, params)
if args.opt == 'adam':
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
elif args.opt == 'sgd':
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
else:
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
path_name = full_model_name + '.pt'
print('Saving model to: %s' % path_name)
state = {'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'opt': opt,
'dev_eval_dict': dev_eval_dict,
'test_eval_dict': test_eval_dict}
torch.save(state, path_name)
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
print('Loading saved model from: %s' % load_path)
checkpoint = torch.load(load_path, map_location=args.device)
if checkpoint['opt'] != args.opt:
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
if strict:
generate_optimizer(args, args.learning_rate, model.parameters())
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(args.device)
dev_eval_dict = checkpoint['dev_eval_dict']
test_eval_dict = checkpoint['test_eval_dict']
start_epoch = dev_eval_dict['in_domain']['epoch']
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
def build_model_and_optimizer(args):
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
args.num_filters, args.kernel_size, args.rnn_mode,
args.hidden_size, args.num_layers, args.num_arc,
args.arc_space, args.arc_tag_space, args.num_gates,
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
print(model)
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
start_epoch = 0
dev_eval_dict = {'in_domain': initialize_eval_dict()}
test_eval_dict = {'in_domain': initialize_eval_dict()}
if args.load_path:
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
load_checkpoint(args, model, optimizer,
dev_eval_dict, test_eval_dict,
start_epoch, args.load_path, strict=args.strict)
if args.load_sequence_taggers_paths:
pretrained_dict = {}
model_dict = model.state_dict()
for idx, path in enumerate(args.load_sequence_taggers_paths):
print('Loading saved sequence_tagger from: %s' % path)
checkpoint = torch.load(path, map_location=args.device)
for k, v in checkpoint['model_state_dict'].items():
if 'rnn_encoder.' in k:
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if args.freeze_sequence_taggers:
print('Freezing Classifiers')
for name, parameter in model.named_parameters():
if 'extra_rnn_encoders' in name:
parameter.requires_grad = False
if args.freeze_word_embeddings:
model.rnn_encoder.word_embedd.weight.requires_grad = False
# model.rnn_encoder.char_embedd.weight.requires_grad = False
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
device = args.device
model.to(device)
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
def initialize_eval_dict():
eval_dict = {}
eval_dict['dp_uas'] = 0.0
eval_dict['dp_las'] = 0.0
eval_dict['epoch'] = 0
eval_dict['dp_ucorrect'] = 0.0
eval_dict['dp_lcorrect'] = 0.0
eval_dict['dp_total'] = 0.0
eval_dict['dp_ucomplete_match'] = 0.0
eval_dict['dp_lcomplete_match'] = 0.0
eval_dict['dp_ucorrect_nopunc'] = 0.0
eval_dict['dp_lcorrect_nopunc'] = 0.0
eval_dict['dp_total_nopunc'] = 0.0
eval_dict['dp_ucomplete_match_nopunc'] = 0.0
eval_dict['dp_lcomplete_match_nopunc'] = 0.0
eval_dict['dp_root_correct'] = 0.0
eval_dict['dp_total_root'] = 0.0
eval_dict['dp_total_inst'] = 0.0
eval_dict['dp_total'] = 0.0
eval_dict['dp_total_inst'] = 0.0
eval_dict['dp_total_nopunc'] = 0.0
eval_dict['dp_total_root'] = 0.0
return eval_dict
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
best_model, best_optimizer, patient):
# In-domain evaluation
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
if is_best_in_domain:
for key, value in curr_dev_eval_dict.items():
dev_eval_dict['in_domain'][key] = value
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
for key, value in curr_test_eval_dict.items():
test_eval_dict['in_domain'][key] = value
best_model = deepcopy(model)
best_optimizer = deepcopy(optimizer)
patient = 0
else:
patient += 1
if epoch == args.num_epochs:
# save in-domain checkpoint
if args.set_num_training_samples is not None:
splits_to_write = datasets.keys()
else:
splits_to_write = ['dev', 'test']
for split in splits_to_write:
if split == 'dev':
eval_dict = dev_eval_dict['in_domain']
elif split == 'test':
eval_dict = test_eval_dict['in_domain']
else:
eval_dict = None
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
print("Saving best model")
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
print('\n')
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
# evaluate performance on data
model.eval()
eval_dict = initialize_eval_dict()
eval_dict['epoch'] = epoch
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner,out_arc, out_arc_tag, mask=masks, length=lengths,
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
lengths = lengths.cpu().numpy()
word = word.data.cpu().numpy()
pos = pos.data.cpu().numpy()
ner = ner.data.cpu().numpy()
heads = heads.data.cpu().numpy()
arc_tags = arc_tags.data.cpu().numpy()
heads_pred = heads_pred.data.cpu().numpy()
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
lengths, punct_set=args.punct_set, symbolic_root=True)
ucorr, lcorr, total, ucm, lcm = stats
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
corr_root, total_root = stats_root
eval_dict['dp_ucorrect'] += ucorr
eval_dict['dp_lcorrect'] += lcorr
eval_dict['dp_total'] += total
eval_dict['dp_ucomplete_match'] += ucm
eval_dict['dp_lcomplete_match'] += lcm
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
eval_dict['dp_total_nopunc'] += total_nopunc
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
eval_dict['dp_root_correct'] += corr_root
eval_dict['dp_total_root'] += total_root
eval_dict['dp_total_inst'] += num_inst
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
print_results(eval_dict, split, domain, str_res)
return eval_dict
def print_results(eval_dict, split, domain, str_res='results'):
print('----------------------------------------------------------------------------------------------------------------------------')
print('Testing model on domain %s' % domain)
print('--------------- Dependency Parsing - %s ---------------' % split)
print(
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
eval_dict['epoch']))
print(
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
eval_dict['epoch']))
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
print('\n')
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
res_filename = str_file + '_res.txt'
pred_filename = str_file + '_pred.txt'
gold_filename = str_file + '_gold.txt'
if eval_dict is not None:
# save results dictionary into a file
with open(res_filename, 'w') as f:
json.dump(eval_dict, f)
# save predictions and gold labels into files
pred_writer = Writer(args.alphabets)
gold_writer = Writer(args.alphabets)
pred_writer.start(pred_filename)
gold_writer.start(gold_filename)
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos,ner,out_arc, out_arc_tag, mask=masks, length=lengths,
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
lengths = lengths.cpu().numpy()
word = word.data.cpu().numpy()
pos = pos.data.cpu().numpy()
ner = ner.data.cpu().numpy()
heads = heads.data.cpu().numpy()
arc_tags = arc_tags.data.cpu().numpy()
heads_pred = heads_pred.data.cpu().numpy()
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
# writing predictions
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
# writing gold labels
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
pred_writer.close()
gold_writer.close()
def main():
logger.info("Reading and creating arguments")
args = read_arguments()
logger.info("Reading Data")
datasets = {}
for split in args.splits:
print("Splits are:",split)
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
symbolic_root=True)
datasets[split] = dataset
if args.set_num_training_samples is not None:
print('Setting train and dev to %d samples' % args.set_num_training_samples)
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
logger.info("Creating Networks")
num_data = sum(datasets['train'][1])
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
best_model = deepcopy(model)
best_optimizer = deepcopy(optimizer)
logger.info('Training INFO of in domain %s' % args.domain)
logger.info('Training on Dependecy Parsing')
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
logger.info("num_epochs: %d" % (args.num_epochs))
print('\n')
if not args.eval_mode:
logger.info("Training")
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
lr = args.learning_rate
patient = 0
decay = 0
for epoch in range(start_epoch + 1, args.num_epochs + 1):
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
model.train()
total_loss = 0.0
total_arc_loss = 0.0
total_arc_tag_loss = 0.0
total_train_inst = 0.0
train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
start_time = time.time()
batch_num = 0
for batch_num, batch in enumerate(train_iter):
batch_num = batch_num + 1
optimizer.zero_grad()
# compute loss of main task
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
loss = loss_arc + loss_arc_tag
# update losses
num_insts = masks.data.sum() - word.size(0)
total_arc_loss += loss_arc.item() * num_insts
total_arc_tag_loss += loss_arc_tag.item() * num_insts
total_loss += loss.item() * num_insts
total_train_inst += num_insts
# optimize parameters
loss.backward()
clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
time_ave = (time.time() - start_time) / batch_num
time_left = (num_batches - batch_num) * time_ave
# update log
if batch_num % 50 == 0:
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
total_arc_tag_loss / total_train_inst, time_left)
sys.stdout.write(log_info)
sys.stdout.write('\n')
sys.stdout.flush()
print('\n')
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
total_arc_tag_loss / total_train_inst, time.time() - start_time))
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
if patient >= args.schedule:
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
optimizer = generate_optimizer(args, lr, model.parameters())
print('updated learning rate to %.6f' % lr)
patient = 0
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
print('\n')
for split in datasets.keys():
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
else:
logger.info("Evaluating")
epoch = start_epoch
for split in ['train', 'dev', 'test','poetry','prose']:
eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
if __name__ == '__main__':
main()