Files
intelaide/intelaide-backend/python/bkup_faiss_search_mapping.py
2026-01-20 04:54:10 +00:00

309 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
import os
import faiss
import string
import numpy as np
import torch
import logging
import pickle # For loading the mapping
from transformers import AutoTokenizer, AutoModel
from string import Template
import argparse
# Import NLTK stopwords and download if needed
import nltk
from nltk.corpus import stopwords
# nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
# Prevent segmentation faults
os.environ["OMP_NUM_THREADS"] = "1"
# Configure logging to file
logging.basicConfig(
filename='/var/log/intelaide_python.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# ============================================================
# Step 2: Generate Embeddings and Set Up FAISS Index
# (Cosine Similarity via Normalization and Inner Product)
# ============================================================
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
logging.info("*******************************************")
logging.info("SEARCH: Loading embedding model...")
logging.info("*******************************************")
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
model = AutoModel.from_pretrained(embedding_model_name)
def embed_text(text):
"""
Tokenizes and embeds the provided text using the transformer model.
Returns a numpy array of the embedding.
"""
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
with torch.no_grad():
model_output = model(**inputs)
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
return embeddings.numpy()
# ------------------------------------------------------------
# Helper function: Compute keyword score using precomputed query info
# ------------------------------------------------------------
def compute_keyword_score(chunk, query_tokens, persons, boost=1.0):
"""
Computes the keyword score for a given chunk using precomputed query tokens and PERSON entities.
"""
keyword_score = 0
chunk_lower = chunk.lower()
# Basic scoring: count occurrences of each token in the entire chunk.
for token in query_tokens:
if token in chunk_lower:
keyword_score += 1
# Bonus: If the chunk starts with "##" and at least 65% of query tokens appear in the first 10 words, add 5.
if chunk.startswith("##"):
first_10_words = chunk.split()[:10]
first_10_words_lower = [word.lower().strip(string.punctuation) for word in first_10_words]
tokens_in_header = sum(1 for token in query_tokens if token in first_10_words_lower)
if query_tokens and (tokens_in_header / len(query_tokens)) >= 0.65:
logging.info("#*#*#*#*#*#*#*#*#")
logging.info(f"Header Line meets bonus criteria with {tokens_in_header} of {len(query_tokens)} tokens present.")
logging.info(f"Chunk with header: {chunk}")
logging.info("#*#*#*#*#*#*#*#*#")
keyword_score += 5
# Additional bonus: if any PERSON from the query is mentioned in the chunk, add 5 for each match.
for person in persons:
if person.lower() in chunk_lower:
logging.info(f"PERSON boost: Found '{person}' in {chunk_lower}.")
keyword_score += 5
return boost * keyword_score
# ------------------------------------------------------------
# Utility function: Precompute query tokens and PERSON entities
# ------------------------------------------------------------
def preprocess_query(query):
# Precompute query tokens (after stripping punctuation and removing stopwords)
query_tokens = [
token.lower().strip(string.punctuation)
for token in query.split()
if token.lower().strip(string.punctuation) not in stop_words
]
# Extract PERSON entities from the query using NLTK.
tokens = nltk.word_tokenize(query)
pos_tags = nltk.pos_tag(tokens)
named_entities = nltk.ne_chunk(pos_tags)
persons = []
logging.info(f"Trying to find a person in query: {tokens}")
for subtree in named_entities:
if isinstance(subtree, nltk.Tree) and subtree.label() == 'PERSON':
person_name = " ".join(token for token, pos in subtree.leaves())
persons.append(person_name)
logging.info(f"Found a person in query: {person_name}")
return query_tokens, persons
# ============================================================
# Step 3: Enhanced Retriever Using Cosine Similarity and Hybrid Scoring
# ============================================================
class Retriever:
def __init__(self, index, embed_func, documents):
self.index = index
self.embed_func = embed_func
self.documents = documents
def retrieve(self, query, k=None):
try:
# Precompute the query tokens and PERSON entities once.
query_tokens, persons = preprocess_query(query)
# Use all document chunks.
k = len(self.documents)
# Increase efSearch to explore more candidates.
total_vectors = self.index.ntotal
if hasattr(self.index, 'hnsw'):
self.index.hnsw.efSearch = total_vectors
# Generate the query embedding and normalize it.
query_embedding = self.embed_func(query)
norm = np.linalg.norm(query_embedding)
if norm > 0:
query_embedding = query_embedding / norm
query_embedding = np.array([query_embedding], dtype='float32')
# Retrieve the results using k nearest neighbors.
distances, indices = self.index.search(query_embedding, k)
# Weights for hybrid scoring
alpha = 0.8 # weight for cosine similarity (from FAISS)
beta = 0.2 # weight for keyword score
# Build a list of (index, document, combined_score) tuples.
results = []
for idx, dist in zip(indices[0], distances[0]):
if idx < len(self.documents):
doc = self.documents[idx]
# Determine what text to log: full chunk if it mentions "Herzig" or "CPC Tenets", otherwise a preview.
if ("herzig" in doc['text'].lower()) or ("cpc tenets" in doc['text'].lower()):
text_to_log = doc['text'].replace('\n', ' ')
else:
text_to_log = doc['text'][:200] # preview first 200 characters if desired
# Compute keyword score using precomputed query_tokens and persons.
kw_score = compute_keyword_score(doc['text'], query_tokens, persons, boost=1.0)
# Compute the combined hybrid score for the chunk.
combined_score = (alpha * dist) + (beta * kw_score)
if ("herzig" in doc['text'].lower()) or ("cpc tenets" in doc['text'].lower()):
logging.info("=================================================")
logging.info(f"SEARCH: SHOWING herzig or cpc tenets CHUNKS: idx: {idx} | source: {doc['source']} | cosine_sim: {dist:.4f} | keyword_score: {kw_score} | hybrid_score: {combined_score:.4f} | chunk: {text_to_log}")
logging.info("=================================================")
results.append((idx, doc, combined_score))
# --- Compute and log the average combined_score per source in descending order ---
source_totals = {}
source_counts = {}
for idx, doc, score in results:
source = doc['source']
source_totals[source] = float(source_totals.get(source, 0)) + float(score)
source_counts[source] = source_counts.get(source, 0) + 1
averages = [(source, source_totals[source] / source_counts[source]) for source in source_totals]
averages_sorted = sorted(averages, key=lambda x: x[1], reverse=True)
for source, avg_score in averages_sorted:
logging.info(f"SOURCE AVERAGE: {source} - Average Combined Score: {avg_score:.4f}")
# --- Add 0.1 to the score for all text chunks from the highest average source ---
if averages_sorted:
highest_source = averages_sorted[0][0]
results = [
(idx, doc, score + 0.1) if doc['source'] == highest_source else (idx, doc, score)
for idx, doc, score in results
]
# Sort the results by combined_score descending.
results_sorted = sorted(results, key=lambda x: x[2], reverse=True)
top_results = results_sorted[:5]
# --- Build the final top_docs list and prepare logging info ---
top_docs = []
final_logging = [] # List of tuples (index, doc, score) for logging purposes
for idx, doc, score in top_results:
top_docs.append(doc)
final_logging.append((idx, doc, score))
next_idx = idx + 1
if next_idx < len(self.documents):
next_doc = self.documents[next_idx]
# Optionally, include the next chunk if needed:
# top_docs.append(next_doc)
# final_logging.append((next_idx, next_doc, None))
# === Enhancement: Include preceding chunks for non-header texts that belong to the same source ===
def has_header(text):
# Check if text (after stripping leading whitespace) starts with a header marker
stripped = text.lstrip()
return stripped.startswith("##") or stripped.startswith("###") or stripped.startswith("####")
# Build a dict keyed by index from our final_logging for easier merging.
final_docs_dict = {idx: doc for idx, doc, _ in final_logging}
# For each retrieved document, if its text does not start with a header,
# walk backwards to include all preceding chunks from the same source until one with a header is found.
for idx, doc, _ in final_logging:
if not has_header(doc['text']):
current_source = doc['source']
current_idx = idx - 1
while current_idx >= 0:
# Stop if the preceding document is from a different source.
if self.documents[current_idx]['source'] != current_source:
break
# Add the document if not already added.
if current_idx not in final_docs_dict:
final_docs_dict[current_idx] = self.documents[current_idx]
# Stop walking back if this document has a header.
if has_header(self.documents[current_idx]['text']):
break
current_idx -= 1
# Sort the final documents by their indices in ascending order.
final_top_docs = [final_docs_dict[i] for i in sorted(final_docs_dict.keys())]
logging.info("*************************************************")
logging.info(f"SEARCH: FINAL STEP for Query: {query} - Returning {len(final_top_docs)} docs after re-ranking")
logging.info("*************************************************")
return final_top_docs
except Exception as e:
logging.error(f"SEARCH: Error retrieving documents: {e}")
return []
# ============================================================
# Step 4: Prompt Template and Answer Function
# ============================================================
prompt_template = Template("""
You are a friendly AI assistant that provides answers using the given context.
If the context does not contain an answer, clearly state "I dont know.". Do not try to expand any abbreviations.
Provide a well-structured response.
Context:
------------------------
$context
Question:
$question
Answer:
""")
def answer_query(question, retriever, k=None):
context_chunks = retriever.retrieve(question, k)
if not context_chunks:
return "I'm sorry, I couldn't find relevant information."
#combined_context = "\n\n".join([f"From {doc['source']}:\n{doc['text']}" for doc in context_chunks])
combined_context = "".join([f"\n{doc['text']}\n\n" for doc in context_chunks])
prompt = prompt_template.substitute(context=combined_context, question=question)
return prompt
# ============================================================
# Main Execution: Process Query and Output Retrieved Context
# ============================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="FAISS search for relevant document snippets using a persisted mapping file with cosine similarity."
)
parser.add_argument("--faiss_index_path", required=True,
help="Path to the FAISS index file")
parser.add_argument("--query", required=True, help="The query text")
parser.add_argument("--k", type=int, default=15,
help="Number of nearest neighbors to retrieve (default: 15)")
args = parser.parse_args()
if not os.path.exists(args.faiss_index_path):
logging.error(f"SEARCH: FAISS index file not found at {args.faiss_index_path}")
exit(1)
else:
index = faiss.read_index(args.faiss_index_path)
# Determine the mapping file path (assumed to be in the same directory as the index file)
mapping_file = os.path.join(os.path.dirname(args.faiss_index_path), "faiss_index_mapping.pkl")
if not os.path.exists(mapping_file):
logging.error(f"SEARCH: Mapping file not found at {mapping_file}")
exit(1)
else:
with open(mapping_file, "rb") as f:
documents = pickle.load(f)
# Initialize retriever with the loaded index and mapping
retriever = Retriever(index, embed_text, documents)
prompt_output = answer_query(args.query, retriever, args.k)
# Output the generated prompt (which includes the retrieved context)
print(prompt_output)