Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.apache.lucene.index.SegmentWriteState;

/**
* A Faiss-based format to create and search vector indexes, using {@link LibFaissC} to interact
* A Faiss-based format to create and search vector indexes, using {@link FaissLibrary} to interact
* with the native library.
*
* <p>The Faiss index is configured using its flexible <a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,8 @@
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -44,7 +38,6 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataAccessHint;
Expand All @@ -61,16 +54,13 @@
final class FaissKnnVectorsReader extends KnnVectorsReader {
private final FlatVectorsReader rawVectorsReader;
private final IndexInput data;
private final Map<String, IndexEntry> indexMap;
private final Arena arena;
private final Map<String, FaissLibrary.Index> indexMap;
private boolean closed;

public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader)
throws IOException {
this.rawVectorsReader = rawVectorsReader;
this.indexMap = new HashMap<>();
this.arena = Arena.ofShared();
this.closed = false;

List<FieldMeta> fieldMetaList = new ArrayList<>();
String metaFileName =
Expand Down Expand Up @@ -125,9 +115,11 @@ public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVector
CodecUtil.retrieveChecksum(data);

for (FieldMeta fieldMeta : fieldMetaList) {
if (indexMap.put(fieldMeta.fieldInfo.name, loadField(data, arena, fieldMeta)) != null) {
throw new CorruptIndexException("Duplicate field: " + fieldMeta.fieldInfo.name, meta);
if (indexMap.containsKey(fieldMeta.name)) {
throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta);
}
IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length);
indexMap.put(fieldMeta.name, FaissLibrary.INSTANCE.readIndex(indexInput));
}
} catch (Throwable t) {
IOUtils.closeWhileSuppressingExceptions(t, this);
Expand All @@ -150,21 +142,7 @@ private static FieldMeta parseNextField(IndexInput meta, SegmentReadState state)
long dataOffset = meta.readLong();
long dataLength = meta.readLong();

return new FieldMeta(fieldInfo, dataOffset, dataLength);
}

@SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC
private static IndexEntry loadField(IndexInput data, Arena arena, FieldMeta fieldMeta)
throws IOException {
int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY;

// Read index into memory
MemorySegment indexPointer =
indexRead(data.slice(fieldMeta.fieldInfo.name, fieldMeta.offset, fieldMeta.length), ioFlags)
// Ensure timely cleanup
.reinterpret(arena, LibFaissC::freeIndex);

return new IndexEntry(indexPointer, fieldMeta.fieldInfo.getVectorSimilarityFunction());
return new FieldMeta(fieldInfo.name, dataOffset, dataLength);
}

@Override
Expand All @@ -188,9 +166,9 @@ public ByteVectorValues getByteVectorValues(String field) {

@Override
public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) {
IndexEntry entry = indexMap.get(field);
if (entry != null) {
indexSearch(entry.indexPointer, entry.function, vector, knnCollector, acceptDocs);
FaissLibrary.Index index = indexMap.get(field);
if (index != null) {
index.search(vector, knnCollector, acceptDocs);
}
}

Expand All @@ -210,12 +188,16 @@ public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
@Override
public void close() throws IOException {
if (closed == false) {
// Close all indexes
for (FaissLibrary.Index index : indexMap.values()) {
index.close();
}
indexMap.clear();

IOUtils.close(rawVectorsReader, data);
closed = true;
IOUtils.close(rawVectorsReader, arena::close, data, indexMap::clear);
}
}

private record FieldMeta(FieldInfo fieldInfo, long offset, long length) {}

private record IndexEntry(MemorySegment indexPointer, VectorSimilarityFunction function) {}
private record FieldMeta(String name, long offset, long length) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -43,7 +37,6 @@
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
Expand Down Expand Up @@ -154,26 +147,23 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
}
}

@SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC
private void writeFloatField(
FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId)
throws IOException {
int number = fieldInfo.number;
meta.writeInt(number);

// Write index to temp file and deallocate from memory
try (Arena temp = Arena.ofConfined()) {
VectorSimilarityFunction function = fieldInfo.getVectorSimilarityFunction();
MemorySegment indexPointer =
createIndex(description, indexParams, function, floatVectorValues, oldToNewDocId)
// Ensure timely cleanup
.reinterpret(temp, LibFaissC::freeIndex);

int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY;
try (FaissLibrary.Index index =
FaissLibrary.INSTANCE.createIndex(
description,
indexParams,
fieldInfo.getVectorSimilarityFunction(),
floatVectorValues,
oldToNewDocId)) {

// Write index
long dataOffset = data.getFilePointer();
indexWrite(indexPointer, data, ioFlags);
index.write(data);
long dataLength = data.getFilePointer() - dataOffset;

meta.writeLong(dataOffset);
Expand Down Expand Up @@ -233,7 +223,7 @@ public int size() {

@Override
public FloatVectorValues copy() {
return new BufferedFloatVectorValues(floats, dimension, docIdSet);
throw new AssertionError("Should not be called");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.sandbox.codecs.faiss;

import java.io.Closeable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.IntToIntFunction;

/**
* Minimal interface to create and query Faiss indexes.
*
* @lucene.experimental
*/
interface FaissLibrary {
FaissLibrary INSTANCE = lookup();

// TODO: Use vectorized version where available
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what does this TODO mean again? Is "vectorized" meaning "SIMD instructions"? This term (vector) is overloaded! https://youtu.be/fVq4_HhBK8Y

Maybe clarify to // TODO: use SIMD Faiss API versions where available or so (if that's what it really means)? I think this is necessary because the Faiss build process will produce specific dynamic library for specific SIMD targets (AVX-512 vs AVX-128 etc.)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry, I meant using SIMD instructions wherever available..

Today, the shared library of Faiss' C API (libfaiss_c.so) is linked to the non-SIMD version of the base library (libfaiss.so) by default, but a user can still "point" to the correct SIMD version by changing its dependencies using:

# patchelf --replace-needed OLD_DEPENDENCY NEW_DEPENDENCY SHARED_LIBRARY
patchelf --replace-needed libfaiss.so libfaiss_{avx2,avx512,sve}.so libfaiss_c.so

However, we'd ideally want to do this automatically (either propose a change to upstream Faiss, or something else from Lucene) -- but I wasn't sure how to do it right now..

I'll update the comment soon!

String NAME = "faiss_c";
String VERSION = "1.11.0";

private static FaissLibrary lookup() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have that dynamic code in main branch? Please replace by a simple return new FaissLibraryNativeImpl().

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, updated

final MethodHandles.Lookup lookup = MethodHandles.lookup();

final Class<?> cls;
try {
cls = lookup.findClass("org.apache.lucene.sandbox.codecs.faiss.FaissLibraryNativeImpl");
} catch (ClassNotFoundException | IllegalAccessException e) {
throw new LinkageError("FaissLibraryNativeImpl class is missing or inaccessible", e);
}

final MethodHandle constr;
try {
constr = lookup.findConstructor(cls, MethodType.methodType(void.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new LinkageError("FaissLibraryNativeImpl constructor is missing or inaccessible", e);
}

try {
return (FaissLibrary) constr.invoke();
} catch (RuntimeException | Error e) {
throw e;
} catch (Throwable t) {
throw new AssertionError("Should not throw checked exceptions", t);
}
}

interface Index extends Closeable {
void search(float[] query, KnnCollector knnCollector, Bits acceptDocs);

void write(IndexOutput output);
}

Index createIndex(
String description,
String indexParams,
VectorSimilarityFunction function,
FloatVectorValues floatVectorValues,
IntToIntFunction oldToNewDocId);

Index readIndex(IndexInput input);
}
Loading
Loading