# -*- coding: utf-8 -*-
"""Untitled3.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1NcCcdsrb1Ui2sOxzKeKA2cKVZ3cWVEzO
"""

# -*- coding: utf-8 -*-
"""
Démonstration de Fine-tuning d'un modèle de classification d'images sur Google Colab.
Ce script illustre comment adapter un modèle de vision pré-entraîné (ViT)
pour une tâche spécifique de classification (ici, les fleurs du jeu de données oxford_flowers102).

Il est conçu pour être exécuté dans un environnement Google Colab avec accélération GPU
pour des performances optimales.
"""

# 1. Installation des bibliothèques nécessaires
# Ces commandes installent les packages Python requis. Le flag '-q' (quiet)
# minimise la sortie pour une meilleure lisibilité dans Colab.
# - 'transformers': Bibliothèque principale pour les modèles pré-entraînés et le Trainer.
# - 'datasets': Pour charger et manipuler les jeux de données (Oxford Flowers 102).
# - 'accelerate': Aide à rendre l'entraînement plus efficace sur différents matériels.
# - 'evaluate': Pour calculer les métriques de performance.
# - 'torchvision': Fournit des utilitaires pour les jeux de données et transformations d'images avec PyTorch.
!pip install -q transformers datasets accelerate evaluate torchvision

# 2. Importation des bibliothèques
# Importe les classes et fonctions essentielles pour le fine-tuning.
import torch # Bibliothèque de calcul tensoriel principale
from datasets import load_dataset, Image # Pour charger les jeux de données d'images
from transformers import AutoImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor # Transformations d'images (non directement utilisées ici mais dans le flow initial)
import evaluate # Pour les métriques d'évaluation
import numpy as np # Pour les opérations numériques (utilisé dans compute_metrics)
from PIL import Image as PILImage # Pour manipuler les images (utilisé pour convertir en RGB)

print("Bibliothèques importées avec succès.")

# 3. Chargement du jeu de données
# Le jeu de données 'dpdl-benchmark/oxford_flowers102' est téléchargé depuis le
# Hugging Face Hub. Il contient des images de 102 espèces de fleurs différentes.
# Le jeu de données est automatiquement divisé en 'train' (entraînement), 'test' (test), et 'validation'.
print("Chargement du jeu de données oxford_flowers102...")
dataset = load_dataset("dpdl-benchmark/oxford_flowers102")

# Sélection des sous-ensembles 'train' et 'test' pour l'entraînement et l'évaluation.
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# Création des mappings entre les identifiants numériques des classes et leurs noms.
# C'est crucial pour interpréter les prédictions du modèle.
labels = train_dataset.features["label"].names # Noms des 102 espèces de fleurs
label2id = {label: i for i, label in enumerate(labels)} # Mapping Nom -> ID
id2label = {i: label for i, label in enumerate(labels)} # Mapping ID -> Nom

print(f"Nombre de classes disponibles dans le jeu de données : {len(labels)}")
print(f"Les 5 premières classes (exemples) : {labels[:5]}")
print(f"Taille du jeu de données d'entraînement : {len(train_dataset)} images")
print(f"Taille du jeu de données de test : {len(test_dataset)} images")

# 4. Prétraitement des images
# Cette étape prépare les images pour qu'elles soient compatibles avec le modèle ViT.
# Le modèle ViT attend des images d'une taille spécifique, normalisées d'une certaine manière.
checkpoint = "google/vit-base-patch16-224" # Modèle ViT que nous allons fine-tuner
# AutoImageProcessor charge le processeur d'images associé au modèle pré-entraîné.
# Il gère automatiquement les opérations comme le redimensionnement et la normalisation.
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

# Définir la fonction de prétraitement des images appliquée à chaque exemple du jeu de données.
def preprocess_images(examples):
    # Convertit les images Pillow en format RGB pour s'assurer de la cohérence.
    images = [image.convert("RGB") for image in examples["image"]]
    # L'image_processor prend une liste d'images et les prépare (redimensionnement, normalisation, conversion en tenseurs PyTorch).
    # 'return_tensors="pt"' indique de renvoyer des tenseurs PyTorch.
    inputs = image_processor(images, return_tensors="pt")
    # Ajoute les labels (étiquettes) au dictionnaire pour que le Trainer puisse les utiliser.
    inputs["labels"] = examples["label"]
    return inputs

