237 lines
8.4 KiB
Python
237 lines
8.4 KiB
Python
import os
|
||
import faiss
|
||
import numpy as np
|
||
import torch
|
||
import logging
|
||
import pickle
|
||
from transformers import AutoTokenizer, AutoModel
|
||
from ollama import Client
|
||
from string import Template
|
||
from nltk.tokenize import sent_tokenize
|
||
|
||
# Prevent Seg Faults
|
||
os.environ["OMP_NUM_THREADS"] = "1"
|
||
|
||
# Configure logging
|
||
logging.basicConfig(filename='rag_system.log', level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
||
# =============================
|
||
# Step 1: Load and Chunk Documents with Overlap
|
||
# =============================
|
||
|
||
def chunk_text(text, max_tokens=512, overlap=10):
|
||
"""
|
||
Split text into smaller chunks with an overlap to preserve context.
|
||
The function tokenizes by sentences and then assembles them into chunks
|
||
not exceeding max_tokens, but with a specified number of overlapping tokens
|
||
between consecutive chunks.
|
||
"""
|
||
sentences = sent_tokenize(text)
|
||
chunks = []
|
||
current_chunk = []
|
||
current_length = 0
|
||
|
||
for sentence in sentences:
|
||
tokens = sentence.split()
|
||
token_count = len(tokens)
|
||
if current_length + token_count > max_tokens:
|
||
# Append the current chunk
|
||
chunk = " ".join(current_chunk)
|
||
chunks.append(chunk)
|
||
# Create overlap: keep the last few sentences to carry over context
|
||
if overlap > 0:
|
||
overlap_sentences = []
|
||
token_sum = 0
|
||
# Iterate backwards over current_chunk until we reach the desired token overlap
|
||
for sent in reversed(current_chunk):
|
||
sent_tokens = sent.split()
|
||
token_sum += len(sent_tokens)
|
||
overlap_sentences.insert(0, sent)
|
||
if token_sum >= overlap:
|
||
break
|
||
current_chunk = overlap_sentences.copy()
|
||
current_length = sum(len(s.split()) for s in current_chunk)
|
||
else:
|
||
current_chunk = []
|
||
current_length = 0
|
||
current_chunk.append(sentence)
|
||
current_length += token_count
|
||
|
||
if current_chunk:
|
||
chunks.append(" ".join(current_chunk))
|
||
return chunks
|
||
|
||
def load_documents(directory='./documents'):
|
||
"""
|
||
Loads all .txt files from the specified directory.
|
||
For each file, the function returns a dictionary with keys 'source' and 'text'
|
||
for each chunk.
|
||
"""
|
||
documents = []
|
||
for filename in os.listdir(directory):
|
||
if filename.endswith('.txt'):
|
||
file_path = os.path.join(directory, filename)
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
content = file.read() # Optionally, you could clean up headers or metadata here
|
||
chunks = chunk_text(content)
|
||
for chunk in chunks:
|
||
documents.append({
|
||
'source': filename,
|
||
'text': chunk
|
||
})
|
||
except Exception as e:
|
||
logging.error(f"Error reading {file_path}: {e}")
|
||
return documents
|
||
|
||
# =============================
|
||
# Step 2: Embedding Model and FAISS Index Setup
|
||
# =============================
|
||
|
||
# Updated embedding model name and token
|
||
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
|
||
access_token = "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"
|
||
print("Loading embedding model...")
|
||
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, token=access_token)
|
||
model = AutoModel.from_pretrained(embedding_model_name, token=access_token)
|
||
|
||
def embed_text(text):
|
||
# Tokenize and embed the text (max_length covers the chunk size)
|
||
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
||
with torch.no_grad():
|
||
model_output = model(**inputs)
|
||
# Compute a simple mean pooling of the last hidden state as the embedding
|
||
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
|
||
return embeddings.numpy()
|
||
|
||
# Filenames for persistence
|
||
INDEX_FILE = "faiss_index_hnsw.bin"
|
||
MAPPING_FILE = "documents_mapping.pkl"
|
||
|
||
# Check if both the index and mapping exist.
|
||
if os.path.exists(INDEX_FILE) and os.path.exists(MAPPING_FILE):
|
||
# Load persisted FAISS index and mapping
|
||
index = faiss.read_index(INDEX_FILE)
|
||
with open(MAPPING_FILE, "rb") as f:
|
||
documents = pickle.load(f)
|
||
print(f"Loaded FAISS index with {index.ntotal} vectors and mapping with {len(documents)} entries.")
|
||
else:
|
||
# Load documents from the directory
|
||
documents = load_documents()
|
||
print(f"Loaded {len(documents)} document chunks.")
|
||
|
||
# Generate embeddings for each document chunk
|
||
print("Generating embeddings for document chunks...")
|
||
document_embeddings = []
|
||
for doc in documents:
|
||
emb = embed_text(doc['text'])
|
||
document_embeddings.append(emb)
|
||
document_embeddings = np.array(document_embeddings, dtype='float32')
|
||
|
||
# Create FAISS HNSW index
|
||
dimension = document_embeddings.shape[1]
|
||
index = faiss.IndexHNSWFlat(dimension, 32)
|
||
index.add(document_embeddings)
|
||
|
||
# Save the FAISS index to disk
|
||
faiss.write_index(index, INDEX_FILE)
|
||
# Persist the mapping using pickle
|
||
with open(MAPPING_FILE, "wb") as f:
|
||
pickle.dump(documents, f)
|
||
print(f"FAISS HNSW index created with {index.ntotal} vectors and mapping saved with {len(documents)} entries.")
|
||
|
||
# =============================
|
||
# Step 3: Enhanced Retriever
|
||
# =============================
|
||
|
||
class Retriever:
|
||
def __init__(self, index, embed_func, documents):
|
||
self.index = index
|
||
self.embed_func = embed_func
|
||
self.documents = documents
|
||
|
||
def retrieve(self, query, k=3):
|
||
try:
|
||
query_embedding = self.embed_func(query)
|
||
distances, indices = self.index.search(np.array([query_embedding], dtype='float32'), k)
|
||
retrieved_docs = []
|
||
for i in indices[0]:
|
||
if i < len(self.documents):
|
||
retrieved_docs.append(self.documents[i])
|
||
logging.info(f"Query: {query} | Retrieved {len(retrieved_docs)} docs")
|
||
logging.info(f"Docs: {retrieved_docs}")
|
||
return retrieved_docs
|
||
except Exception as e:
|
||
logging.error(f"Error retrieving documents: {e}")
|
||
return []
|
||
|
||
retriever = Retriever(index, embed_text, documents)
|
||
|
||
# =============================
|
||
# Step 4: LLM Integration (Llama-3.2-3B)
|
||
# =============================
|
||
|
||
print("Initializing Llama3.2:3B model...")
|
||
try:
|
||
llm = Client()
|
||
except Exception as e:
|
||
logging.error(f"Error initializing Ollama Client: {e}")
|
||
exit(1)
|
||
|
||
prompt_template = Template("""
|
||
You are an AI assistant that provides answers using the given context.
|
||
|
||
Context:
|
||
$context
|
||
|
||
Question:
|
||
$question
|
||
|
||
If the context does not contain an answer, clearly state "I don’t know.".
|
||
Provide a well-structured response.
|
||
|
||
Answer:
|
||
""")
|
||
|
||
def answer_query(question):
|
||
# Retrieve document chunks for the query
|
||
context_chunks = retriever.retrieve(question)
|
||
if not context_chunks:
|
||
return "I'm sorry, I couldn't find relevant information."
|
||
|
||
# Optionally, include the source file names in the context for debugging:
|
||
combined_context = "\n\n".join([f"From {doc['source']}:\n{doc['text']}" for doc in context_chunks])
|
||
prompt = prompt_template.substitute(context=combined_context, question=question)
|
||
|
||
try:
|
||
response = llm.generate(prompt=prompt, model="llama3.2:3b")
|
||
generated_text = getattr(response, 'response', "I'm sorry, I couldn't generate a response.").strip()
|
||
return generated_text
|
||
except Exception as e:
|
||
logging.error(f"Error generating response: {e}")
|
||
return "I'm sorry, I couldn't generate a response at this time."
|
||
|
||
# =============================
|
||
# Step 5: Run Enhanced RAG System
|
||
# =============================
|
||
|
||
if __name__ == "__main__":
|
||
print("\n=== Enhanced RAG System ===")
|
||
print("Type 'exit, quit, or bye' to terminate.\n")
|
||
while True:
|
||
try:
|
||
user_question = input("\n\n----------------\n Enter your question: ")
|
||
if user_question.lower() in ['exit', 'quit', 'bye']:
|
||
print("Exiting. Goodbye!")
|
||
break
|
||
answer = answer_query(user_question)
|
||
print("\n\nAnswer:", answer, "\n")
|
||
except KeyboardInterrupt:
|
||
print("\nExiting. Goodbye!")
|
||
break
|
||
except Exception as e:
|
||
logging.error(f"Unexpected error: {e}")
|
||
print("An unexpected error occurred.")
|
||
|