import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from nltk.tokenize import word_tokenize
from nltk.stem.porter import PorterStemmer
import nltk

# Ensure required NLTK resources are downloaded
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

try:
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt_tab')

stemmer = PorterStemmer()

# -------------------------------
# Load your FAQ data
# -------------------------------
with open('courses.json', 'r') as f:
    data = json.load(f)

# -------------------------------
# Preprocessing
# -------------------------------
def tokenize(sentence):
    return word_tokenize(sentence)

def stem(word):
    return stemmer.stem(word.lower())

all_words = []
xy = []

for item in data:
    question = item['question']
    answer = item['answer']
    w = tokenize(question)
    all_words.extend(w)
    xy.append((w, answer))

all_words = [stem(w) for w in all_words if w.isalnum()]
all_words = sorted(set(all_words))
answers = [item['answer'] for item in data]

def bag_of_words(tokenized_sentence, all_words):
    sentence_words = [stem(w) for w in tokenized_sentence]
    bag = np.zeros(len(all_words), dtype=np.float32)
    for idx, w in enumerate(all_words):
        if w in sentence_words:
            bag[idx] = 1
    return bag

X_train = []
y_train = []

for (pattern_sentence, answer) in xy:
    bag = bag_of_words(pattern_sentence, all_words)
    X_train.append(bag)
    y_train.append(answers.index(answer))

X_train = np.array(X_train)
y_train = np.array(y_train)

# -------------------------------
# Define Model
# -------------------------------
class ChatDataset(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ChatDataset, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

input_size = len(all_words)
hidden_size = 8
output_size = len(answers)

model = ChatDataset(input_size, hidden_size, output_size)

# -------------------------------
# Training
# -------------------------------
X_train_tensor = torch.from_numpy(X_train).float()
y_train_tensor = torch.from_numpy(y_train).long()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1000):
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}')

# -------------------------------
# Save model and metadata
# -------------------------------
FILE = "course_chatbot.pth"
torch.save({
    "model_state": model.state_dict(),
    "input_size": input_size,
    "hidden_size": hidden_size,
    "output_size": output_size,
    "all_words": all_words,
    "answers": answers
}, FILE)

print(f"Training complete. Model saved to {FILE}")