Files
intelaide/intelaide-backend/python/ai_gen_rag_mapping_postgresql.py
2026-01-20 04:54:10 +00:00

306 lines
9.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import numpy as np
import torch
import logging
import psycopg2
from transformers import AutoTokenizer, AutoModel
from ollama import Client
from string import Template
from nltk.tokenize import sent_tokenize
# Prevent Seg Faults
os.environ["OMP_NUM_THREADS"] = "1"
# Configure logging
logging.basicConfig(filename='rag_system.log', level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
# =============================
# Step 1: Load and Chunk Documents with Overlap
# =============================
def chunk_text(text, max_tokens=512, overlap=10):
"""
Split text into smaller chunks with an overlap to preserve context.
Tokenizes by sentences and assembles them into chunks not exceeding max_tokens,
with a specified overlap.
"""
sentences = sent_tokenize(text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
tokens = sentence.split()
token_count = len(tokens)
if current_length + token_count > max_tokens:
chunk = " ".join(current_chunk)
chunks.append(chunk)
# Create overlap: keep the last few sentences to carry over context
if overlap > 0:
overlap_sentences = []
token_sum = 0
for sent in reversed(current_chunk):
sent_tokens = sent.split()
token_sum += len(sent_tokens)
overlap_sentences.insert(0, sent)
if token_sum >= overlap:
break
current_chunk = overlap_sentences.copy()
current_length = sum(len(s.split()) for s in current_chunk)
else:
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += token_count
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def load_documents(directory='./documents'):
"""
Loads all .txt files from the specified directory.
For each file, returns document chunks along with the source file name.
"""
documents = []
for filename in os.listdir(directory):
if filename.endswith('.txt'):
file_path = os.path.join(directory, filename)
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
chunks = chunk_text(content)
for chunk in chunks:
documents.append({
'source': filename,
'text': chunk
})
except Exception as e:
logging.error(f"Error reading {file_path}: {e}")
return documents
# =============================
# Step 2: Embedding Model Setup with Normalization
# =============================
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
access_token = "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"
print("Loading embedding model...")
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, token=access_token)
model = AutoModel.from_pretrained(embedding_model_name, token=access_token)
def embed_text(text):
"""
Tokenizes, embeds, and then normalizes the given text.
Normalization is done using L2 norm so that the resulting vector is a unit vector.
"""
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
with torch.no_grad():
model_output = model(**inputs)
# Compute mean pooling of the last hidden state as the embedding
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
# Normalize the embedding to a unit vector (L2 normalization)
norm = torch.norm(embeddings, p=2)
normalized_embedding = embeddings / norm
return normalized_embedding.numpy()
# =============================
# Step 3: PostgreSQL pgvector Setup and Document Insertion
# =============================
# PostgreSQL connection parameters adjust these to your environment.
DB_NAME = "yourdbname"
DB_USER = "yourusername"
DB_PASSWORD = "yourpassword"
DB_HOST = "localhost"
DB_PORT = "5432"
def get_connection():
return psycopg2.connect(
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
host=DB_HOST,
port=DB_PORT
)
def create_table():
"""
Creates the document_chunks table with a pgvector column for embeddings.
Using cosine similarity, the vector dimension is set to 768.
"""
conn = get_connection()
cur = conn.cursor()
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
cur.execute("""
CREATE TABLE IF NOT EXISTS document_chunks (
id SERIAL PRIMARY KEY,
source TEXT,
text_chunk TEXT,
embedding vector(768)
);
""")
conn.commit()
cur.close()
conn.close()
def insert_documents(documents):
"""
Inserts document chunks along with their normalized embeddings into the PostgreSQL database.
"""
conn = get_connection()
cur = conn.cursor()
for doc in documents:
try:
emb = embed_text(doc['text'])
emb_list = emb.tolist()
insert_query = """
INSERT INTO document_chunks (source, text_chunk, embedding)
VALUES (%s, %s, %s);
"""
cur.execute(insert_query, (doc['source'], doc['text'], emb_list))
except Exception as e:
logging.error(f"Error inserting document chunk into DB: {e}")
conn.rollback()
conn.commit()
cur.close()
conn.close()
def is_table_empty():
"""
Checks if the document_chunks table is empty.
"""
conn = get_connection()
cur = conn.cursor()
cur.execute("SELECT COUNT(*) FROM document_chunks;")
count = cur.fetchone()[0]
cur.close()
conn.close()
return count == 0
# Initialize the table
create_table()
# If no documents exist in the database, load and insert them.
if is_table_empty():
documents = load_documents()
print(f"Loaded {len(documents)} document chunks from files.")
print("Inserting document chunks into PostgreSQL database...")
insert_documents(documents)
else:
print("Document chunks already exist in the PostgreSQL database.")
# =============================
# Step 4: Retrieval from PostgreSQL using Cosine Distance
# =============================
def retrieve_from_db(query, k=3):
"""
Retrieves the top k most similar document chunks from PostgreSQL based on the cosine similarity.
Uses the `<#>` operator which computes cosine distance.
"""
query_embedding = embed_text(query).tolist()
conn = get_connection()
cur = conn.cursor()
sql = """
SELECT id, source, text_chunk, embedding
FROM document_chunks
ORDER BY embedding <#> %s
LIMIT %s;
"""
cur.execute(sql, (query_embedding, k))
results = cur.fetchall()
retrieved_docs = []
for row in results:
retrieved_docs.append({
"id": row[0],
"source": row[1],
"text": row[2],
"embedding": row[3]
})
cur.close()
conn.close()
return retrieved_docs
class Retriever:
def __init__(self, retrieve_func):
self.retrieve_func = retrieve_func
def retrieve(self, query, k=3):
try:
retrieved_docs = self.retrieve_func(query, k)
logging.info(f"Query: {query} | Retrieved {len(retrieved_docs)} docs")
logging.info(f"Docs: {retrieved_docs}")
return retrieved_docs
except Exception as e:
logging.error(f"Error retrieving documents: {e}")
return []
retriever = Retriever(retrieve_from_db)
# =============================
# Step 5: LLM Integration (Llama-3.2-3B)
# =============================
print("Initializing Llama3.2:3B model...")
try:
llm = Client()
except Exception as e:
logging.error(f"Error initializing Ollama Client: {e}")
exit(1)
prompt_template = Template("""
You are an AI assistant that provides answers using the given context.
Context:
$context
Question:
$question
If the context does not contain an answer, clearly state "I dont know.".
Provide a well-structured response.
Answer:
""")
def answer_query(question):
# Retrieve document chunks for the query using PostgreSQL
context_chunks = retriever.retrieve(question)
if not context_chunks:
return "I'm sorry, I couldn't find relevant information."
# Optionally, include the source file names in the context:
combined_context = "\n\n".join([f"From {doc['source']}:\n{doc['text']}" for doc in context_chunks])
prompt = prompt_template.substitute(context=combined_context, question=question)
try:
response = llm.generate(prompt=prompt, model="llama3.2:3b")
generated_text = getattr(response, 'response', "I'm sorry, I couldn't generate a response.").strip()
return generated_text
except Exception as e:
logging.error(f"Error generating response: {e}")
return "I'm sorry, I couldn't generate a response at this time."
# =============================
# Step 6: Run Enhanced RAG System
# =============================
if __name__ == "__main__":
print("\n=== Enhanced RAG System with PostgreSQL pgvector (Cosine Similarity) ===")
print("Type 'exit, quit, or bye' to terminate.\n")
while True:
try:
user_question = input("\n\n----------------\n Enter your question: ")
if user_question.lower() in ['exit', 'quit', 'bye']:
print("Exiting. Goodbye!")
break
answer = answer_query(user_question)
print("\n\nAnswer:", answer, "\n")
except KeyboardInterrupt:
print("\nExiting. Goodbye!")
break
except Exception as e:
logging.error(f"Unexpected error: {e}")
print("An unexpected error occurred.")