import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# CONFIG
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
documents_folder = "documents"
batch_size = 32
block_size = 128
n_embd = 128
n_head = 4
n_layer = 4
dropout = 0.2
learning_rate = 3e-4
max_iters = 2000
save_path = "chat_docs_model.pth"

# -----------------------------
# 1️⃣ Combine all documents
# -----------------------------
all_text = ""
for filename in os.listdir(documents_folder):
    if filename.endswith(".txt"):
        with open(os.path.join(documents_folder, filename), "r", encoding="utf-8") as f:
            all_text += f.read() + "\n"

with open("training_text.txt", "w", encoding="utf-8") as f:
    f.write(all_text)

print(f"Combined {len(os.listdir(documents_folder))} files into training_text.txt")

# -----------------------------
# 2️⃣ Tokenization
# -----------------------------
text = all_text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
data = torch.tensor(encode(text), dtype=torch.long)

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    d = train_data if split=="train" else val_data
    ix = torch.randint(len(d)-block_size,(batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# -----------------------------
# 3️⃣ Mini GPT Model
# -----------------------------
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dim_feedforward=4*n_embd, dropout=dropout) for _ in range(n_layer)]
        )
        self.ln = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B,T = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(T,device=device))
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.head(x)
        if targets is None:
            return logits, None
        B,T,C = logits.shape
        return logits, F.cross_entropy(logits.view(B*T,C), targets.view(B*T))

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits,_ = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits,dim=-1)
            next_idx = torch.multinomial(probs,1)
            idx = torch.cat((idx,next_idx),dim=1)
        return idx

model = GPT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# -----------------------------
# 4️⃣ Training Loop
# -----------------------------
print("Starting training...")
for iter in range(max_iters):
    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if iter % 200 == 0:
        print(f"Step {iter}, Loss {loss.item():.4f}")

torch.save(model.state_dict(), save_path)
print(f"Training finished! Model saved to {save_path}")

# -----------------------------
# 5️⃣ Chat with memory
# -----------------------------
model.eval()
conversation_history = ""

print("\nYou can now chat with your model! Type 'exit' to quit.")
while True:
    prompt = input("You: ")
    if prompt.lower() == "exit":
        break
    conversation_history += "User: " + prompt + "\nBot: "
    context = torch.tensor([encode(conversation_history)], dtype=torch.long).to(device)
    output = model.generate(context, 200)
    response = decode(output[0].tolist())
    response = response[len(conversation_history):]  # only new tokens
    print("Bot:", response.strip())
    conversation_history += response + "\n"