89 lines
3.1 KiB
Python
89 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
import faiss
|
|
import pickle
|
|
import numpy as np
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
# Configuration: paths to your FAISS index and mapping file
|
|
faiss_index_path = '/root/intelaide-backend/documents/user_3/assistant_2/faiss_index.bin'
|
|
mapping_path = '/root/intelaide-backend/documents/user_3/assistant_2/faiss_index_mapping.pkl'
|
|
|
|
# Embedding model settings
|
|
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
|
|
model = AutoModel.from_pretrained(embedding_model_name)
|
|
|
|
def embed_text(text):
|
|
"""
|
|
Tokenizes and embeds the provided text using the transformer model.
|
|
Returns a numpy array of the embedding.
|
|
"""
|
|
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
|
with torch.no_grad():
|
|
model_output = model(**inputs)
|
|
# Mean pooling of the last hidden state as a simple embedding
|
|
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
|
|
return embeddings.numpy()
|
|
|
|
def normalize_embedding(embedding):
|
|
"""
|
|
Normalizes a numpy embedding vector.
|
|
"""
|
|
norm = np.linalg.norm(embedding)
|
|
return embedding / norm
|
|
|
|
def main():
|
|
# Load the FAISS index
|
|
index = faiss.read_index(faiss_index_path)
|
|
total_vectors = index.ntotal
|
|
print(f"Loaded FAISS index with {total_vectors} vectors.")
|
|
|
|
# Increase efSearch to explore more candidates (adjust as needed)
|
|
if hasattr(index, 'hnsw'):
|
|
index.hnsw.efSearch = total_vectors # or a value high enough to cover your dataset
|
|
|
|
# Load the mapping file (list of document chunks with metadata)
|
|
with open(mapping_path, 'rb') as f:
|
|
mapping = pickle.load(f)
|
|
|
|
# Define the query string
|
|
query = "Tell me about Heather Herzig"
|
|
|
|
# Embed and normalize the query
|
|
query_embedding = embed_text(query).astype('float32')
|
|
query_embedding = normalize_embedding(query_embedding).reshape(1, -1)
|
|
|
|
# Set k equal to the total number of vectors
|
|
k = total_vectors
|
|
distances, indices = index.search(query_embedding, k)
|
|
|
|
# Count total results from the index search
|
|
total_results = len(indices[0])
|
|
print(f"Total results returned by index.search: {total_results}")
|
|
|
|
# Filter results for those containing "Herzig" and count them
|
|
herzig_count = 0
|
|
print("Search Results (only chunks containing 'Herzig'):")
|
|
for dist, idx in zip(distances[0], indices[0]):
|
|
# Check if idx is within mapping range
|
|
if idx < len(mapping):
|
|
doc = mapping[idx]
|
|
chunk_text = doc.get('text', "")
|
|
if "Herzig" in chunk_text:
|
|
herzig_count += 1
|
|
print(f"Chunk Index: {idx}")
|
|
print(f"Distance: {dist}")
|
|
print(f"Source File: {doc.get('source')}")
|
|
print("Chunk Text:")
|
|
print(chunk_text)
|
|
print("-" * 40)
|
|
else:
|
|
print(f"Index {idx} is out of range in the mapping file.")
|
|
|
|
print(f"Total results with 'Herzig' in the chunk text: {herzig_count}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|