import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import time

################## ここを作成する ##################


##################################################

# 学習のハイパーパラメータ
learning_rate = 5e-4
num_epochs = 20

# モデル、損失関数、オプティマイザの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18()
model.fc.__init__(model.fc.in_features, 200)
model.to(device)
# print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=2e-4, betas=(0.9, 0.999))

# 学習
def train_model():
    best_val_accuracy = 0
    best_state_dict = model.state_dict()
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        val_accuracy = test_model(val_loader)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_state_dict = model.state_dict()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, Val accuracy: {val_accuracy:.4f}, Best val accuracy: {best_val_accuracy:.4f}")
    model.load_state_dict(best_state_dict)

# テスト
def test_model(dl):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dl:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc

# 実行
if __name__ == "__main__":
    tm = time.time()
    train_model()
    tm = time.time() - tm
    acc = test_model(test_loader)
    print(f"Test Accuracy: {acc:.2f}%, Training time: {tm:.2f}sec")
