1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| def main(): if len(sys.argv) > 1: name = ' '.join(sys.argv[1:]) else: from datetime import datetime name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") target_name = os.path.join('logs', '{}.pth'.format(name)) print('will save to {}'.format(target_name))
cudnn.benchmark = True
train_loader = data.get_loader(train=True) val_loader = data.get_loader(val=True) net = nn.DataParallel(model.Net(train_loader.dataset.num_tokens)).cuda() optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad])
tracker = utils.Tracker() config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
for i in range(config.epochs): _ = run(net, train_loader, optimizer, tracker, train=True, prefix='train', epoch=i) r = run(net, val_loader, optimizer, tracker, train=False, prefix='val', epoch=i)
results = { 'name': name, 'tracker': tracker.to_dict(), 'config': config_as_dict, 'weights': net.state_dict(), 'eval': { 'answers': r[0], 'accuracies': r[1], 'idx': r[2], }, 'vocab': train_loader.dataset.vocab, } torch.save(results, target_name) def run(net, loader, optimizer, tracker, train=False, prefix='', epoch=0): """ Run an epoch over the given loader """ if train: net.train() tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} else: net.eval() tracker_class, tracker_params = tracker.MeanMonitor, {} answ = [] idxs = [] accs = []
tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params))
log_softmax = nn.LogSoftmax().cuda() for v, q, a, idx, q_len in tq: var_params = { 'volatile': not train, 'requires_grad': False, } v = Variable(v.cuda(async=True), **var_params) q = Variable(q.cuda(async=True), **var_params) a = Variable(a.cuda(async=True), **var_params) q_len = Variable(q_len.cuda(async=True), **var_params)
out = net(v, q, q_len) nll = -log_softmax(out) loss = (nll * a / 10).sum(dim=1).mean() acc = utils.batch_accuracy(out.data, a.data).cpu()
if train: global total_iterations update_learning_rate(optimizer, total_iterations)
optimizer.zero_grad() loss.backward() optimizer.step()
total_iterations += 1 else: _, answer = out.data.cpu().max(dim=1) answ.append(answer.view(-1)) accs.append(acc.view(-1)) idxs.append(idx.view(-1).clone())
loss_tracker.append(loss.data[0]) for a in acc: acc_tracker.append(a.item()) fmt = '{:.4f}'.format tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value))
if not train: answ = list(torch.cat(answ, dim=0)) accs = list(torch.cat(accs, dim=0)) idxs = list(torch.cat(idxs, dim=0)) return answ, accs, idxs
|