+ 개발

파이토치(Pytorch): RNN 기반 이름 분류기 구현

AI.Logger 2024. 9. 22. 00:15
  • 수업 내용 리마인드 및 아카이빙 목적의 업로드


 

이번 글에서는 PyTorch를 활용해 이름 분류기를 만들어 볼 거예요. 이 실습에서는 이름 데이터를 사용해 특정 언어를 예측하는 모델을 만들어 보고, 학습한 후 모델의 성능을 평가해보겠습니다. 과정을 차근차근 따라가면 어렵지 않게 이해할 수 있으니, 함께 진행해봐요.

1. 필요한 라이브러리 불러오기

우선 실습에 필요한 라이브러리들을 불러올게요. GPU를 사용 가능한 경우, GPU를 사용해 학습 속도를 높일 수 있습니다.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split
from torchinfo import summary

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split

import pandas as pd
import unicodedata
import string
import random
from tqdm import tqdm

# 시드 설정
torch.manual_seed(10)
torch.cuda.manual_seed(10)

# CUDA 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

 

2. 데이터 로드

이번 실습에서는 Pickle 파일 형식으로 저장된 데이터를 사용해요. Pickle은 파이썬의 데이터 직렬화 형식으로, 데이터를 쉽게 저장하고 불러올 수 있습니다.

import pickle

with open('name.pkl', 'rb') as f:
    selected_data = pickle.load(f)

category_labels_selected = [category for _, category in selected_data]
selected_category_names = list(pd.Series(category_labels_selected).unique())
print(selected_category_names)

 

3. 데이터 전처리 및 데이터셋 구성

이름 데이터를 사용하기 위해서는 유니코드ASCII로 변환해주고, 각 문자를 원핫 인코딩 방식으로 변환할 거예요. 이를 위해 커스텀 데이터셋을 정의합니다.

# 모든 가능한 문자 정의
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn' and c in all_letters
    )

def lineToTensor(line):
    tensor = torch.zeros(len(line), n_letters)
    for li, letter in enumerate(line):
        tensor[li][all_letters.find(letter)] = 1
    return tensor

# 커스텀 데이터셋 정의
class NameDataset(Dataset):
    def __init__(self, data, category_names):
        self.data = data
        self.category_names = category_names

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        line, category = self.data[idx]
        category_tensor = torch.tensor(self.category_names.index(category), dtype=torch.long)
        line = unicodeToAscii(line)
        line_tensor = lineToTensor(line)
        return line_tensor, category_tensor

 

4. 데이터 분할 및 로더 생성

데이터를 훈련, 검증, 테스트 세트로 나눈 후 DataLoader를 사용해 모델에 입력할 배치를 만들어줍니다.

train_data, test_val_data, train_labels, test_val_labels = train_test_split(
    selected_data, category_labels_selected, test_size=0.3, stratify=category_labels_selected, random_state=10
)

val_data, test_data, val_labels, test_labels = train_test_split(
    test_val_data, test_val_labels, test_size=0.5, stratify=test_val_labels, random_state=10
)

# 커스텀 데이터셋과 데이터로더 정의
train_dataset = NameDataset(train_data, selected_category_names)
val_dataset = NameDataset(val_data, selected_category_names)
test_dataset = NameDataset(test_data, selected_category_names)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

 

5. RNN 모델 정의

이제 RNN(순환 신경망)을 정의할 차례입니다. 이 모델은 입력된 이름의 시퀀스를 처리하여 해당 이름의 카테고리(언어)를 예측합니다.

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.rnn = nn.RNN(input_size, hidden_size, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        hidden = torch.zeros(self.n_layers, x.size(0), self.hidden_size).to(x.device)
        out, hidden = self.rnn(x, hidden)
        out = out[:, -1, :]  # 마지막 타임스텝의 출력 사용
        out = self.fc(out)
        return out

# 모델 생성
input_size = n_letters
hidden_size = 128
output_size = len(selected_category_names)

model = RNN(input_size, hidden_size, output_size).to(device)

 

6. 모델 학습

훈련 루프를 돌리면서 모델을 학습하고, 검증 손실이 가장 낮을 때의 모델을 저장합니다.

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    for lines, categories in dataloader:
        lines, categories = lines.to(device), categories.to(device)
        optimizer.zero_grad()
        output = model(lines)
        loss = criterion(output, categories)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for lines, categories in dataloader:
            lines, categories = lines.to(device), categories.to(device)
            output = model(lines)
            loss = criterion(output, categories)
            total_loss += loss.item()
    return total_loss / len(dataloader)

# 훈련 루프
n_epochs = 100
best_valid_loss = float('inf')

for epoch in range(1, n_epochs + 1):
    train_loss = train(model, train_loader, criterion, optimizer)
    valid_loss = evaluate(model, val_loader, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best_model.pth')

    print(f"Epoch {epoch}/{n_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}")

 

7. 모델 평가 및 예측

최적의 모델을 저장한 후, 테스트 세트에서 모델의 성능을 평가합니다. 또한 혼동행렬을 시각화하여 모델의 예측 성능을 확인해보겠습니다.

# 저장된 모델 로드
model.load_state_dict(torch.load('best_model.pth'))

# 테스트 세트에서 성능 평가
test_loss = evaluate(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.4f}")

# 혼동행렬 시각화
cm = confusion_matrix(all_labels, all_preds, normalize='true')

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt=".2f", cmap='YlGnBu', xticklabels=selected_category_names, yticklabels=selected_category_names)
plt.title('Normalized Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()