File size: 33,396 Bytes
e8f4897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 |
from __future__ import print_function
import sys
from os import path, makedirs
sys.path.append(".")
sys.path.append("..")
import argparse
from copy import deepcopy
import json
import numpy as np
import torch
import torch.nn as nn
from collections import namedtuple
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
from utils.models.parsing_gating import BiAffine_Parser_Gated
from utils import load_word_embeddings
from utils.tasks import parse
import time
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, SGD
import uuid
uid = uuid.uuid4().hex[:6]
logger = get_logger('GraphParser')
def read_arguments():
args_ = argparse.ArgumentParser(description='Sovling GraphParser')
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
args_.add_argument('--domain', help='domain/language', required=True)
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
required=True)
args_.add_argument('--gating',action='store_true', help='use gated mechanism')
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
args_.add_argument('--unk_replace', type=float, default=0.,
help='The rate to replace a singleton word with UNK')
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
help='Embedding for words')
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
required=True)
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
required=True)
args_.add_argument('--char_path', help='path for character embedding dict')
args_.add_argument('--pos_path', help='path for pos embedding dict')
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
args_.add_argument('--model_path', help='path for saving model file.', required=True)
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
'exactly the same keys as current model')
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
args = args_.parse_args()
args_dict = {}
args_dict['dataset'] = args.dataset
args_dict['domain'] = args.domain
args_dict['rnn_mode'] = args.rnn_mode
args_dict['gating'] = args.gating
args_dict['num_gates'] = args.num_gates
args_dict['arc_decode'] = args.arc_decode
# args_dict['splits'] = ['train', 'dev', 'test']
args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
args_dict['model_path'] = args.model_path
if not path.exists(args_dict['model_path']):
makedirs(args_dict['model_path'])
args_dict['data_paths'] = {}
if args_dict['dataset'] == 'ontonotes':
data_path = 'data/Pre_MRL/onto_pos_ner_dp'
else:
data_path = 'data/Prep_MRL/ud_pos_ner_dp'
for split in args_dict['splits']:
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
###################################
args_dict['data_paths']['poetry'] = data_path + '_' + 'test' + '_' + args_dict['domain']
args_dict['data_paths']['prose'] = data_path + '_' + 'test' + '_' + args_dict['domain']
###################################
args_dict['alphabet_data_paths'] = {}
for split in args_dict['splits']:
if args_dict['dataset'] == 'ontonotes':
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
else:
if '_' in args_dict['domain']:
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
else:
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
args_dict['model_name'] = 'domain_' + args_dict['domain']
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
args_dict['load_path'] = args.load_path
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
if args_dict['load_sequence_taggers_paths'] is not None:
args_dict['gating'] = True
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
else:
if not args_dict['gating']:
args_dict['num_gates'] = 0
args_dict['strict'] = args.strict
args_dict['num_epochs'] = args.num_epochs
args_dict['batch_size'] = args.batch_size
args_dict['hidden_size'] = args.hidden_size
args_dict['arc_space'] = args.arc_space
args_dict['arc_tag_space'] = args.arc_tag_space
args_dict['num_layers'] = args.num_layers
args_dict['num_filters'] = args.num_filters
args_dict['kernel_size'] = args.kernel_size
args_dict['learning_rate'] = args.learning_rate
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
args_dict['opt'] = args.opt
args_dict['momentum'] = args.momentum
args_dict['betas'] = tuple(args.betas)
args_dict['epsilon'] = args.epsilon
args_dict['decay_rate'] = args.decay_rate
args_dict['clip'] = args.clip
args_dict['gamma'] = args.gamma
args_dict['schedule'] = args.schedule
args_dict['p_rnn'] = tuple(args.p_rnn)
args_dict['p_in'] = args.p_in
args_dict['p_out'] = args.p_out
args_dict['unk_replace'] = args.unk_replace
args_dict['set_num_training_samples'] = args.set_num_training_samples
args_dict['punct_set'] = None
if args.punct_set is not None:
args_dict['punct_set'] = set(args.punct_set)
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
args_dict['word_embedding'] = args.word_embedding
args_dict['word_path'] = args.word_path
args_dict['use_char'] = args.use_char
args_dict['char_embedding'] = args.char_embedding
args_dict['char_path'] = args.char_path
args_dict['pos_embedding'] = args.pos_embedding
args_dict['pos_path'] = args.pos_path
args_dict['use_pos'] = args.use_pos
args_dict['pos_dim'] = args.pos_dim
args_dict['word_dict'] = None
args_dict['word_dim'] = args.word_dim
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
args_dict['word_path'])
args_dict['char_dict'] = None
args_dict['char_dim'] = args.char_dim
if args_dict['char_embedding'] != 'random':
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
args_dict['char_path'])
args_dict['pos_dict'] = None
if args_dict['pos_embedding'] != 'random':
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
args_dict['pos_path'])
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
args_dict['eval_mode'] = args.eval_mode
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
logger.info("Saving arguments to file")
save_args(args, args_dict['full_model_name'])
logger.info("Creating Alphabets")
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
args_dict = {**args_dict, **alphabet_dict}
ARGS = namedtuple('ARGS', args_dict.keys())
my_args = ARGS(**args_dict)
return my_args
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
train_paths = alphabet_data_paths['train']
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
alphabet_dict = {}
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
train_paths,
extra_paths=extra_paths,
max_vocabulary_size=100000,
embedd_dict=word_dict)
for k, v in alphabet_dict['alphabets'].items():
num_key = 'num_' + k.split('_')[0]
alphabet_dict[num_key] = v.size()
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
return alphabet_dict
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
if tokens_dict is None:
return None
scale = np.sqrt(3.0 / dim)
table = np.empty([alphabet.size(), dim], dtype=np.float32)
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
oov_tokens = 0
for token, index in alphabet.items():
if token in ['aTA']:
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
oov_tokens += 1
elif token in tokens_dict:
embedding = tokens_dict[token]
elif token.lower() in tokens_dict:
embedding = tokens_dict[token.lower()]
else:
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
oov_tokens += 1
# print(token)
table[index, :] = embedding
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
table = torch.from_numpy(table)
return table
def save_args(args, full_model_name):
arg_path = full_model_name + '.arg.json'
argparse_dict = vars(args)
with open(arg_path, 'w') as f:
json.dump(argparse_dict, f)
def generate_optimizer(args, lr, params):
params = filter(lambda param: param.requires_grad, params)
if args.opt == 'adam':
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
elif args.opt == 'sgd':
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
else:
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
path_name = full_model_name + '.pt'
print('Saving model to: %s' % path_name)
state = {'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'opt': opt,
'dev_eval_dict': dev_eval_dict,
'test_eval_dict': test_eval_dict}
torch.save(state, path_name)
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
print('Loading saved model from: %s' % load_path)
checkpoint = torch.load(load_path, map_location=args.device)
if checkpoint['opt'] != args.opt:
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
if strict:
generate_optimizer(args, args.learning_rate, model.parameters())
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(args.device)
dev_eval_dict = checkpoint['dev_eval_dict']
test_eval_dict = checkpoint['test_eval_dict']
start_epoch = dev_eval_dict['in_domain']['epoch']
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
def build_model_and_optimizer(args):
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
args.num_filters, args.kernel_size, args.rnn_mode,
args.hidden_size, args.num_layers, args.num_arc,
args.arc_space, args.arc_tag_space, args.num_gates,
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
print(model)
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
start_epoch = 0
dev_eval_dict = {'in_domain': initialize_eval_dict()}
test_eval_dict = {'in_domain': initialize_eval_dict()}
if args.load_path:
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
load_checkpoint(args, model, optimizer,
dev_eval_dict, test_eval_dict,
start_epoch, args.load_path, strict=args.strict)
if args.load_sequence_taggers_paths:
pretrained_dict = {}
model_dict = model.state_dict()
for idx, path in enumerate(args.load_sequence_taggers_paths):
print('Loading saved sequence_tagger from: %s' % path)
checkpoint = torch.load(path, map_location=args.device)
for k, v in checkpoint['model_state_dict'].items():
if 'rnn_encoder.' in k:
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if args.freeze_sequence_taggers:
print('Freezing Classifiers')
for name, parameter in model.named_parameters():
if 'extra_rnn_encoders' in name:
parameter.requires_grad = False
if args.freeze_word_embeddings:
model.rnn_encoder.word_embedd.weight.requires_grad = False
# model.rnn_encoder.char_embedd.weight.requires_grad = False
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
device = args.device
model.to(device)
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
def initialize_eval_dict():
eval_dict = {}
eval_dict['dp_uas'] = 0.0
eval_dict['dp_las'] = 0.0
eval_dict['epoch'] = 0
eval_dict['dp_ucorrect'] = 0.0
eval_dict['dp_lcorrect'] = 0.0
eval_dict['dp_total'] = 0.0
eval_dict['dp_ucomplete_match'] = 0.0
eval_dict['dp_lcomplete_match'] = 0.0
eval_dict['dp_ucorrect_nopunc'] = 0.0
eval_dict['dp_lcorrect_nopunc'] = 0.0
eval_dict['dp_total_nopunc'] = 0.0
eval_dict['dp_ucomplete_match_nopunc'] = 0.0
eval_dict['dp_lcomplete_match_nopunc'] = 0.0
eval_dict['dp_root_correct'] = 0.0
eval_dict['dp_total_root'] = 0.0
eval_dict['dp_total_inst'] = 0.0
eval_dict['dp_total'] = 0.0
eval_dict['dp_total_inst'] = 0.0
eval_dict['dp_total_nopunc'] = 0.0
eval_dict['dp_total_root'] = 0.0
return eval_dict
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
best_model, best_optimizer, patient):
# In-domain evaluation
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
if is_best_in_domain:
for key, value in curr_dev_eval_dict.items():
dev_eval_dict['in_domain'][key] = value
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
for key, value in curr_test_eval_dict.items():
test_eval_dict['in_domain'][key] = value
best_model = deepcopy(model)
best_optimizer = deepcopy(optimizer)
patient = 0
else:
patient += 1
if epoch == args.num_epochs:
# save in-domain checkpoint
if args.set_num_training_samples is not None:
splits_to_write = datasets.keys()
else:
splits_to_write = ['dev', 'test']
for split in splits_to_write:
if split == 'dev':
eval_dict = dev_eval_dict['in_domain']
elif split == 'test':
eval_dict = test_eval_dict['in_domain']
else:
eval_dict = None
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
print("Saving best model")
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
print('\n')
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
# evaluate performance on data
model.eval()
eval_dict = initialize_eval_dict()
eval_dict['epoch'] = epoch
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner,out_arc, out_arc_tag, mask=masks, length=lengths,
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
lengths = lengths.cpu().numpy()
word = word.data.cpu().numpy()
pos = pos.data.cpu().numpy()
ner = ner.data.cpu().numpy()
heads = heads.data.cpu().numpy()
arc_tags = arc_tags.data.cpu().numpy()
heads_pred = heads_pred.data.cpu().numpy()
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
lengths, punct_set=args.punct_set, symbolic_root=True)
ucorr, lcorr, total, ucm, lcm = stats
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
corr_root, total_root = stats_root
eval_dict['dp_ucorrect'] += ucorr
eval_dict['dp_lcorrect'] += lcorr
eval_dict['dp_total'] += total
eval_dict['dp_ucomplete_match'] += ucm
eval_dict['dp_lcomplete_match'] += lcm
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
eval_dict['dp_total_nopunc'] += total_nopunc
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
eval_dict['dp_root_correct'] += corr_root
eval_dict['dp_total_root'] += total_root
eval_dict['dp_total_inst'] += num_inst
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
print_results(eval_dict, split, domain, str_res)
return eval_dict
def print_results(eval_dict, split, domain, str_res='results'):
print('----------------------------------------------------------------------------------------------------------------------------')
print('Testing model on domain %s' % domain)
print('--------------- Dependency Parsing - %s ---------------' % split)
print(
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
eval_dict['epoch']))
print(
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
eval_dict['epoch']))
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
print('\n')
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
res_filename = str_file + '_res.txt'
pred_filename = str_file + '_pred.txt'
gold_filename = str_file + '_gold.txt'
if eval_dict is not None:
# save results dictionary into a file
with open(res_filename, 'w') as f:
json.dump(eval_dict, f)
# save predictions and gold labels into files
pred_writer = Writer(args.alphabets)
gold_writer = Writer(args.alphabets)
pred_writer.start(pred_filename)
gold_writer.start(gold_filename)
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos,ner,out_arc, out_arc_tag, mask=masks, length=lengths,
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
lengths = lengths.cpu().numpy()
word = word.data.cpu().numpy()
pos = pos.data.cpu().numpy()
ner = ner.data.cpu().numpy()
heads = heads.data.cpu().numpy()
arc_tags = arc_tags.data.cpu().numpy()
heads_pred = heads_pred.data.cpu().numpy()
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
# writing predictions
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
# writing gold labels
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
pred_writer.close()
gold_writer.close()
def main():
logger.info("Reading and creating arguments")
args = read_arguments()
logger.info("Reading Data")
datasets = {}
for split in args.splits:
print("Splits are:",split)
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
symbolic_root=True)
datasets[split] = dataset
if args.set_num_training_samples is not None:
print('Setting train and dev to %d samples' % args.set_num_training_samples)
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
logger.info("Creating Networks")
num_data = sum(datasets['train'][1])
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
best_model = deepcopy(model)
best_optimizer = deepcopy(optimizer)
logger.info('Training INFO of in domain %s' % args.domain)
logger.info('Training on Dependecy Parsing')
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
logger.info("num_epochs: %d" % (args.num_epochs))
print('\n')
if not args.eval_mode:
logger.info("Training")
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
lr = args.learning_rate
patient = 0
decay = 0
for epoch in range(start_epoch + 1, args.num_epochs + 1):
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
model.train()
total_loss = 0.0
total_arc_loss = 0.0
total_arc_tag_loss = 0.0
total_train_inst = 0.0
train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
start_time = time.time()
batch_num = 0
for batch_num, batch in enumerate(train_iter):
batch_num = batch_num + 1
optimizer.zero_grad()
# compute loss of main task
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
loss = loss_arc + loss_arc_tag
# update losses
num_insts = masks.data.sum() - word.size(0)
total_arc_loss += loss_arc.item() * num_insts
total_arc_tag_loss += loss_arc_tag.item() * num_insts
total_loss += loss.item() * num_insts
total_train_inst += num_insts
# optimize parameters
loss.backward()
clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
time_ave = (time.time() - start_time) / batch_num
time_left = (num_batches - batch_num) * time_ave
# update log
if batch_num % 50 == 0:
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
total_arc_tag_loss / total_train_inst, time_left)
sys.stdout.write(log_info)
sys.stdout.write('\n')
sys.stdout.flush()
print('\n')
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
total_arc_tag_loss / total_train_inst, time.time() - start_time))
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
if patient >= args.schedule:
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
optimizer = generate_optimizer(args, lr, model.parameters())
print('updated learning rate to %.6f' % lr)
patient = 0
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
print('\n')
for split in datasets.keys():
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
else:
logger.info("Evaluating")
epoch = start_epoch
for split in ['train', 'dev', 'test','poetry','prose']:
eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
if __name__ == '__main__':
main()
|