import torch
import torch.nn as nn
from nltk.tokenize import word_tokenize
from nltk.stem.porter import PorterStemmer
import numpy as np

stemmer = PorterStemmer()

# -------------------------------
# Preprocessing functions
# -------------------------------
def tokenize(sentence):
    return word_tokenize(sentence)

def stem(word):
    return stemmer.stem(word.lower())

def bag_of_words(tokenized_sentence, all_words):
    sentence_words = [stem(w) for w in tokenized_sentence]
    bag = np.zeros(len(all_words), dtype=np.float32)
    for idx, w in enumerate(all_words):
        if w in sentence_words:
            bag[idx] = 1
    return bag

# -------------------------------
# Load model and metadata
# -------------------------------
FILE = "course_chatbot.pth"
data = torch.load(FILE)

all_words = data["all_words"]
answers = data["answers"]
input_size = data["input_size"]
hidden_size = data["hidden_size"]
output_size = data["output_size"]

class ChatDataset(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ChatDataset, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = ChatDataset(input_size, hidden_size, output_size)
model.load_state_dict(data["model_state"])
model.eval()

# -------------------------------
# Chat loop with confidence
# -------------------------------
print("Bot is ready! Type 'quit' to exit.")

while True:
    sentence = input("You: ")
    if sentence.lower() == "quit":
        break

    X = bag_of_words(tokenize(sentence), all_words)
    X_tensor = torch.from_numpy(X).float()
    output = model(X_tensor)

    # Apply softmax to get probabilities
    probabilities = torch.softmax(output, dim=0)
    confidence, predicted_index = torch.max(probabilities, dim=0)

    # Set a confidence threshold (e.g., 0.7)
    if confidence.item() > 0.7:
        answer = answers[predicted_index.item()]
    else:
        answer = "I’m not sure about that. Please visit 8149996597 for help."

    print("Bot:", answer)