Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.lucene.sandbox.codecs.faiss;

import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
import static java.lang.foreign.ValueLayout.JAVA_INT;
import static java.lang.foreign.ValueLayout.JAVA_LONG;
Expand All @@ -32,8 +33,6 @@
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.Locale;
import org.apache.lucene.index.FloatVectorValues;
Expand Down Expand Up @@ -221,16 +220,22 @@ public static MemorySegment createIndex(

// Allocate docs in native memory
MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension);
FloatBuffer docsBuffer = docs.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer();
long docsOffset = 0;
long perDocByteSize = dimension * JAVA_FLOAT.byteSize();

// Allocate ids in native memory
MemorySegment ids = temp.allocate(JAVA_LONG, size);
LongBuffer idsBuffer = ids.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer();
int idsIndex = 0;

KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) {
idsBuffer.put(oldToNewDocId.apply(i));
docsBuffer.put(floatVectorValues.vectorValue(iterator.index()));
int id = oldToNewDocId.apply(i);
ids.setAtIndex(JAVA_LONG, idsIndex, id);
idsIndex++;

float[] vector = floatVectorValues.vectorValue(iterator.index());
MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length);
docsOffset += perDocByteSize;
}

// Train index
Expand All @@ -254,18 +259,12 @@ private static long writeBytes(
inputPointer = inputPointer.reinterpret(size);

if (size <= BUFFER_SIZE) { // simple case, avoid buffering
byte[] bytes = new byte[(int) size];
inputPointer.asSlice(0, size).asByteBuffer().order(ByteOrder.nativeOrder()).get(bytes);
output.writeBytes(bytes, bytes.length);
output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size);
} else { // copy buffered number of bytes repeatedly
byte[] bytes = new byte[BUFFER_SIZE];
for (long offset = 0; offset < size; offset += BUFFER_SIZE) {
int length = (int) Math.min(size - offset, BUFFER_SIZE);
inputPointer
.asSlice(offset, length)
.asByteBuffer()
.order(ByteOrder.nativeOrder())
.get(bytes, 0, length);
MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length);
output.writeBytes(bytes, length);
}
}
Expand All @@ -282,21 +281,13 @@ private static long readBytes(
if (size <= BUFFER_SIZE) { // simple case, avoid buffering
byte[] bytes = new byte[(int) size];
input.readBytes(bytes, 0, bytes.length);
outputPointer
.asSlice(0, bytes.length)
.asByteBuffer()
.order(ByteOrder.nativeOrder())
.put(bytes);
MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length);
} else { // copy buffered number of bytes repeatedly
byte[] bytes = new byte[BUFFER_SIZE];
for (long offset = 0; offset < size; offset += BUFFER_SIZE) {
int length = (int) Math.min(size - offset, BUFFER_SIZE);
input.readBytes(bytes, 0, length);
outputPointer
.asSlice(offset, length)
.asByteBuffer()
.order(ByteOrder.nativeOrder())
.put(bytes, 0, length);
MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length);
}
}
return numItems;
Expand Down Expand Up @@ -411,8 +402,7 @@ public static void indexSearch(
};

// Allocate queries in native memory
MemorySegment queries = temp.allocate(JAVA_FLOAT, query.length);
queries.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().put(query);
MemorySegment queries = temp.allocateFrom(JAVA_FLOAT, query);

// Faiss knn search
int k = knnCollector.k();
Expand Down Expand Up @@ -458,13 +448,9 @@ public static void indexSearch(
idsPointer);
}

// Retrieve scores
float[] distances = new float[k];
distancesPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().get(distances);

// Retrieve ids
long[] ids = new long[k];
idsPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer().get(ids);
// Retrieve scores and ids
float[] distances = distancesPointer.toArray(JAVA_FLOAT);
long[] ids = idsPointer.toArray(JAVA_LONG);

// Record hits
for (int i = 0; i < k; i++) {
Expand Down
Loading