# src/train.py
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from dataset import get_dataloaders
from model import ColonCancerModel
from utils import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Get project root (absolute path handling)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
data_dir = os.path.join(project_root, 'data')
model_dir = os.path.join(project_root, 'models')
results_dir = os.path.join(project_root, 'results', 'metrics')

os.makedirs(model_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

def train_model(data_dir, num_epochs=10, batch_size=8, lr=1e-4):
    # Get dataloaders
    train_loader, val_loader, test_loader = get_dataloaders(data_dir, batch_size=batch_size)

    # Instantiate model, loss, optimizer
    model = ColonCancerModel().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_acc = 0.0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            preds = (torch.sigmoid(outputs) > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                preds = (torch.sigmoid(outputs) > 0.5).float()
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)

        print(f"Epoch {epoch+1}/{num_epochs}: Train loss {epoch_loss:.4f}, Train acc {epoch_acc:.4f}, "
              f"Val loss {epoch_val_loss:.4f}, Val acc {epoch_val_acc:.4f}")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), os.path.join(model_dir, 'mobilenet_model.pth'))

    # Plot accuracy curves
    plt.figure()
    plt.plot(history['train_acc'], label="Train Accuracy")
    plt.plot(history['val_acc'], label="Val Accuracy")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training and Validation Accuracy")
    plt.savefig(os.path.join(results_dir, 'accuracy.png'))
    plt.close()

    # Evaluate on test set & plot confusion matrix
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            outputs = model(inputs)
            preds = (torch.sigmoid(outputs) > 0.5).float().cpu().numpy()
            all_preds.extend(preds.squeeze())
            all_labels.extend(labels.cpu().numpy().squeeze())

    cm = confusion_matrix(all_labels, np.round(all_preds))
    plot_confusion_matrix(cm, classes=['benign', 'adenocarcinoma'], save_path=os.path.join(results_dir, 'confusion_matrix.png'))

    return model, history

if __name__ == "__main__":
    model, history = train_model(data_dir)
