309 lines
14 KiB
Python
309 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
import os
|
||
import faiss
|
||
import string
|
||
import numpy as np
|
||
import torch
|
||
import logging
|
||
import pickle # For loading the mapping
|
||
from transformers import AutoTokenizer, AutoModel
|
||
from string import Template
|
||
import argparse
|
||
|
||
# Import NLTK stopwords and download if needed
|
||
import nltk
|
||
from nltk.corpus import stopwords
|
||
# nltk.download('stopwords')
|
||
stop_words = set(stopwords.words('english'))
|
||
|
||
# Prevent segmentation faults
|
||
os.environ["OMP_NUM_THREADS"] = "1"
|
||
|
||
# Configure logging to file
|
||
logging.basicConfig(
|
||
filename='/var/log/intelaide_python.log',
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
|
||
# ============================================================
|
||
# Step 2: Generate Embeddings and Set Up FAISS Index
|
||
# (Cosine Similarity via Normalization and Inner Product)
|
||
# ============================================================
|
||
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||
|
||
logging.info("*******************************************")
|
||
logging.info("SEARCH: Loading embedding model...")
|
||
logging.info("*******************************************")
|
||
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
|
||
model = AutoModel.from_pretrained(embedding_model_name)
|
||
|
||
def embed_text(text):
|
||
"""
|
||
Tokenizes and embeds the provided text using the transformer model.
|
||
Returns a numpy array of the embedding.
|
||
"""
|
||
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
||
with torch.no_grad():
|
||
model_output = model(**inputs)
|
||
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
|
||
return embeddings.numpy()
|
||
|
||
# ------------------------------------------------------------
|
||
# Helper function: Compute keyword score using precomputed query info
|
||
# ------------------------------------------------------------
|
||
def compute_keyword_score(chunk, query_tokens, persons, boost=1.0):
|
||
"""
|
||
Computes the keyword score for a given chunk using precomputed query tokens and PERSON entities.
|
||
"""
|
||
keyword_score = 0
|
||
chunk_lower = chunk.lower()
|
||
|
||
# Basic scoring: count occurrences of each token in the entire chunk.
|
||
for token in query_tokens:
|
||
if token in chunk_lower:
|
||
keyword_score += 1
|
||
|
||
# Bonus: If the chunk starts with "##" and at least 65% of query tokens appear in the first 10 words, add 5.
|
||
if chunk.startswith("##"):
|
||
first_10_words = chunk.split()[:10]
|
||
first_10_words_lower = [word.lower().strip(string.punctuation) for word in first_10_words]
|
||
tokens_in_header = sum(1 for token in query_tokens if token in first_10_words_lower)
|
||
if query_tokens and (tokens_in_header / len(query_tokens)) >= 0.65:
|
||
logging.info("#*#*#*#*#*#*#*#*#")
|
||
logging.info(f"Header Line meets bonus criteria with {tokens_in_header} of {len(query_tokens)} tokens present.")
|
||
logging.info(f"Chunk with header: {chunk}")
|
||
logging.info("#*#*#*#*#*#*#*#*#")
|
||
keyword_score += 5
|
||
|
||
# Additional bonus: if any PERSON from the query is mentioned in the chunk, add 5 for each match.
|
||
for person in persons:
|
||
if person.lower() in chunk_lower:
|
||
logging.info(f"PERSON boost: Found '{person}' in {chunk_lower}.")
|
||
keyword_score += 5
|
||
|
||
return boost * keyword_score
|
||
|
||
# ------------------------------------------------------------
|
||
# Utility function: Precompute query tokens and PERSON entities
|
||
# ------------------------------------------------------------
|
||
def preprocess_query(query):
|
||
# Precompute query tokens (after stripping punctuation and removing stopwords)
|
||
query_tokens = [
|
||
token.lower().strip(string.punctuation)
|
||
for token in query.split()
|
||
if token.lower().strip(string.punctuation) not in stop_words
|
||
]
|
||
# Extract PERSON entities from the query using NLTK.
|
||
tokens = nltk.word_tokenize(query)
|
||
pos_tags = nltk.pos_tag(tokens)
|
||
named_entities = nltk.ne_chunk(pos_tags)
|
||
persons = []
|
||
logging.info(f"Trying to find a person in query: {tokens}")
|
||
for subtree in named_entities:
|
||
if isinstance(subtree, nltk.Tree) and subtree.label() == 'PERSON':
|
||
person_name = " ".join(token for token, pos in subtree.leaves())
|
||
persons.append(person_name)
|
||
logging.info(f"Found a person in query: {person_name}")
|
||
return query_tokens, persons
|
||
|
||
# ============================================================
|
||
# Step 3: Enhanced Retriever Using Cosine Similarity and Hybrid Scoring
|
||
# ============================================================
|
||
class Retriever:
|
||
def __init__(self, index, embed_func, documents):
|
||
self.index = index
|
||
self.embed_func = embed_func
|
||
self.documents = documents
|
||
|
||
def retrieve(self, query, k=None):
|
||
try:
|
||
# Precompute the query tokens and PERSON entities once.
|
||
query_tokens, persons = preprocess_query(query)
|
||
|
||
# Use all document chunks.
|
||
k = len(self.documents)
|
||
|
||
# Increase efSearch to explore more candidates.
|
||
total_vectors = self.index.ntotal
|
||
if hasattr(self.index, 'hnsw'):
|
||
self.index.hnsw.efSearch = total_vectors
|
||
|
||
# Generate the query embedding and normalize it.
|
||
query_embedding = self.embed_func(query)
|
||
norm = np.linalg.norm(query_embedding)
|
||
if norm > 0:
|
||
query_embedding = query_embedding / norm
|
||
query_embedding = np.array([query_embedding], dtype='float32')
|
||
|
||
# Retrieve the results using k nearest neighbors.
|
||
distances, indices = self.index.search(query_embedding, k)
|
||
|
||
# Weights for hybrid scoring
|
||
alpha = 0.8 # weight for cosine similarity (from FAISS)
|
||
beta = 0.2 # weight for keyword score
|
||
|
||
# Build a list of (index, document, combined_score) tuples.
|
||
results = []
|
||
for idx, dist in zip(indices[0], distances[0]):
|
||
if idx < len(self.documents):
|
||
doc = self.documents[idx]
|
||
# Determine what text to log: full chunk if it mentions "Herzig" or "CPC Tenets", otherwise a preview.
|
||
if ("herzig" in doc['text'].lower()) or ("cpc tenets" in doc['text'].lower()):
|
||
text_to_log = doc['text'].replace('\n', ' ')
|
||
else:
|
||
text_to_log = doc['text'][:200] # preview first 200 characters if desired
|
||
|
||
# Compute keyword score using precomputed query_tokens and persons.
|
||
kw_score = compute_keyword_score(doc['text'], query_tokens, persons, boost=1.0)
|
||
# Compute the combined hybrid score for the chunk.
|
||
combined_score = (alpha * dist) + (beta * kw_score)
|
||
if ("herzig" in doc['text'].lower()) or ("cpc tenets" in doc['text'].lower()):
|
||
logging.info("=================================================")
|
||
logging.info(f"SEARCH: SHOWING herzig or cpc tenets CHUNKS: idx: {idx} | source: {doc['source']} | cosine_sim: {dist:.4f} | keyword_score: {kw_score} | hybrid_score: {combined_score:.4f} | chunk: {text_to_log}")
|
||
logging.info("=================================================")
|
||
results.append((idx, doc, combined_score))
|
||
|
||
# --- Compute and log the average combined_score per source in descending order ---
|
||
source_totals = {}
|
||
source_counts = {}
|
||
for idx, doc, score in results:
|
||
source = doc['source']
|
||
source_totals[source] = float(source_totals.get(source, 0)) + float(score)
|
||
source_counts[source] = source_counts.get(source, 0) + 1
|
||
|
||
averages = [(source, source_totals[source] / source_counts[source]) for source in source_totals]
|
||
averages_sorted = sorted(averages, key=lambda x: x[1], reverse=True)
|
||
|
||
for source, avg_score in averages_sorted:
|
||
logging.info(f"SOURCE AVERAGE: {source} - Average Combined Score: {avg_score:.4f}")
|
||
|
||
# --- Add 0.1 to the score for all text chunks from the highest average source ---
|
||
if averages_sorted:
|
||
highest_source = averages_sorted[0][0]
|
||
results = [
|
||
(idx, doc, score + 0.1) if doc['source'] == highest_source else (idx, doc, score)
|
||
for idx, doc, score in results
|
||
]
|
||
|
||
# Sort the results by combined_score descending.
|
||
results_sorted = sorted(results, key=lambda x: x[2], reverse=True)
|
||
top_results = results_sorted[:5]
|
||
|
||
# --- Build the final top_docs list and prepare logging info ---
|
||
top_docs = []
|
||
final_logging = [] # List of tuples (index, doc, score) for logging purposes
|
||
for idx, doc, score in top_results:
|
||
top_docs.append(doc)
|
||
final_logging.append((idx, doc, score))
|
||
next_idx = idx + 1
|
||
if next_idx < len(self.documents):
|
||
next_doc = self.documents[next_idx]
|
||
# Optionally, include the next chunk if needed:
|
||
# top_docs.append(next_doc)
|
||
# final_logging.append((next_idx, next_doc, None))
|
||
|
||
|
||
# === Enhancement: Include preceding chunks for non-header texts that belong to the same source ===
|
||
def has_header(text):
|
||
# Check if text (after stripping leading whitespace) starts with a header marker
|
||
stripped = text.lstrip()
|
||
return stripped.startswith("##") or stripped.startswith("###") or stripped.startswith("####")
|
||
|
||
# Build a dict keyed by index from our final_logging for easier merging.
|
||
final_docs_dict = {idx: doc for idx, doc, _ in final_logging}
|
||
|
||
# For each retrieved document, if its text does not start with a header,
|
||
# walk backwards to include all preceding chunks from the same source until one with a header is found.
|
||
for idx, doc, _ in final_logging:
|
||
if not has_header(doc['text']):
|
||
current_source = doc['source']
|
||
current_idx = idx - 1
|
||
while current_idx >= 0:
|
||
# Stop if the preceding document is from a different source.
|
||
if self.documents[current_idx]['source'] != current_source:
|
||
break
|
||
# Add the document if not already added.
|
||
if current_idx not in final_docs_dict:
|
||
final_docs_dict[current_idx] = self.documents[current_idx]
|
||
# Stop walking back if this document has a header.
|
||
if has_header(self.documents[current_idx]['text']):
|
||
break
|
||
current_idx -= 1
|
||
|
||
# Sort the final documents by their indices in ascending order.
|
||
final_top_docs = [final_docs_dict[i] for i in sorted(final_docs_dict.keys())]
|
||
|
||
logging.info("*************************************************")
|
||
logging.info(f"SEARCH: FINAL STEP for Query: {query} - Returning {len(final_top_docs)} docs after re-ranking")
|
||
logging.info("*************************************************")
|
||
|
||
return final_top_docs
|
||
except Exception as e:
|
||
logging.error(f"SEARCH: Error retrieving documents: {e}")
|
||
return []
|
||
|
||
# ============================================================
|
||
# Step 4: Prompt Template and Answer Function
|
||
# ============================================================
|
||
prompt_template = Template("""
|
||
You are a friendly AI assistant that provides answers using the given context.
|
||
If the context does not contain an answer, clearly state "I don’t know.". Do not try to expand any abbreviations.
|
||
Provide a well-structured response.
|
||
|
||
Context:
|
||
------------------------
|
||
$context
|
||
|
||
Question:
|
||
$question
|
||
|
||
Answer:
|
||
""")
|
||
|
||
def answer_query(question, retriever, k=None):
|
||
context_chunks = retriever.retrieve(question, k)
|
||
if not context_chunks:
|
||
return "I'm sorry, I couldn't find relevant information."
|
||
#combined_context = "\n\n".join([f"From {doc['source']}:\n{doc['text']}" for doc in context_chunks])
|
||
combined_context = "".join([f"\n{doc['text']}\n\n" for doc in context_chunks])
|
||
prompt = prompt_template.substitute(context=combined_context, question=question)
|
||
return prompt
|
||
|
||
# ============================================================
|
||
# Main Execution: Process Query and Output Retrieved Context
|
||
# ============================================================
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(
|
||
description="FAISS search for relevant document snippets using a persisted mapping file with cosine similarity."
|
||
)
|
||
parser.add_argument("--faiss_index_path", required=True,
|
||
help="Path to the FAISS index file")
|
||
parser.add_argument("--query", required=True, help="The query text")
|
||
parser.add_argument("--k", type=int, default=15,
|
||
help="Number of nearest neighbors to retrieve (default: 15)")
|
||
args = parser.parse_args()
|
||
|
||
if not os.path.exists(args.faiss_index_path):
|
||
logging.error(f"SEARCH: FAISS index file not found at {args.faiss_index_path}")
|
||
exit(1)
|
||
else:
|
||
index = faiss.read_index(args.faiss_index_path)
|
||
|
||
# Determine the mapping file path (assumed to be in the same directory as the index file)
|
||
mapping_file = os.path.join(os.path.dirname(args.faiss_index_path), "faiss_index_mapping.pkl")
|
||
if not os.path.exists(mapping_file):
|
||
logging.error(f"SEARCH: Mapping file not found at {mapping_file}")
|
||
exit(1)
|
||
else:
|
||
with open(mapping_file, "rb") as f:
|
||
documents = pickle.load(f)
|
||
|
||
# Initialize retriever with the loaded index and mapping
|
||
retriever = Retriever(index, embed_text, documents)
|
||
prompt_output = answer_query(args.query, retriever, args.k)
|
||
|
||
# Output the generated prompt (which includes the retrieved context)
|
||
print(prompt_output)
|
||
|