306 lines
9.9 KiB
Python
306 lines
9.9 KiB
Python
import os
|
||
import numpy as np
|
||
import torch
|
||
import logging
|
||
import psycopg2
|
||
from transformers import AutoTokenizer, AutoModel
|
||
from ollama import Client
|
||
from string import Template
|
||
from nltk.tokenize import sent_tokenize
|
||
|
||
# Prevent Seg Faults
|
||
os.environ["OMP_NUM_THREADS"] = "1"
|
||
|
||
# Configure logging
|
||
logging.basicConfig(filename='rag_system.log', level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
||
# =============================
|
||
# Step 1: Load and Chunk Documents with Overlap
|
||
# =============================
|
||
|
||
def chunk_text(text, max_tokens=512, overlap=10):
|
||
"""
|
||
Split text into smaller chunks with an overlap to preserve context.
|
||
Tokenizes by sentences and assembles them into chunks not exceeding max_tokens,
|
||
with a specified overlap.
|
||
"""
|
||
sentences = sent_tokenize(text)
|
||
chunks = []
|
||
current_chunk = []
|
||
current_length = 0
|
||
|
||
for sentence in sentences:
|
||
tokens = sentence.split()
|
||
token_count = len(tokens)
|
||
if current_length + token_count > max_tokens:
|
||
chunk = " ".join(current_chunk)
|
||
chunks.append(chunk)
|
||
# Create overlap: keep the last few sentences to carry over context
|
||
if overlap > 0:
|
||
overlap_sentences = []
|
||
token_sum = 0
|
||
for sent in reversed(current_chunk):
|
||
sent_tokens = sent.split()
|
||
token_sum += len(sent_tokens)
|
||
overlap_sentences.insert(0, sent)
|
||
if token_sum >= overlap:
|
||
break
|
||
current_chunk = overlap_sentences.copy()
|
||
current_length = sum(len(s.split()) for s in current_chunk)
|
||
else:
|
||
current_chunk = []
|
||
current_length = 0
|
||
current_chunk.append(sentence)
|
||
current_length += token_count
|
||
if current_chunk:
|
||
chunks.append(" ".join(current_chunk))
|
||
return chunks
|
||
|
||
def load_documents(directory='./documents'):
|
||
"""
|
||
Loads all .txt files from the specified directory.
|
||
For each file, returns document chunks along with the source file name.
|
||
"""
|
||
documents = []
|
||
for filename in os.listdir(directory):
|
||
if filename.endswith('.txt'):
|
||
file_path = os.path.join(directory, filename)
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
content = file.read()
|
||
chunks = chunk_text(content)
|
||
for chunk in chunks:
|
||
documents.append({
|
||
'source': filename,
|
||
'text': chunk
|
||
})
|
||
except Exception as e:
|
||
logging.error(f"Error reading {file_path}: {e}")
|
||
return documents
|
||
|
||
# =============================
|
||
# Step 2: Embedding Model Setup with Normalization
|
||
# =============================
|
||
|
||
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
|
||
access_token = "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"
|
||
print("Loading embedding model...")
|
||
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, token=access_token)
|
||
model = AutoModel.from_pretrained(embedding_model_name, token=access_token)
|
||
|
||
def embed_text(text):
|
||
"""
|
||
Tokenizes, embeds, and then normalizes the given text.
|
||
Normalization is done using L2 norm so that the resulting vector is a unit vector.
|
||
"""
|
||
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
||
with torch.no_grad():
|
||
model_output = model(**inputs)
|
||
# Compute mean pooling of the last hidden state as the embedding
|
||
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
|
||
# Normalize the embedding to a unit vector (L2 normalization)
|
||
norm = torch.norm(embeddings, p=2)
|
||
normalized_embedding = embeddings / norm
|
||
return normalized_embedding.numpy()
|
||
|
||
# =============================
|
||
# Step 3: PostgreSQL pgvector Setup and Document Insertion
|
||
# =============================
|
||
|
||
# PostgreSQL connection parameters – adjust these to your environment.
|
||
DB_NAME = "yourdbname"
|
||
DB_USER = "yourusername"
|
||
DB_PASSWORD = "yourpassword"
|
||
DB_HOST = "localhost"
|
||
DB_PORT = "5432"
|
||
|
||
def get_connection():
|
||
return psycopg2.connect(
|
||
dbname=DB_NAME,
|
||
user=DB_USER,
|
||
password=DB_PASSWORD,
|
||
host=DB_HOST,
|
||
port=DB_PORT
|
||
)
|
||
|
||
def create_table():
|
||
"""
|
||
Creates the document_chunks table with a pgvector column for embeddings.
|
||
Using cosine similarity, the vector dimension is set to 768.
|
||
"""
|
||
conn = get_connection()
|
||
cur = conn.cursor()
|
||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||
cur.execute("""
|
||
CREATE TABLE IF NOT EXISTS document_chunks (
|
||
id SERIAL PRIMARY KEY,
|
||
source TEXT,
|
||
text_chunk TEXT,
|
||
embedding vector(768)
|
||
);
|
||
""")
|
||
conn.commit()
|
||
cur.close()
|
||
conn.close()
|
||
|
||
def insert_documents(documents):
|
||
"""
|
||
Inserts document chunks along with their normalized embeddings into the PostgreSQL database.
|
||
"""
|
||
conn = get_connection()
|
||
cur = conn.cursor()
|
||
for doc in documents:
|
||
try:
|
||
emb = embed_text(doc['text'])
|
||
emb_list = emb.tolist()
|
||
insert_query = """
|
||
INSERT INTO document_chunks (source, text_chunk, embedding)
|
||
VALUES (%s, %s, %s);
|
||
"""
|
||
cur.execute(insert_query, (doc['source'], doc['text'], emb_list))
|
||
except Exception as e:
|
||
logging.error(f"Error inserting document chunk into DB: {e}")
|
||
conn.rollback()
|
||
conn.commit()
|
||
cur.close()
|
||
conn.close()
|
||
|
||
def is_table_empty():
|
||
"""
|
||
Checks if the document_chunks table is empty.
|
||
"""
|
||
conn = get_connection()
|
||
cur = conn.cursor()
|
||
cur.execute("SELECT COUNT(*) FROM document_chunks;")
|
||
count = cur.fetchone()[0]
|
||
cur.close()
|
||
conn.close()
|
||
return count == 0
|
||
|
||
# Initialize the table
|
||
create_table()
|
||
|
||
# If no documents exist in the database, load and insert them.
|
||
if is_table_empty():
|
||
documents = load_documents()
|
||
print(f"Loaded {len(documents)} document chunks from files.")
|
||
print("Inserting document chunks into PostgreSQL database...")
|
||
insert_documents(documents)
|
||
else:
|
||
print("Document chunks already exist in the PostgreSQL database.")
|
||
|
||
# =============================
|
||
# Step 4: Retrieval from PostgreSQL using Cosine Distance
|
||
# =============================
|
||
|
||
def retrieve_from_db(query, k=3):
|
||
"""
|
||
Retrieves the top k most similar document chunks from PostgreSQL based on the cosine similarity.
|
||
Uses the `<#>` operator which computes cosine distance.
|
||
"""
|
||
query_embedding = embed_text(query).tolist()
|
||
conn = get_connection()
|
||
cur = conn.cursor()
|
||
sql = """
|
||
SELECT id, source, text_chunk, embedding
|
||
FROM document_chunks
|
||
ORDER BY embedding <#> %s
|
||
LIMIT %s;
|
||
"""
|
||
cur.execute(sql, (query_embedding, k))
|
||
results = cur.fetchall()
|
||
retrieved_docs = []
|
||
for row in results:
|
||
retrieved_docs.append({
|
||
"id": row[0],
|
||
"source": row[1],
|
||
"text": row[2],
|
||
"embedding": row[3]
|
||
})
|
||
cur.close()
|
||
conn.close()
|
||
return retrieved_docs
|
||
|
||
class Retriever:
|
||
def __init__(self, retrieve_func):
|
||
self.retrieve_func = retrieve_func
|
||
|
||
def retrieve(self, query, k=3):
|
||
try:
|
||
retrieved_docs = self.retrieve_func(query, k)
|
||
logging.info(f"Query: {query} | Retrieved {len(retrieved_docs)} docs")
|
||
logging.info(f"Docs: {retrieved_docs}")
|
||
return retrieved_docs
|
||
except Exception as e:
|
||
logging.error(f"Error retrieving documents: {e}")
|
||
return []
|
||
|
||
retriever = Retriever(retrieve_from_db)
|
||
|
||
# =============================
|
||
# Step 5: LLM Integration (Llama-3.2-3B)
|
||
# =============================
|
||
|
||
print("Initializing Llama3.2:3B model...")
|
||
try:
|
||
llm = Client()
|
||
except Exception as e:
|
||
logging.error(f"Error initializing Ollama Client: {e}")
|
||
exit(1)
|
||
|
||
prompt_template = Template("""
|
||
You are an AI assistant that provides answers using the given context.
|
||
|
||
Context:
|
||
$context
|
||
|
||
Question:
|
||
$question
|
||
|
||
If the context does not contain an answer, clearly state "I don’t know.".
|
||
Provide a well-structured response.
|
||
|
||
Answer:
|
||
""")
|
||
|
||
def answer_query(question):
|
||
# Retrieve document chunks for the query using PostgreSQL
|
||
context_chunks = retriever.retrieve(question)
|
||
if not context_chunks:
|
||
return "I'm sorry, I couldn't find relevant information."
|
||
|
||
# Optionally, include the source file names in the context:
|
||
combined_context = "\n\n".join([f"From {doc['source']}:\n{doc['text']}" for doc in context_chunks])
|
||
prompt = prompt_template.substitute(context=combined_context, question=question)
|
||
|
||
try:
|
||
response = llm.generate(prompt=prompt, model="llama3.2:3b")
|
||
generated_text = getattr(response, 'response', "I'm sorry, I couldn't generate a response.").strip()
|
||
return generated_text
|
||
except Exception as e:
|
||
logging.error(f"Error generating response: {e}")
|
||
return "I'm sorry, I couldn't generate a response at this time."
|
||
|
||
# =============================
|
||
# Step 6: Run Enhanced RAG System
|
||
# =============================
|
||
|
||
if __name__ == "__main__":
|
||
print("\n=== Enhanced RAG System with PostgreSQL pgvector (Cosine Similarity) ===")
|
||
print("Type 'exit, quit, or bye' to terminate.\n")
|
||
while True:
|
||
try:
|
||
user_question = input("\n\n----------------\n Enter your question: ")
|
||
if user_question.lower() in ['exit', 'quit', 'bye']:
|
||
print("Exiting. Goodbye!")
|
||
break
|
||
answer = answer_query(user_question)
|
||
print("\n\nAnswer:", answer, "\n")
|
||
except KeyboardInterrupt:
|
||
print("\nExiting. Goodbye!")
|
||
break
|
||
except Exception as e:
|
||
logging.error(f"Unexpected error: {e}")
|
||
print("An unexpected error occurred.")
|