Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Final version for thomaswue #674

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions prepare_thomaswue.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,19 @@ sdk use java 21.0.2-graal 1>&2

# ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_thomaswue_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:TuneInlinerExploration=1 -march=native --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner"
# Use -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping.

# Performance tuning flags, optimization level 3, maximum inlining exploration, and compile for the architecture where the native image is generated.
NATIVE_IMAGE_OPTS="-O3 -H:TuneInlinerExploration=1 -march=native"

# Need to enable preview for accessing the raw address of the foreign memory access API.
# Initializing the Scanner to make sure the unsafe access object is known as a non-null compile time constant.
NATIVE_IMAGE_OPTS="$NATIVE_IMAGE_OPTS --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner"

# There is no need for garbage collection and therefore also no safepoints required.
NATIVE_IMAGE_OPTS="$NATIVE_IMAGE_OPTS --gc=epsilon -H:-GenLoopSafepoints"

# Uncomment the following line for outputting the compiler graph to the IdealGraphVisualizer
# NATIVE_IMAGE_OPTS="$NATIVE_IMAGE_OPTS -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network"

native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_thomaswue_image dev.morling.onebrc.CalculateAverage_thomaswue
fi
234 changes: 100 additions & 134 deletions src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
* split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread.
* Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in
* the end.
*
* Runs in 0.40s on an Intel i9-13900K.
*
* Runs in 0.39s on an Intel i9-13900K.
* Credit:
* Quan Anh Mai for branchless number parsing code
* Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea
Expand Down Expand Up @@ -103,61 +101,123 @@ private static TreeMap<String, Result> accumulateResults(List<Result>[] allResul
return result;
}

private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results, List<Result> collectedResults) {
private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, List<Result> collectedResults) {
Result[] results = new Result[HASH_TABLE_SIZE];
while (true) {
long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE;
if (current >= fileEnd) {
return;
}

long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE));
long segmentStart;
if (current == fileStart) {
segmentStart = current;
}
else {
segmentStart = nextNewLine(current) + 1;
}

long dist = (segmentEnd - segmentStart) / 3;
long midPoint1 = nextNewLine(segmentStart + dist);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to learn it. Why dividing the workload into 3 parts and process them simuteneously benifits performance? I am curious about the programming skill?

Thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way the x86 processor is able to better use its processing units to execute more instructions per cycle. I am preparing some more information and blog post about this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I am looking forward your blog. I didn't find it in the resource list. Is it in progress? I don't want to miss the great blog.

long midPoint2 = nextNewLine(segmentStart + dist + dist);

Scanner scanner1 = new Scanner(segmentStart, midPoint1);
Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2);
Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd);
while (true) {
if (!scanner1.hasNext()) {
break;
}
if (!scanner2.hasNext()) {
break;
}
if (!scanner3.hasNext()) {
break;
}
long word1 = scanner1.getLong();
long word2 = scanner2.getLong();
long word3 = scanner3.getLong();
long delimiterMask1 = findDelimiter(word1);
long delimiterMask2 = findDelimiter(word2);
long delimiterMask3 = findDelimiter(word3);
Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults);
long number1 = scanNumber(scanner1);
long number2 = scanNumber(scanner2);
long number3 = scanNumber(scanner3);
record(existingResult1, number1);
record(existingResult2, number2);
record(existingResult3, number3);
}

while (scanner1.hasNext()) {
long word = scanner1.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1));
}
while (scanner2.hasNext()) {
long word = scanner2.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2));
}
while (scanner3.hasNext()) {
long word = scanner3.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3));
}
}
}

