# ImageNet class 対応表
# https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/

import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2



## 事前学習済みモデルをロード，畳み込み層の情報を抽出
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()
module_list = []
count = 0
for m in model.named_modules():
    if type(m[1]) == torch.nn.Conv2d:
        module_list.append(m[1])
        print('conv_layer_idx:', count, '\t', m[1])
        count += 1
    if type(m[1]) == torch.nn.ReLU:
        m[1].inplace=False



## 画像データ（1枚）をロード
# JPG画像のパス
image_path = "sample3.jpg"

# 画像をPillow形式で読み込む
image = Image.open(image_path)
print('画像サイズ（縦x横）', image.width, image.height)

# 画像の前処理（リサイズ、Tensor変換、正規化を含む）
transform = transforms.Compose([
    transforms.Resize((224, 224)),      # ResNetの入力サイズにリサイズ
    transforms.ToTensor(),              # Tensorに変換（[0, 1]の範囲にスケーリング）
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],     # RGBの平均値
        std=[0.229, 0.224, 0.225]       # RGBの標準偏差
    ),
])

# 画像をTensorに変換
input_tensor = transform(image).unsqueeze(0)  # バッチ次元を追加



## 推論を実行する関数
def infer(input_tensor):
    with torch.no_grad():
        outputs = model(input_tensor)

    # ソフトマックスで確率を計算
    probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

    # 最も確率が高いクラスを取得
    top1_prob, top1_class = probabilities.max(0)

    print(f"Predicted class: {top1_class.item()}, Probability: {top1_prob.item():.4f}")

    return top1_class.item()



## Grad-CAM を実行する関数
def gradcam(model, input_tensor, conv_layer_idx, class_idx):
    # この関数をは，画像データ（テンソル），Grad-CAMを実行する層，対象のクラスを受け取り，Grad-CAMを実行し，
    # 結果のヒートマップ画像および，本画像にヒートマップを重ねたものを

    # conv_layer_idxで指定した層における入力inpと，inpに対する勾配grad_inpを取り出すための準備
    def fhook(module, inp, out):
        module.inp = inp[0].detach()

    def bhook(module, grad_inp, grad_out):
        module.grad_inp = grad_inp[0].detach()

    mod = module_list[conv_layer_idx]
    mod.register_forward_hook(fhook)
    mod.register_full_backward_hook(bhook)

    # ここの処理（Grad-CAMの処理）を完成させてください．
    out = model(input_tensor)
    out[0,class_idx].backward()
    with torch.no_grad():
        weight = mod.grad_inp.mean(-2, keepdim=True).mean(-1, keepdim=True)
        heatmap = (mod.inp * weight).mean(1)
        heatmap = torch.relu(heatmap)
    
    # 結果画像を元のサイズに戻して保存
    heatmap = transforms.Resize((image.height, image.width))(heatmap)
    heatmap = heatmap / heatmap.max()
    heatmap = heatmap.cpu().numpy()
    input_tensor = transforms.Resize((image.height, image.width))(input_tensor)
    input_tensor = input_tensor.squeeze(0).cpu().numpy()
    imagexheatmap = input_tensor * heatmap
    imagexheatmap = imagexheatmap.transpose(1,2,0)
    imagexheatmap = imagexheatmap[:,:,[2,1,0]]
    heatmap = heatmap.transpose(1,2,0)
    cv2.imwrite('imagexheatmap.png', imagexheatmap * 255)
    cv2.imwrite('heatmap.png', heatmap * 255)



## 推論，Grad-CAMの実行
top1_class = infer(input_tensor)
gradcam(model, input_tensor, 19, top1_class)

# gradcam()の引数を色々変えてみたり，入力画像を変えてみたりして，色々試しましょう！
