import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaLLM
from langchain.chains import RetrievalQA

# Step 1: Load PDF and split into documents
def load_pdf(path):
    loader = PyPDFLoader(path)
    docs = loader.load_and_split()
    return docs

# Step 2: Create FAISS vector store from documents
def create_vectorstore(docs):
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    texts = splitter.split_documents(docs)
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vectordb = FAISS.from_documents(texts, embedding=embeddings)
    return vectordb

# Step 3: Load Ollama LLM
def load_llm():
    return OllamaLLM(model="llama3")

# Step 4: Build Retrieval QA Chain
def create_qa_chain(llm, vectordb):
    retriever = vectordb.as_retriever(search_kwargs={"k": 3})
    return RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

# Handle PDF processing
def process_pdf(file_path):
    try:
        docs = load_pdf(file_path)
        vectordb = create_vectorstore(docs)
        llm = load_llm()
        qa_chain = create_qa_chain(llm, vectordb)
        return f"✅ PDF processed. Ready to ask questions.", qa_chain
    except Exception as e:
        return f"❌ Error while processing PDF: {e}", None

# Handle Q&A interaction
def answer_question(question, qa_chain):
    if not qa_chain:
        return "⚠️ Please upload and process a PDF first."
    if not question.strip():
        return "⚠️ Please enter a question."
    
    try:
        result = qa_chain.invoke({"query": question})
        return result["result"] if isinstance(result, dict) else result
    except Exception as e:
        return f"❌ Error while answering: {e}"

# Build Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# 📄 Ask Questions About Your PDF")
    gr.Markdown("Built with **Ollama + LLaMA 3 + HuggingFace + FAISS**")

    qa_chain_state = gr.State()

    pdf_input = gr.File(label="📎 Upload a PDF", file_types=[".pdf"])
    upload_button = gr.Button("📤 Process PDF")
    status_output = gr.Textbox(label="Status", lines=2)

    user_question = gr.Textbox(label="💬 Ask a question about the PDF")
    ask_button = gr.Button("🤖 Get Answer")
    answer_output = gr.Textbox(label="📘 Answer", lines=4)

    upload_button.click(fn=process_pdf, inputs=pdf_input, outputs=[status_output, qa_chain_state])
    ask_button.click(fn=answer_question, inputs=[user_question, qa_chain_state], outputs=answer_output)

demo.launch()