private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) {
Result existingResult;
long word = initialWord;
long pos = initialPos;
long delimiterMask = initialDelimiterMask;
long hash;
long nameAddress = scanner.pos();

// Search for ';', one long at a time. There are two common cases that a specially treated:
// (b) the ';' is found in the first 16 bytes
if (pos != 0) {
if (delimiterMask != 0) {
// Special case for when the ';' is found in the first 8 bytes.
pos = Long.numberOfTrailingZeros(pos) >>> 3;
scanner.add(pos);
word = mask(word, pos);
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash = word;

int index = hashToIndex(hash, results);
existingResult = results[index];

existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word) {
return existingResult;
}
scanner.setPos(nameAddress + pos);
}
else {
// Special case for when the ';' is found in bytes 9-16.
scanner.add(8);
hash = word;
long prevWord = word;
scanner.add(8);
word = scanner.getLong();
pos = findDelimiter(word);
if (pos != 0) {
pos = Long.numberOfTrailingZeros(pos) >>> 3;
scanner.add(pos);
word = mask(word, pos);
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
int index = hashToIndex(hash, results);
existingResult = results[index];

existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) {
return existingResult;
}
scanner.setPos(nameAddress + pos + 8);
}
else {
// Slow-path for when the ';' could not be found in the first 16 bytes.
scanner.add(8);
hash ^= word;
while (true) {
word = scanner.getLong();
pos = findDelimiter(word);
if (pos != 0) {
pos = Long.numberOfTrailingZeros(pos) >>> 3;
scanner.add(pos);
word = mask(word, pos);
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
Expand Down Expand Up @@ -204,7 +264,8 @@ private static Result findResult(long initialWord, long initialPos, Scanner scan
private static long nextNewLine(long prev) {
while (true) {
long currentWord = Scanner.UNSAFE.getLong(prev);
long pos = findNewLine(currentWord);
long input = currentWord ^ 0x0A0A0A0A0A0A0A0AL;
long pos = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L;
if (pos != 0) {
prev += Long.numberOfTrailingZeros(pos) >>> 3;
break;
Expand All @@ -216,87 +277,11 @@ private static long nextNewLine(long prev) {
return prev;
}

// Main parse loop.
private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart, List<Result> collectedResults) {
Result[] results = new Result[HASH_TABLE_SIZE];

while (true) {
long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE;

if (current >= fileEnd) {
return results;
}

long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE));
long segmentStart;
if (current == fileStart) {
segmentStart = current;
}
else {
segmentStart = nextNewLine(current) + 1;
}

long dist = (segmentEnd - segmentStart) / 3;
long midPoint1 = nextNewLine(segmentStart + dist);
long midPoint2 = nextNewLine(segmentStart + dist + dist);

Scanner scanner1 = new Scanner(segmentStart, midPoint1);
Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2);
Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd);
while (true) {
if (!scanner1.hasNext()) {
break;
}
if (!scanner2.hasNext()) {
break;
}
if (!scanner3.hasNext()) {
break;
}

long word1 = scanner1.getLong();
long word2 = scanner2.getLong();
long word3 = scanner3.getLong();
long pos1 = findDelimiter(word1);
long pos2 = findDelimiter(word2);
long pos3 = findDelimiter(word3);
Result existingResult1 = findResult(word1, pos1, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, pos2, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, pos3, scanner3, results, collectedResults);
long number1 = scanNumber(scanner1);
long number2 = scanNumber(scanner2);
long number3 = scanNumber(scanner3);
record(existingResult1, number1);
record(existingResult2, number2);
record(existingResult3, number3);
}

while (scanner1.hasNext()) {
long word = scanner1.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1));
}

while (scanner2.hasNext()) {
long word = scanner2.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2));
}

while (scanner3.hasNext()) {
long word = scanner3.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3));
}
}
}

private static long scanNumber(Scanner scanPtr) {
scanPtr.add(1);
long numberWord = scanPtr.getLong();
int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
long numberWord = scanPtr.getLongAt(scanPtr.pos() + 1);
int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000L);
long number = convertIntoNumber(decimalSepPos, numberWord);
scanPtr.add((decimalSepPos >>> 3) + 3);
scanPtr.add((decimalSepPos >>> 3) + 4);
return number;
}

Expand All @@ -316,10 +301,6 @@ private static int hashToIndex(long hash, Result[] results) {
return (int) (hashAsInt & (results.length - 1));
}

private static long mask(long word, long pos) {
return (word << ((7 - pos) << 3));
}

// Special method to convert a number in the ascii number into an int without branches created by Quan Anh Mai.
private static long convertIntoNumber(int decimalSepPos, long numberWord) {
int shift = 28 - decimalSepPos;
Expand All @@ -337,14 +318,7 @@ private static long convertIntoNumber(int decimalSepPos, long numberWord) {

private static long findDelimiter(long word) {
long input = word ^ 0x3B3B3B3B3B3B3B3BL;
long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L;
return tmp;
}

private static long findNewLine(long word) {
long input = word ^ 0x0A0A0A0A0A0A0A0AL;
long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L;
return tmp;
return (input - 0x0101010101010101L) & ~input & 0x8080808080808080L;
}

private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) {
Expand All @@ -357,14 +331,13 @@ private static Result newEntry(Result[] results, long nameAddress, int hash, int
r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8);
}
int remainingShift = (64 - (nameLength + 1 - i) << 3);
long lastWord = (scanner.getLongAt(nameAddress + i) << remainingShift);
r.lastNameLong = lastWord;
r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift);
r.nameAddress = nameAddress;
collectedResults.add(r);
return r;
}

private static class Result {
private static final class Result {
long lastNameLong, secondLastNameLong;
short min, max;
int count;
Expand Down Expand Up @@ -409,9 +382,10 @@ public String calcName() {
}
}

private static class Scanner {
private static final class Scanner {
private static final sun.misc.Unsafe UNSAFE = initUnsafe();
private long pos, end;
private long pos;
private final long end;

private static sun.misc.Unsafe initUnsafe() {
try {
Expand Down Expand Up @@ -452,13 +426,5 @@ long getLongAt(long pos) {
byte getByteAt(long pos) {
return UNSAFE.getByte(pos);
}

long getLongAt(long pos, long[] array) {
return UNSAFE.getLong(array, pos + sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET);
}

void setPos(long l) {
this.pos = l;
}
}
}
Loading