# Appliquer la fonction de prétraitement aux jeux de données d'entraînement et de test.
# 'set_transform' applique cette fonction de manière "paresseuse" (lazy),
# c'est-à-dire que les images sont prétraitées uniquement lorsqu'elles sont accédées.
print("Application des transformations aux jeux de données (création des 'pixel_values')...")
train_dataset.set_transform(preprocess_images)
test_dataset.set_transform(preprocess_images)

# 5. Définition du modèle
# Chargement du modèle Vision Transformer (ViT) pré-entraîné.
# 'ViTForImageClassification' est une tête de classification ajoutée au modèle ViT de base.
print(f"Chargement du modèle pré-entraîné {checkpoint} pour la classification d'images...")
model = ViTForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels), # Définit le nombre de classes pour notre tâche (102 fleurs)
    id2label=id2label,       # Mappe les IDs aux noms de labels
    label2id=label2id,       # Mappe les noms de labels aux IDs
    ignore_mismatched_sizes=True # Permet d'ignorer la taille de la tête de classification
                                 # du modèle pré-entraîné (qui était de 1000 pour ImageNet)
                                 # pour la remplacer par notre tête de 102 classes.
)
# Le message "Some weights of ViTForImageClassification were not initialized..." à l'exécution
# est normal ici, car la tête de classification est réinitialisée pour s'adapter à nos 102 classes.

# 6. Définition des arguments d'entraînement (TrainingArguments)
# Configure les hyperparamètres et les options d'entraînement.
print("Définition des arguments d'entraînement...")
training_args = TrainingArguments(
    output_dir="./flower_classifier", # Répertoire où les checkpoints du modèle et les métriques seront sauvegardés.
    per_device_train_batch_size=16,   # Nombre d'images traitées par GPU (ou CPU) par batch pour l'entraînement.
                                      # Colab avec GPU permet souvent des batch_size plus élevés.
    per_device_eval_batch_size=16,    # Nombre d'images traitées par GPU (ou CPU) par batch pour l'évaluation.
    eval_strategy="epoch",            # La stratégie d'évaluation : ici, évaluation à la fin de chaque époque.
    num_train_epochs=3,               # Nombre total d'époques d'entraînement. Réduit pour une démo rapide.
    logging_steps=10,                 # Fréquence à laquelle les logs d'entraînement sont affichés.
    save_strategy="epoch",            # La stratégie de sauvegarde : ici, sauvegarde du modèle à la fin de chaque époque.
                                      # Doit correspondre à 'eval_strategy' si 'load_best_model_at_end' est True.
    save_total_limit=2,               # Garde les 2 meilleurs checkpoints du modèle pour économiser de l'espace.
    remove_unused_columns=False,      # Important pour conserver la colonne 'image' après le set_transform.
    push_to_hub=False,                # Ne pas pousser le modèle vers Hugging Face Hub.
    load_best_model_at_end=True,      # Charge le meilleur modèle (basé sur 'metric_for_best_model') à la fin de l'entraînement.
    metric_for_best_model="accuracy", # La métrique utilisée pour déterminer le "meilleur" modèle.
    report_to="none",                 # Désactive les intégrations de suivi de l'entraînement (comme Weights & Biases) pour simplifier.
    gradient_accumulation_steps=2,    # Permet de simuler un batch plus grand en accumulant les gradients.
                                      # Si per_device_train_batch_size=16 et gradient_accumulation_steps=2,
                                      # cela simule un batch de (16*2) = 32. Utile si la VRAM est limitée.
    fp16=True,                        # Active l'entraînement en précision mixte (float16) pour accélérer et réduire la consommation de mémoire GPU.
)

# 7. Définition de la fonction de métriques
# Cette fonction est utilisée par le Trainer pour calculer les métriques de performance
# du modèle sur le jeu de données d'évaluation.
metric = evaluate.load("accuracy") # Charge la métrique d' précision (accuracy)

