import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import time

################## ここを作成する ##################

from torchvision import datasets, transforms
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(144),
    transforms.CenterCrop(128),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
])

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(144),
    # transforms.CenterCrop(128),
    transforms.RandomCrop(128),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(2),
    transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
])

ds = datasets.ImageFolder(root='./data/CUB_200_2011/images', transform=transform)
ds_train = datasets.ImageFolder(root='./data/CUB_200_2011/images', transform=transform_train)
print("Number of images", len(ds))

num_train = 6000
num_val = 1000
torch.manual_seed(2025)
idx = torch.randperm(len(ds))
idx_train = idx[: num_train]
idx_val = idx[num_train : num_train + num_val]
idx_test = idx[num_train + num_val :]

ds_train = Subset(ds_train, idx_train)
ds_val = Subset(ds, idx_val)
ds_test = Subset(ds, idx_test)

print("Number of training images", len(ds_train))
print("Number of validation images", len(ds_val))
print("Number of test images", len(ds_test))
print("First 10 indices of training dataset", idx_train[:10])

batch_size = 64
num_workers = 4
train_loader = DataLoader(ds_train, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(ds_val, batch_size=batch_size, num_workers=4, shuffle=False)
test_loader = DataLoader(ds_test, batch_size=batch_size, num_workers=4, shuffle=False)

##################################################

# 学習のハイパーパラメータ
learning_rate = 5e-4
num_epochs = 20

# モデル、損失関数、オプティマイザの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18()
model.fc.__init__(model.fc.in_features, 200)
model.to(device)
# print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=2e-4, betas=(0.9, 0.999))

# 学習
def train_model():
    best_val_accuracy = 0
    best_state_dict = model.state_dict()
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        val_accuracy = test_model(val_loader)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_state_dict = model.state_dict()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, Val accuracy: {val_accuracy:.4f}, Best val accuracy: {best_val_accuracy:.4f}")
    model.load_state_dict(best_state_dict)

# テスト
def test_model(dl):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dl:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc

# 実行
if __name__ == "__main__":
    tm = time.time()
    train_model()
    tm = time.time() - tm
    acc = test_model(test_loader)
    print(f"Test Accuracy: {acc:.2f}%, Training time: {tm:.2f}sec")
