194 lines
8.1 KiB
Python
194 lines
8.1 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import re
|
|
import glob
|
|
import faiss
|
|
import numpy as np
|
|
import torch
|
|
import logging
|
|
from transformers import AutoTokenizer, AutoModel
|
|
from nltk.tokenize import sent_tokenize
|
|
import argparse
|
|
import pickle
|
|
|
|
# Prevent segmentation faults by limiting OpenMP threads
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
|
# Configure logging to file
|
|
logging.basicConfig(
|
|
filename='/var/log/intelaide_python.log',
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
def count_tokens(text):
|
|
"""Count tokens using simple whitespace splitting.
|
|
Replace with a more advanced tokenizer if needed.
|
|
"""
|
|
return len(text.split())
|
|
|
|
def chunk_text_hybrid(text, max_tokens=512, overlap=20):
|
|
"""
|
|
Hybrid chunking with header awareness:
|
|
|
|
1. Splits the text into sections based on headers (lines starting with '##').
|
|
Each section includes the header and its associated content until the next header.
|
|
2. For each section:
|
|
- If the token count is <= max_tokens, the section is used as a chunk.
|
|
- Otherwise, the section is split into paragraphs.
|
|
- For paragraphs exceeding max_tokens, sentence tokenization is used
|
|
to further split into chunks while retaining an overlap (by token count)
|
|
for contextual continuity.
|
|
"""
|
|
# Split text into sections based on header lines (lines starting with "##").
|
|
sections = re.split(r'(?=^#{2,5}\s)', text, flags=re.MULTILINE)
|
|
chunks = []
|
|
|
|
for section in sections:
|
|
section = section.strip()
|
|
if not section:
|
|
continue
|
|
|
|
# If the entire section is within the token limit, add it as a chunk.
|
|
if count_tokens(section) <= max_tokens:
|
|
chunks.append(section)
|
|
else:
|
|
# Further split the section into paragraphs.
|
|
paragraphs = [p.strip() for p in section.split("\n\n") if p.strip()]
|
|
for para in paragraphs:
|
|
if count_tokens(para) <= max_tokens:
|
|
chunks.append(para)
|
|
else:
|
|
# The paragraph is too long; split it using sentence tokenization.
|
|
sentences = sent_tokenize(para)
|
|
current_chunk = []
|
|
current_length = 0
|
|
for sentence in sentences:
|
|
sentence_token_count = count_tokens(sentence)
|
|
if current_length + sentence_token_count > max_tokens:
|
|
# Finalize and save the current chunk.
|
|
chunk = " ".join(current_chunk)
|
|
chunks.append(chunk)
|
|
# Create overlap: retain the last sentences to meet the desired overlap.
|
|
overlap_sentences = []
|
|
token_sum = 0
|
|
for sent in reversed(current_chunk):
|
|
token_sum += count_tokens(sent)
|
|
overlap_sentences.insert(0, sent)
|
|
if token_sum >= overlap:
|
|
break
|
|
current_chunk = overlap_sentences.copy()
|
|
current_length = sum(count_tokens(sent) for sent in current_chunk)
|
|
current_chunk.append(sentence)
|
|
current_length += sentence_token_count
|
|
if current_chunk:
|
|
chunks.append(" ".join(current_chunk))
|
|
return chunks
|
|
|
|
def load_documents_from_files(file_list):
|
|
"""
|
|
Loads the specified .txt or .md files.
|
|
Splits each file into chunks using the hybrid chunking function.
|
|
Returns a list of dictionaries with keys 'source' and 'text'.
|
|
"""
|
|
documents = []
|
|
for file_path in file_list:
|
|
if not (file_path.endswith('.txt') or file_path.endswith('.md')):
|
|
logging.warning(f"CREATE: Skipping non-txt file: {file_path}")
|
|
continue
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
content = file.read()
|
|
chunks = chunk_text_hybrid(content)
|
|
for chunk in chunks:
|
|
documents.append({
|
|
'source': os.path.basename(file_path),
|
|
'text': chunk
|
|
})
|
|
except Exception as e:
|
|
logging.error(f"CREATE: Error reading {file_path}: {e}")
|
|
return documents
|
|
|
|
# ============================================================
|
|
# Step 2: Generate Embeddings and Build FAISS HNSW Index (Cosine Similarity via Normalization and Inner Product)
|
|
# ============================================================
|
|
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
logging.info("CREATE: Loading embedding model...")
|
|
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
|
|
model = AutoModel.from_pretrained(embedding_model_name)
|
|
|
|
def embed_text(text):
|
|
"""
|
|
Tokenizes and embeds the provided text using the transformer model.
|
|
Returns a numpy array of the embedding.
|
|
"""
|
|
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
|
with torch.no_grad():
|
|
model_output = model(**inputs)
|
|
# Mean pooling of the last hidden state as a simple embedding
|
|
embeddings = model_output.last_hidden_state.mean(dim=1).squeeze()
|
|
return embeddings.numpy()
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate embeddings from specified text files and create a FAISS index using a hybrid chunking approach."
|
|
)
|
|
parser.add_argument("--files", nargs='+', default=[],
|
|
help="One or more .txt or .md files to process.")
|
|
parser.add_argument("--doc_dir", default=None,
|
|
help="Optional: Path to a directory containing *.txt files to process.")
|
|
parser.add_argument("--index_output", default="faiss_index.bin",
|
|
help="Output file for the FAISS index (default: faiss_index.bin).")
|
|
args = parser.parse_args()
|
|
|
|
# Combine files specified via --files and those from --doc_dir (if provided)
|
|
files = args.files.copy()
|
|
if args.doc_dir:
|
|
# Get all .txt files from the directory
|
|
dir_files = glob.glob(os.path.join(args.doc_dir, "*.txt"))
|
|
files.extend(dir_files)
|
|
|
|
if not files:
|
|
logging.error("CREATE: No input files provided via --files or --doc_dir.")
|
|
exit(1)
|
|
|
|
documents = load_documents_from_files(files)
|
|
logging.info(f"CREATE: Loaded {len(documents)} document chunks.")
|
|
|
|
logging.info("CREATE: Generating embeddings for document chunks...")
|
|
document_embeddings = []
|
|
for doc in documents:
|
|
emb = embed_text(doc['text'])
|
|
document_embeddings.append(emb)
|
|
document_embeddings = np.array(document_embeddings, dtype='float32')
|
|
|
|
# Normalize embeddings for cosine similarity
|
|
norms = np.linalg.norm(document_embeddings, axis=1, keepdims=True)
|
|
document_embeddings = document_embeddings / norms
|
|
|
|
# Build the FAISS HNSW index using inner product (for cosine similarity with normalized vectors)
|
|
dimension = document_embeddings.shape[1]
|
|
logging.info(f"CREATE: Building FAISS index for {document_embeddings.shape[0]} vectors with dimension {dimension} using cosine similarity...")
|
|
index = faiss.IndexHNSWFlat(dimension, 40, faiss.METRIC_INNER_PRODUCT)
|
|
index.hnsw.efConstruction = 32 # Construction parameter; adjust if needed.
|
|
index.add(document_embeddings)
|
|
index.hnsw.efSearch = 60 # Search parameter for better performance
|
|
|
|
faiss.write_index(index, args.index_output)
|
|
logging.info(f"CREATE: FAISS index created with {index.ntotal} vectors and saved to {args.index_output}.")
|
|
|
|
# Save the mapping data (document metadata) to a file.
|
|
base, _ = os.path.splitext(args.index_output)
|
|
mapping_filename = f"{base}_mapping.pkl"
|
|
try:
|
|
with open(mapping_filename, "wb") as f:
|
|
pickle.dump(documents, f)
|
|
logging.info(f"CREATE: Mapping data saved to {mapping_filename}.")
|
|
except Exception as e:
|
|
logging.error(f"CREATE: Error saving mapping data to {mapping_filename}: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|