def compute_metrics(eval_pred):
    # eval_pred contient les prédictions (logits) et les vrais labels.
    predictions, labels = eval_pred
    # Convertit les logits en prédictions de classe (l'indice du logit le plus élevé).
    predictions = np.argmax(predictions, axis=1)
    # Calcule la précision.
    return metric.compute(predictions=predictions, references=labels)

# 8. Création du Trainer
# Le Trainer est une classe fournie par Hugging Face pour simplifier l'entraînement
# des modèles Pytorch. Il gère la boucle d'entraînement, l'évaluation, la sauvegarde, etc.
print("Création de l'objet Trainer...")
trainer = Trainer(
    model=model,             # Le modèle à entraîner.
    args=training_args,      # Les arguments d'entraînement définis ci-dessus.
    train_dataset=train_dataset, # Jeu de données pour l'entraînement.
    eval_dataset=test_dataset,   # Jeu de données pour l'évaluation.
    tokenizer=image_processor,   # Le processeur d'images est utilisé comme tokenizer par le Trainer.
    compute_metrics=compute_metrics, # La fonction pour calculer les métriques.
)

# 9. Entraînement du modèle
# Lance le processus de fine-tuning. Cela prendra un certain temps en fonction
# du nombre d'époques et de la performance de votre GPU Colab.
print("Début de l'entraînement du modèle...")
trainer.train()
print("Entraînement terminé.")

# 10. Évaluation du modèle
# Évalue le modèle final sur le jeu de données de test pour obtenir les métriques de performance.
print("Évaluation finale du modèle sur le jeu de données de test...")
metrics = trainer.evaluate(test_dataset)
print(f"Métriques d'évaluation finales : {metrics}")

# 11. Sauvegarde du modèle fine-tuné
# Sauvegarde le modèle (architecture + poids) et son processeur d'images (pour le prétraitement)
# localement. Ces fichiers peuvent être rechargés plus tard pour des inférences.
trainer.save_model("./fine_tuned_flower_model")
image_processor.save_pretrained("./fine_tuned_flower_model")
print("Modèle fine-tuné et son processeur d'images sauvegardés dans './fine_tuned_flower_model'.")

# 12. Exemple d'inférence (utilisation du modèle fine-tuné pour une prédiction)
# Démontre comment utiliser le modèle entraîné sur une nouvelle image.

print("\n--- Exemple d'inférence avec image téléversée ---")

# Importation spécifique pour l'upload de fichiers dans Colab
from google.colab import files
import io # Pour lire les fichiers téléversés

# Demander à l'utilisateur de téléverser une image
print("Veuillez téléverser une image de fleur pour la prédiction.")
uploaded = files.upload()

# Traiter l'image téléversée
for fn in uploaded.keys():
    print(f"Image téléversée : {fn}")
    # Lire l'image en utilisant PIL
    image_bytes = uploaded[fn]
    uploaded_image = PILImage.open(io.BytesIO(image_bytes))

    # Prétraiter l'image pour l'inférence en utilisant le processeur d'images sauvegardé.
    # Assurez-vous que l'image est en RGB
    inputs = image_processor(uploaded_image.convert("RGB"), return_tensors="pt")

    # Déplace les tenseurs vers le même appareil que le modèle (GPU si disponible).
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Effectue la prédiction avec le modèle fine-tuné.
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits # Les logits sont les scores bruts prédits par le modèle pour chaque classe.

    # Calculer les probabilités (confiance) en utilisant softmax
    probabilities = torch.softmax(logits, dim=-1)[0]

    # Obtenir l'ID de la classe prédite et son niveau de confiance
    predicted_label_id = torch.argmax(probabilities).item()
    predicted_confidence = probabilities[predicted_label_id].item()
    predicted_label_name = id2label[predicted_label_id] # Nom de la fleur prédite

    print(f"\nPrédiction du modèle : {predicted_label_name}")
    print(f"Niveau de confiance : {predicted_confidence:.2f}") # Afficher la confiance avec 2 décimales

    # Optionnel : Afficher l'image téléversée
    # import matplotlib.pyplot as plt
    # plt.imshow(uploaded_image)
    # plt.title(f"Prédiction: {predicted_label_name} ({predicted_confidence:.2f})")
    # plt.axis('off')
    # plt.show()

print("\n--- Fin de l'exemple d'inférence ---")