import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

###################### ここを完成させましょう 3 #######################
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gamma', type=float)
args = parser.parse_args()
###################################################################

# ハイパーパラメータ
torch.manual_seed(2025)
batch_size = 256
learning_rate = 1e-4
num_epochs = 50

# STL-10のデータセットとデータローダーの準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(32, 4),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),  # 正規化
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

###################### ここを完成させましょう 1 #######################

# テストデータのうち，ラベルが0のデータと１のデータのインデックス
indices = []
for i in range(len(test_dataset)):
    if test_dataset.targets[i] < 2:
        indices.append(i)
test_dataset = Subset(test_dataset, indices)

# 学習データのうち，ラベルが0のデータと１のデータのインデックス
indices = []
for i in range(len(train_dataset)):
    if train_dataset.targets[i] == 0:
        indices.append(i)
for i in range(len(train_dataset)):
    if train_dataset.targets[i] == 1:
        indices.append(i)
train_dataset = Subset(train_dataset, indices[4900:])

###################################################################

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# モデル

class ResNet(torch.nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(16)
        self.pool1 = torch.nn.MaxPool2d(2,2)
        self.bigblock2 = self.make_bigblock(16, 32, 5)
        self.pool2 = torch.nn.MaxPool2d(2,2)
        self.bigblock3 = self.make_bigblock(32, 64, 5)
        self.fc = torch.nn.Linear(64, 2)
    
    def make_bigblock(self, in_channels, out_channels, num_blocks):
        blocks = []
        for i in range(num_blocks - 1):
            blocks.append(block(in_channels, in_channels))
        blocks.append(block(in_channels, out_channels))
        return torch.nn.Sequential(*blocks)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = self.bigblock2(x)
        x = self.pool2(x)
        x = self.bigblock3(x)
        x = torch.mean(x, dim=(2,3))
        x = self.fc(x)
        return x

class block(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(block, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=False)
        self.conv2 = torch.nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(in_channels)
        self.bn2 = torch.nn.BatchNorm2d(out_channels)
        if in_channels == out_channels:
            self.shortcut = torch.nn.Sequential()
        else:
            self.shortcut = torch.nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)

    def forward(self, x):
        t = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.shortcut(x) + self.conv2(t)))
        return x

###################### ここを完成させましょう 2 #######################
class CRITERION(nn.Module):
    def __init__(self, gamma=2):
        super(CRITERION, self).__init__()
        self.gamma = gamma

    def forward(self, y, t):
        # loss = torch.nn.functional.cross_entropy(y,t)
        t = torch.nn.functional.one_hot(t, y.shape[1])
        y = torch.softmax(y, dim=1)
        loss = - torch.mean(torch.sum(t * (1 - y) ** self.gamma * torch.log(y), dim=1), dim=0)
        return loss

###################################################################

# モデル、損失関数、オプティマイザの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet().to(device)
criterion = CRITERION(args.gamma)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=2e-5)

# 学習ループ
def train_model():
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

# テストループ
def test_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

# 実行
if __name__ == "__main__":
    train_model()
    test_model()
