Files
intelaide/intelaide-backend/python/bkup_create_embeds_from_files_mapping.py
2026-01-20 04:54:10 +00:00

194 lines
8.1 KiB
Python

#!/usr/bin/env python3
import os
import re
import glob
import faiss
import numpy as np
import torch
import logging
from transformers import AutoTokenizer, AutoModel
from nltk.tokenize import sent_tokenize
import argparse
import pickle
# Prevent segmentation faults by limiting OpenMP threads
os.environ["OMP_NUM_THREADS"] = "1"
# Configure logging to file
logging.basicConfig(
filename='/var/log/intelaide_python.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def count_tokens(text):
"""Count tokens using simple whitespace splitting.
Replace with a more advanced tokenizer if needed.
"""
return len(text.split())
def chunk_text_hybrid(text, max_tokens=512, overlap=20):
"""
Hybrid chunking with header awareness:
1. Splits the text into sections based on headers (lines starting with '##').
Each section includes the header and its associated content until the next header.
2. For each section:
- If the token count is <= max_tokens, the section is used as a chunk.
- Otherwise, the section is split into paragraphs.
- For paragraphs exceeding max_tokens, sentence tokenization is used
to further split into chunks while retaining an overlap (by token count)
for contextual continuity.
"""
# Split text into sections based on header lines (lines starting with "##").
sections = re.split(r'(?=^#{2,5}\s)', text, flags=re.MULTILINE)
chunks = []
for section in sections:
section = section.strip()
if not section:
continue
# If the entire section is within the token limit, add it as a chunk.
if count_tokens(section) <= max_tokens:
chunks.append(section)
else:
# Further split the section into paragraphs.
paragraphs = [p.strip() for p in section.split("\n\n") if p.strip()]
for para in paragraphs:
if count_tokens(para) <= max_tokens:
chunks.append(para)
else:
# The paragraph is too long; split it using sentence tokenization.
sentences = sent_tokenize(para)
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_token_count = count_tokens(sentence)
if current_length + sentence_token_count > max_tokens:
# Finalize and save the current chunk.
chunk = " ".join(current_chunk)
chunks.append(chunk)
# Create overlap: retain the last sentences to meet the desired overlap.
overlap_sentences = []
token_sum = 0
for sent in reversed(current_chunk):
token_sum += count_tokens(sent)
overlap_sentences.insert(0, sent)
if token_sum >= overlap:
break
current_chunk = overlap_sentences.copy()
current_length = sum(count_tokens(sent) for sent in current_chunk)
current_chunk.append(sentence)
current_length += sentence_token_count
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def load_documents_from_files(file_list):
"""
Loads the specified .txt or .md files.
Splits each file into chunks using the hybrid chunking function.
Returns a list of dictionaries with keys 'source' and 'text'.
"""
documents = []
for file_path in file_list:
if not (file_path.endswith('.txt') or file_path.endswith('.md')):
logging.warning(f"CREATE: Skipping non-txt file: {file_path}")
continue
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
chunks = chunk_text_hybrid(content)
for chunk in chunks:
documents.append({
'source': os.path.basename(file_path),
'text': chunk
})
except Exception as e:
logging.error(f"CREATE: Error reading {file_path}: {e}")
return documents
# ============================================================
# Step 2: Generate Embeddings and Build FAISS HNSW Index (Cosine Similarity via Normalization and Inner Product)
# ============================================================
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
logging.info("CREATE: Loading embedding model...")
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
model = AutoModel.from_pretrained(embedding_model_name)
def embed_text(text):
"""
Tokenizes and embeds the provided text using the transformer model.
Returns a numpy array of the embedding.
"""
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
with torch.no_grad():
model_output = model(**inputs)
# Mean pooling of the last hidden state as a simple embedding
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
return embeddings.numpy()
def main():
parser = argparse.ArgumentParser(
description="Generate embeddings from specified text files and create a FAISS index using a hybrid chunking approach."
)
parser.add_argument("--files", nargs='+', default=[],
help="One or more .txt or .md files to process.")
parser.add_argument("--doc_dir", default=None,
help="Optional: Path to a directory containing *.txt files to process.")
parser.add_argument("--index_output", default="faiss_index.bin",
help="Output file for the FAISS index (default: faiss_index.bin).")
args = parser.parse_args()
# Combine files specified via --files and those from --doc_dir (if provided)
files = args.files.copy()
if args.doc_dir:
# Get all .txt files from the directory
dir_files = glob.glob(os.path.join(args.doc_dir, "*.txt"))
files.extend(dir_files)
if not files:
logging.error("CREATE: No input files provided via --files or --doc_dir.")
exit(1)
documents = load_documents_from_files(files)
logging.info(f"CREATE: Loaded {len(documents)} document chunks.")
logging.info("CREATE: Generating embeddings for document chunks...")
document_embeddings = []
for doc in documents:
emb = embed_text(doc['text'])
document_embeddings.append(emb)
document_embeddings = np.array(document_embeddings, dtype='float32')
# Normalize embeddings for cosine similarity
norms = np.linalg.norm(document_embeddings, axis=1, keepdims=True)
document_embeddings = document_embeddings / norms
# Build the FAISS HNSW index using inner product (for cosine similarity with normalized vectors)
dimension = document_embeddings.shape[1]
logging.info(f"CREATE: Building FAISS index for {document_embeddings.shape[0]} vectors with dimension {dimension} using cosine similarity...")
index = faiss.IndexHNSWFlat(dimension, 40, faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = 32 # Construction parameter; adjust if needed.
index.add(document_embeddings)
index.hnsw.efSearch = 60 # Search parameter for better performance
faiss.write_index(index, args.index_output)
logging.info(f"CREATE: FAISS index created with {index.ntotal} vectors and saved to {args.index_output}.")
# Save the mapping data (document metadata) to a file.
base, _ = os.path.splitext(args.index_output)
mapping_filename = f"{base}_mapping.pkl"
try:
with open(mapping_filename, "wb") as f:
pickle.dump(documents, f)
logging.info(f"CREATE: Mapping data saved to {mapping_filename}.")
except Exception as e:
logging.error(f"CREATE: Error saving mapping data to {mapping_filename}: {e}")
if __name__ == "__main__":
main()