Skip to content

Commit

Permalink
migiate nan float issue on certain platform by introduce ...Raw() var…
Browse files Browse the repository at this point in the history
…iant function to DATrie

On certain platform, nan may be canonicalized to a different value. So
DATrie<float> may need to be very careful to do comparison for check
value. So instead, we introduce Raw() function to directly return the
internal representation int32_t for value comparison. And only decode
when the value is valid.
  • Loading branch information
wengxt committed Dec 20, 2024
1 parent 5bb49ec commit 7650dee
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 49 deletions.
85 changes: 55 additions & 30 deletions src/libime/core/datrie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ struct NanValue<float> {
}
};

template <typename T>
constexpr inline T decodeImpl(int32_t raw) {
typename DATriePrivate<T>::decoder_type d;
d.result = raw;
return d.result_value;
}

} // namespace

#if 0
Expand Down Expand Up @@ -352,16 +359,6 @@ class DATriePrivate {
to = from;
}
}
value_type traverse(const char *key, npos_t &from, size_t &pos) const {
return traverse(key, from, pos, std::strlen(key));
}

value_type traverse(const char *key, npos_t &from, size_t &pos,
size_t len) const {
decoder_type d;
d.result = _find(key, from, pos, len);
return d.result_value;
}

template <typename U>
inline void update(const char *key, const U &callback) {
Expand Down Expand Up @@ -534,13 +531,13 @@ class DATriePrivate {
}

bool foreach(const callback_type &callback, npos_t root = npos_t()) const {
decoder_type b;
int32_t resultRaw;
size_t p(0);
npos_t from = root;
for (b.result = begin(from, p); b.result != CEDAR_NO_PATH;
b.result = next(from, p, root)) {
if (b.result != CEDAR_NO_VALUE &&
!callback(b.result_value, p, from.toInt())) {
for (resultRaw = begin(from, p); resultRaw != CEDAR_NO_PATH;
resultRaw = next(from, p, root)) {
if (resultRaw != CEDAR_NO_VALUE &&
!callback(decodeImpl<V>(resultRaw), p, from.toInt())) {
return false;
}
}
Expand Down Expand Up @@ -1111,14 +1108,18 @@ bool DATrie<T>::erase(position_type from) {
template <typename T>
typename DATrie<T>::value_type DATrie<T>::exactMatchSearch(const char *key,
size_t len) const {
return decodeImpl<T>(exactMatchSearchRaw(key, len));
}

template <typename T>
int32_t DATrie<T>::exactMatchSearchRaw(const char *key, size_t len) const {
size_t pos = 0;
typename DATriePrivate<value_type>::npos_t npos;
typename DATriePrivate<T>::decoder_type decoder;
decoder.result = d->_find(key, npos, pos, len);
if (decoder.result == DATriePrivate<value_type>::CEDAR_NO_PATH) {
decoder.result = DATriePrivate<value_type>::CEDAR_NO_VALUE;
auto resultRaw = d->_find(key, npos, pos, len);
if (resultRaw == DATriePrivate<value_type>::CEDAR_NO_PATH) {
resultRaw = DATriePrivate<value_type>::CEDAR_NO_VALUE;
}
return decoder.result_value;
return resultRaw;
}

template <typename T>
Expand All @@ -1129,9 +1130,15 @@ bool DATrie<T>::hasExactMatch(std::string_view key) const {
template <typename T>
typename DATrie<T>::value_type DATrie<T>::traverse(const char *key, size_t len,
position_type &from) const {
return decodeImpl<T>(traverseRaw(key, len, from));
}

template <typename T>
int32_t DATrie<T>::traverseRaw(const char *key, size_t len,
position_type &from) const {
size_t pos = 0;
typename DATriePrivate<T>::npos_t npos(from);
auto result = d->traverse(key, npos, pos, len);
auto result = d->_find(key, npos, pos, len);
from = npos.toInt();
return result;
}
Expand All @@ -1150,33 +1157,51 @@ template <typename T>
bool DATrie<T>::isNoPath(value_type v) {
typename DATriePrivate<T>::decoder_type d;
d.result_value = v;
return d.result == DATriePrivate<value_type>::CEDAR_NO_PATH;
return isNoPathRaw(d.result);
}

template <typename T>
bool DATrie<T>::isNoValue(value_type v) {
typename DATriePrivate<T>::decoder_type d;
d.result_value = v;
return d.result == DATriePrivate<value_type>::CEDAR_NO_VALUE;
return isNoValueRaw(d.result);
}

template <typename T>
bool DATrie<T>::isValid(value_type v) {
return !(isNoPath(v) || isNoValue(v));
typename DATriePrivate<T>::decoder_type d;
d.result_value = v;
return isValidRaw(d.result);
}

template <typename T>
bool DATrie<T>::isNoPathRaw(int32_t v) {
return v == DATriePrivate<value_type>::CEDAR_NO_PATH;
}

template <typename T>
bool DATrie<T>::isNoValueRaw(int32_t v) {
return v == DATriePrivate<value_type>::CEDAR_NO_VALUE;
}

template <typename T>
bool DATrie<T>::isValidRaw(int32_t v) {
return !(isNoPathRaw(v) || isNoValueRaw(v));
}

template <typename T>
T DATrie<T>::noPath() {
typename DATriePrivate<T>::decoder_type d;
d.result = DATriePrivate<value_type>::CEDAR_NO_PATH;
return d.result_value;
return decodeImpl<T>(DATriePrivate<value_type>::CEDAR_NO_PATH);
}

template <typename T>
T DATrie<T>::noValue() {
typename DATriePrivate<T>::decoder_type d;
d.result = DATriePrivate<value_type>::CEDAR_NO_VALUE;
return d.result_value;
return decodeImpl<T>(DATriePrivate<value_type>::CEDAR_NO_VALUE);
}

template <typename T>
T DATrie<T>::decode(int32_t raw) {
return decodeImpl<T>(raw);
}

template <typename T>
Expand Down
16 changes: 16 additions & 0 deletions src/libime/core/datrie.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class DATrie {
return exactMatchSearch(key.data(), key.size());
}

int32_t exactMatchSearchRaw(const char *key, size_t len) const;
int32_t exactMatchSearchRaw(std::string_view key) const {
return exactMatchSearchRaw(key.data(), key.size());
}

bool hasExactMatch(std::string_view key) const;

DATrie<T>::value_type traverse(std::string_view key,
Expand All @@ -99,6 +104,11 @@ class DATrie {
DATrie<T>::value_type traverse(const char *key, size_t len,
position_type &from) const;

int32_t traverseRaw(std::string_view key, position_type &from) const {
return traverseRaw(key.data(), key.size(), from);
}
int32_t traverseRaw(const char *key, size_t len, position_type &from) const;

// set value
void set(std::string_view key, value_type val) {
return set(key.data(), key.size(), val);
Expand Down Expand Up @@ -138,9 +148,15 @@ class DATrie {
static bool isNoPath(value_type v);
static bool isNoValue(value_type v);

static bool isValidRaw(int32_t v);
static bool isNoPathRaw(int32_t v);
static bool isNoValueRaw(int32_t v);

static value_type noPath();
static value_type noValue();

static value_type decode(int32_t raw);

size_t mem_size() const;

private:
Expand Down
24 changes: 13 additions & 11 deletions src/libime/pinyin/pinyindictionary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "pinyinmatchstate_p.h"
#include <boost/algorithm/string.hpp>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <iomanip>
#include <optional>
Expand Down Expand Up @@ -158,10 +160,10 @@ inline void searchOneStep(
auto iter = nodes.begin();
while (iter != nodes.end()) {
if (current != 0) {
PinyinTrie::value_type result;
result = iter->first->traverse(&current, 1, iter->second);
const auto resultRaw =
iter->first->traverseRaw(&current, 1, iter->second);

if (PinyinTrie::isNoPath(result)) {
if (PinyinTrie::isNoPathRaw(resultRaw)) {
nodes.erase(iter++);
} else {
iter++;
Expand All @@ -171,8 +173,8 @@ inline void searchOneStep(
for (char test = PinyinEncoder::firstFinal;
test <= PinyinEncoder::lastFinal; test++) {
decltype(extraNodes)::value_type p = *iter;
auto result = p.first->traverse(&test, 1, p.second);
if (!PinyinTrie::isNoPath(result)) {
const auto resultRaw = p.first->traverseRaw(&test, 1, p.second);
if (!PinyinTrie::isNoPathRaw(resultRaw)) {
extraNodes.push_back(p);
changed = true;
}
Expand Down Expand Up @@ -371,8 +373,8 @@ PinyinTriePositions traverseAlongPathOneStepBySyllables(
// make a copy
auto pos = _pos;
auto initial = static_cast<char>(syl.first);
auto result = path.trie()->traverse(&initial, 1, pos);
if (PinyinTrie::isNoPath(result)) {
const auto resultRaw = path.trie()->traverseRaw(&initial, 1, pos);
if (PinyinTrie::isNoPathRaw(resultRaw)) {
continue;
}
const auto &finals = syl.second;
Expand All @@ -381,9 +383,9 @@ PinyinTriePositions traverseAlongPathOneStepBySyllables(
size_t fuzzyFactor,
auto pos) {
auto final = static_cast<char>(pyFinal);
auto result = path.trie()->traverse(&final, 1, pos);
const auto resultRaw = path.trie()->traverseRaw(&final, 1, pos);

if (!PinyinTrie::isNoPath(result)) {
if (!PinyinTrie::isNoPathRaw(resultRaw)) {
size_t newFuzzies = fuzzies + fuzzyFactor;
positions.emplace_back(pos, newFuzzies);
}
Expand Down Expand Up @@ -446,8 +448,8 @@ void matchWordsOnTrie(const PinyinTrie *userDict, const MatchedPinyinPath &path,
pos);
} else {
const char sep = pinyinHanziSep;
auto result = path.trie()->traverse(&sep, 1, pos);
if (PinyinTrie::isNoPath(result)) {
const auto resultRaw = path.trie()->traverseRaw(&sep, 1, pos);
if (PinyinTrie::isNoPathRaw(resultRaw)) {
continue;
}

Expand Down
17 changes: 13 additions & 4 deletions src/libime/pinyin/pinyindictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,29 @@
#define _FCITX_LIBIME_PINYIN_PINYINDICTIONARY_H_

#include "libimepinyin_export.h"
#include <cstddef>
#include <fcitx-utils/flags.h>
#include <fcitx-utils/macros.h>
#include <functional>
#include <istream>
#include <libime/core/dictionary.h>
#include <libime/core/segmentgraph.h>
#include <libime/core/triedictionary.h>
#include <libime/pinyin/pinyinencoder.h>
#include <memory>
#include <optional>
#include <ostream>
#include <string_view>
#include <unordered_set>

namespace libime {

enum class PinyinDictFormat { Text, Binary };

class PinyinDictionaryPrivate;

typedef std::function<bool(std::string_view encodedPinyin,
std::string_view hanzi, float cost)>
PinyinMatchCallback;
using PinyinMatchCallback =
std::function<bool(std::string_view, std::string_view, float)>;

using PinyinTrie = typename TrieDictionary::TrieType;

Expand Down Expand Up @@ -64,7 +73,7 @@ class LIBIMEPINYIN_EXPORT PinyinDictionary : public TrieDictionary {
void save(size_t idx, std::ostream &out, PinyinDictFormat format);

void addWord(size_t idx, std::string_view fullPinyin,
std::string_view hanzi, float cost = 0.0f);
std::string_view hanzi, float cost = 0.0F);
bool removeWord(size_t idx, std::string_view fullPinyin,
std::string_view hanzi);
std::optional<float> lookupWord(size_t idx, std::string_view fullPinyin,
Expand Down
5 changes: 3 additions & 2 deletions src/libime/table/tablebaseddictionary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <fcitx-utils/stringutils.h>
#include <fcitx-utils/utf8.h>
#include <fstream>
#include <iterator>
#include <set>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -275,7 +276,7 @@ bool TableBasedDictionaryPrivate::matchTrie(
auto curPos = position;
auto strCode = fcitx::utf8::UCS4ToUTF8(code);
auto result = trie.traverse(strCode, curPos);
if (!trie.isNoPath(result)) {
if (!DATrie<unsigned int>::isNoPath(result)) {
newPositions.push_back(curPos);
}
}
Expand All @@ -286,7 +287,7 @@ bool TableBasedDictionaryPrivate::matchTrie(
std::distance(charRange.first, charRange.second));
auto curPos = position;
auto result = trie.traverse(chr, curPos);
if (!trie.isNoPath(result)) {
if (!DATrie<unsigned int>::isNoPath(result)) {
newPositions.push_back(curPos);
}
}
Expand Down
5 changes: 3 additions & 2 deletions test/testtrie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* SPDX-License-Identifier: LGPL-2.1-or-later
*/
#include "libime/core/datrie.h"
#include <cstdint>
#include <cstring>
#include <fcitx-utils/log.h>

Expand Down Expand Up @@ -39,8 +40,8 @@ int main() {
FCITX_ASSERT(trie.size() == 4);
DATrie<float>::position_type pos = 0;
auto result = trie.traverse("aaa", pos);
auto nan1 = trie.noValue();
auto nan2 = trie.noPath();
auto nan1 = DATrie<float>::noValue();
auto nan2 = DATrie<float>::noPath();
// NaN != NaN, we must use memcmp to do this.
FCITX_ASSERT(memcmp(&nan1, &result, sizeof(float)) == 0);
FCITX_ASSERT(trie.isNoValue(result));
Expand Down

0 comments on commit 7650dee

Please sign in to comment.