| import time
import numpy as np import torch import torch.nn as nn import pandas as pd from torch.nn.utils.rnn import pack_padded_sequence from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt
start_time = time.time() HIDDEN_SIZE = 100 BATCH_SIZE = 256 N_LAYER = 2 N_EPOCH = 5 N_CHAR = 128 USE_GPU = False
class NameDataset(Dataset): def __init__(self, filename): df = pd.read_csv(filename, header=None) self.len = df.shape[0] self.names = df[0] self.countries = df[1] self.country_list = list(sorted(set(self.countries))) self.country_dict = self.get_country_dict() self.country_num = len(self.country_list)
def __getitem__(self, idx): return self.names[idx], self.country_dict[self.countries[idx]]
def __len__(self): return self.len
def get_country_dict(self): country_dict = dict() for idx, country_name in enumerate(self.country_list): country_dict[country_name] = idx return country_dict
def idx2country(self, idx): return self.country_list[idx]
def get_countries_num(self): return self.country_num
train_dataset = NameDataset('./dataset/name/names_train.csv') train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_dataset = NameDataset('./dataset/name/names_test.csv') test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
N_COUNTRY = train_dataset.get_countries_num()
class RNNClassifier(nn.Module): def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True): super(RNNClassifier, self).__init__() self.hidden_size = hidden_size self.n_layers = n_layers self.n_directions = 2 if bidirectional else 1
self.embedding = nn.Embedding(input_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional) self.fc = nn.Linear(hidden_size * self.n_directions, output_size)
def _init_hidden(self, batch_size): hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size) return hidden
def forward(self, input, seq_lengths): input = input.t() batch_size = input.size(1) hidden = self._init_hidden(batch_size) embedding = self.embedding(input)
gru_input = pack_padded_sequence(embedding, seq_lengths)
output, hidden = self.gru(gru_input, hidden) if self.n_directions == 2: hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1) else: hidden_cat = hidden[-1] fc_output = self.fc(hidden_cat) return fc_output
classifier = RNNClassifier(N_CHAR, HIDDEN_SIZE, N_COUNTRY, N_LAYER) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
def time_since(since): s = time.time() - since return '%dm %ds' % (s // 60, s % 60)
def name2list(name): arr = [ord(c) for c in name] return arr, len(arr)
def make_tensors(names, countries): sequences_and_lengths = [name2list(name) for name in names] name_sequences = [sl[0] for sl in sequences_and_lengths] seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths]) countries = countries.long()
seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long() for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths)): seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True) seq_tensor = seq_tensor[perm_idx] countries = countries[perm_idx] return seq_tensor, seq_lengths, countries
def train(): total_loss = 0 for i, (names, countries) in enumerate(train_loader, 1): inputs, seq_lengths, target = make_tensors(names, countries) output = classifier(inputs, seq_lengths) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step()
total_loss += loss.item() if i % 10 == 0: print(f'[{time_since(start_time)}] Epoch {epoch} ', end='') print(f'[{i * len(inputs)}/{len(train_dataset)}] ', end='') print(f'loss={total_loss / (i * len(inputs))}')
def test(): correct = 0 total = len(test_dataset) print('evaluating trained model...') with torch.no_grad(): for i, (names, countries) in enumerate(test_loader, 1): inputs, seq_lengths, target = make_tensors(names, countries) output = classifier(inputs, seq_lengths) pred = output.max(dim=1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() percent = '%.2f' % (100 * correct / total) print(f'Test set: Accuracy {correct}/{total} {percent}%') return correct / total
if __name__ == "__main__": import os
max_acc = 0 filename = './model/name_max_acc.txt' if not os.path.isfile(filename): with open(filename, 'w') as f: f.write(str(max_acc)) else: with open(filename, 'r') as f: for line in f: max_acc = float(line) print('max_acc', max_acc) if os.path.exists('./model/name.pkl'): classifier.load_state_dict(torch.load('./model/name.pkl')) print('Training for %d epochs...' % N_EPOCH) acc_list = [] for epoch in range(1, N_EPOCH + 1): train() acc = test() acc_list.append(acc) if acc > max_acc: max_acc = acc with open(filename, 'w') as f: f.write(str(max_acc)) print("update model") torch.save(classifier.state_dict(), './model/name.pkl') epoch = np.arange(1, len(acc_list) + 1, 1) acc_list = np.array(acc_list) plt.plot(epoch, acc_list) plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.grid() plt.show()