Files
intelaide/intelaide-backend/python/rag_mapping.py
2026-01-20 04:54:10 +00:00

237 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import faiss
import numpy as np
import torch
import logging
import pickle
from transformers import AutoTokenizer, AutoModel
from ollama import Client
from string import Template
from nltk.tokenize import sent_tokenize
# Prevent Seg Faults
os.environ["OMP_NUM_THREADS"] = "1"
# Configure logging
logging.basicConfig(filename='rag_system.log', level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
# =============================
# Step 1: Load and Chunk Documents with Overlap
# =============================
def chunk_text(text, max_tokens=512, overlap=10):
"""
Split text into smaller chunks with an overlap to preserve context.
The function tokenizes by sentences and then assembles them into chunks
not exceeding max_tokens, but with a specified number of overlapping tokens
between consecutive chunks.
"""
sentences = sent_tokenize(text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
tokens = sentence.split()
token_count = len(tokens)
if current_length + token_count > max_tokens:
# Append the current chunk
chunk = " ".join(current_chunk)
chunks.append(chunk)
# Create overlap: keep the last few sentences to carry over context
if overlap > 0:
overlap_sentences = []
token_sum = 0
# Iterate backwards over current_chunk until we reach the desired token overlap
for sent in reversed(current_chunk):
sent_tokens = sent.split()
token_sum += len(sent_tokens)
overlap_sentences.insert(0, sent)
if token_sum >= overlap:
break
current_chunk = overlap_sentences.copy()
current_length = sum(len(s.split()) for s in current_chunk)
else:
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += token_count
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def load_documents(directory='./documents'):
"""
Loads all .txt files from the specified directory.
For each file, the function returns a dictionary with keys 'source' and 'text'
for each chunk.
"""
documents = []
for filename in os.listdir(directory):
if filename.endswith('.txt'):
file_path = os.path.join(directory, filename)
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read() # Optionally, you could clean up headers or metadata here
chunks = chunk_text(content)
for chunk in chunks:
documents.append({
'source': filename,
'text': chunk
})
except Exception as e:
logging.error(f"Error reading {file_path}: {e}")
return documents
# =============================
# Step 2: Embedding Model and FAISS Index Setup
# =============================
# Updated embedding model name and token
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
access_token = "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"
print("Loading embedding model...")
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, token=access_token)
model = AutoModel.from_pretrained(embedding_model_name, token=access_token)
def embed_text(text):
# Tokenize and embed the text (max_length covers the chunk size)
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
with torch.no_grad():
model_output = model(**inputs)
# Compute a simple mean pooling of the last hidden state as the embedding
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
return embeddings.numpy()
# Filenames for persistence
INDEX_FILE = "faiss_index_hnsw.bin"
MAPPING_FILE = "documents_mapping.pkl"
# Check if both the index and mapping exist.
if os.path.exists(INDEX_FILE) and os.path.exists(MAPPING_FILE):
# Load persisted FAISS index and mapping
index = faiss.read_index(INDEX_FILE)
with open(MAPPING_FILE, "rb") as f:
documents = pickle.load(f)
print(f"Loaded FAISS index with {index.ntotal} vectors and mapping with {len(documents)} entries.")
else:
# Load documents from the directory
documents = load_documents()
print(f"Loaded {len(documents)} document chunks.")
# Generate embeddings for each document chunk
print("Generating embeddings for document chunks...")
document_embeddings = []
for doc in documents:
emb = embed_text(doc['text'])
document_embeddings.append(emb)
document_embeddings = np.array(document_embeddings, dtype='float32')
# Create FAISS HNSW index
dimension = document_embeddings.shape[1]
index = faiss.IndexHNSWFlat(dimension, 32)
index.add(document_embeddings)
# Save the FAISS index to disk
faiss.write_index(index, INDEX_FILE)
# Persist the mapping using pickle
with open(MAPPING_FILE, "wb") as f:
pickle.dump(documents, f)
print(f"FAISS HNSW index created with {index.ntotal} vectors and mapping saved with {len(documents)} entries.")
# =============================
# Step 3: Enhanced Retriever
# =============================
class Retriever:
def __init__(self, index, embed_func, documents):
self.index = index
self.embed_func = embed_func
self.documents = documents
def retrieve(self, query, k=3):
try:
query_embedding = self.embed_func(query)
distances, indices = self.index.search(np.array([query_embedding], dtype='float32'), k)
retrieved_docs = []
for i in indices[0]:
if i < len(self.documents):
retrieved_docs.append(self.documents[i])
logging.info(f"Query: {query} | Retrieved {len(retrieved_docs)} docs")
logging.info(f"Docs: {retrieved_docs}")
return retrieved_docs
except Exception as e:
logging.error(f"Error retrieving documents: {e}")
return []
retriever = Retriever(index, embed_text, documents)
# =============================
# Step 4: LLM Integration (Llama-3.2-3B)
# =============================
print("Initializing Llama3.2:3B model...")
try:
llm = Client()
except Exception as e:
logging.error(f"Error initializing Ollama Client: {e}")
exit(1)
prompt_template = Template("""
You are an AI assistant that provides answers using the given context.
Context:
$context
Question:
$question
If the context does not contain an answer, clearly state "I dont know.".
Provide a well-structured response.
Answer:
""")
def answer_query(question):
# Retrieve document chunks for the query
context_chunks = retriever.retrieve(question)
if not context_chunks:
return "I'm sorry, I couldn't find relevant information."
# Optionally, include the source file names in the context for debugging:
combined_context = "\n\n".join([f"From {doc['source']}:\n{doc['text']}" for doc in context_chunks])
prompt = prompt_template.substitute(context=combined_context, question=question)
try:
response = llm.generate(prompt=prompt, model="llama3.2:3b")
generated_text = getattr(response, 'response', "I'm sorry, I couldn't generate a response.").strip()
return generated_text
except Exception as e:
logging.error(f"Error generating response: {e}")
return "I'm sorry, I couldn't generate a response at this time."
# =============================
# Step 5: Run Enhanced RAG System
# =============================
if __name__ == "__main__":
print("\n=== Enhanced RAG System ===")
print("Type 'exit, quit, or bye' to terminate.\n")
while True:
try:
user_question = input("\n\n----------------\n Enter your question: ")
if user_question.lower() in ['exit', 'quit', 'bye']:
print("Exiting. Goodbye!")
break
answer = answer_query(user_question)
print("\n\nAnswer:", answer, "\n")
except KeyboardInterrupt:
print("\nExiting. Goodbye!")
break
except Exception as e:
logging.error(f"Unexpected error: {e}")
print("An unexpected error occurred.")