diff --git a/docs/api/spell.rst b/docs/api/spell.rst index fa36e5779..ce2dd035d 100644 --- a/docs/api/spell.rst +++ b/docs/api/spell.rst @@ -19,6 +19,12 @@ correct_sent The `correct_sent` function is an extension of the `correct` function and is used to correct an entire sentence. It tokenizes the input sentence, corrects each word, and returns the corrected sentence. This is beneficial for proofreading and improving the readability of Thai text. +get_words_spell_suggestion +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: get_words_spell_suggestion + +The `get_words_spell_suggestion` function is designed to retrieve spelling suggestions for one or more input Thai words. + spell ~~~~~ .. autofunction:: spell diff --git a/pythainlp/spell/__init__.py b/pythainlp/spell/__init__.py index 65c9979f8..625740305 100644 --- a/pythainlp/spell/__init__.py +++ b/pythainlp/spell/__init__.py @@ -13,6 +13,7 @@ "correct_sent", "spell", "spell_sent", + "get_words_spell_suggestion", ] from pythainlp.spell.pn import NorvigSpellChecker @@ -21,3 +22,4 @@ # these imports are placed here to avoid circular imports from pythainlp.spell.core import correct, correct_sent, spell, spell_sent +from pythainlp.spell.words_spelling_correction import get_words_spell_suggestion diff --git a/pythainlp/spell/words_spelling_correction.py b/pythainlp/spell/words_spelling_correction.py new file mode 100644 index 000000000..e2029d9f4 --- /dev/null +++ b/pythainlp/spell/words_spelling_correction.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: 2016-2025 PyThaiNLP Project +# SPDX-FileType: SOURCE +# SPDX-License-Identifier: Apache-2.0 +import os +from pythainlp.corpus import get_hf_hub +from typing import List, Union + + +class FastTextEncoder: + """ + A class to load pre-trained FastText-like word embeddings, + compute word and sentence vectors, and interact with an ONNX + model for nearest neighbor suggestions. + """ + + # --- Initialization and Data Loading --- + + def __init__(self, model_dir, nn_model_path, words_list, bucket=2000000, nb_words=2000000, minn=5, maxn=5): + """ + Initializes the FastTextEncoder, loading embeddings, vocabulary, + nearest neighbor model, and suggestion words list. + + Args: + model_dir (str): Directory containing 'embeddings.npy' and 'vocabulary.txt'. + nn_model_path (str): Path to the ONNX nearest neighbors model. + words_list (str): the list of words for suggestions. + bucket (int): The size of the hash bucket for subword hashing. + nb_words (int): The number of words in the vocabulary (used as an offset for subword indices). + minn (int): Minimum character length for subwords. + maxn (int): Maximum character length for subwords. + """ + try: + import numpy as np # reduce load + import onnxruntime + self.np = np + except ModuleNotFoundError: + raise ModuleNotFoundError(""" + Please installing the package via 'pip install numpy onnxruntime'. + """) + except Exception as e: + raise Exception(f"An unexpected error occurred: {e}") + self.model_dir = model_dir + self.nn_model_path = nn_model_path + self.bucket = bucket + self.nb_words = nb_words + self.minn = minn + self.maxn = maxn + + # Load data and models + self.vocabulary, self.embeddings = self._load_embeddings() + self.words_for_suggestion = self._load_suggestion_words(words_list) + self.nn_session = self._load_onnx_session(nn_model_path) + self.embedding_dim = self.embeddings.shape[1] + + def _load_embeddings(self): + """Loads embeddings matrix and vocabulary list.""" + input_matrix = self.np.load(os.path.join(self.model_dir, "embeddings.npy")) + words = [] + vocab_path = os.path.join(self.model_dir, "vocabulary.txt") + with open(vocab_path, "r", encoding='utf-8') as f: + for line in f.readlines(): + words.append(line.rstrip()) + return words, input_matrix + + def _load_suggestion_words(self, words_list): + """Loads the list of words used for suggestions.""" + words = self.np.array(words_list) + return words + + def _load_onnx_session(self, onnx_path): + """Loads the ONNX inference session.""" + # Note: Using providers=["CPUExecutionProvider"] for platform independence + import onnxruntime as rt + sess = rt.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + return sess + + # --- Helper Methods for Encoding --- + + def _get_hash(self, subword): + """Computes the FastText-like hash for a subword.""" + h = 2166136261 # FNV-1a basis + for c in subword: + c_ord = ord(c) % 2**8 + h = (h ^ c_ord) % 2**32 + h = (h * 16777619) % 2**32 # FNV-1a prime + return h % self.bucket + self.nb_words + + def _get_subwords(self, word): + """Extracts subwords and their corresponding indices for a given word.""" + _word = "<" + word + ">" + _subwords = [] + _subword_ids = [] + + # 1. Check for the word in vocabulary (full word is the first subword) + if word in self.vocabulary: + _subwords.append(word) + _subword_ids.append(self.vocabulary.index(word)) + if word == "": + return _subwords, self.np.array(_subword_ids) + + # 2. Extract n-grams (subwords) and get their hash indices + for ngram_start in range(0, len(_word)): + for ngram_length in range(self.minn, self.maxn + 1): + if ngram_start + ngram_length <= len(_word): + _candidate_subword = _word[ngram_start:ngram_start + ngram_length] + # Only append if not already included (e.g., as the full word) + if _candidate_subword not in _subwords: + _subwords.append(_candidate_subword) + _subword_ids.append(self._get_hash(_candidate_subword)) + + return _subwords, self.np.array(_subword_ids) + + def get_word_vector(self, word): + """Computes the normalized vector for a single word.""" + # subword_ids[1] contains the array of indices for the word and its subwords + subword_ids = self._get_subwords(word)[1] + + # Check if the array of subword indices is empty + if subword_ids.size == 0: + # Return a 300-dimensional zero vector if no word/subword is found. + return self.np.zeros(self.embedding_dim) + + # Compute the mean of the embeddings for all subword indices + vector = self.np.mean([self.embeddings[s] for s in subword_ids], axis=0) + + # Normalize the vector + norm = self.np.linalg.norm(vector) + if norm > 0: + vector /= norm + + return vector + + def _tokenize(self, sentence): + """Tokenizes a sentence based on whitespace.""" + tokens = [] + word = "" + for c in sentence: + if c in [' ', '\n', '\r', '\t', '\v', '\f', '\0']: + if word: + tokens.append(word) + word = "" + if c == '\n': + tokens.append("") + else: + word += c + if word: + tokens.append(word) + return tokens + + def get_sentence_vector(self, line): + """Computes the mean vector for a sentence.""" + tokens = self._tokenize(line) + vectors = [] + for t in tokens: + # get_word_vector already handles normalization, so no need to do it again here + vec = self.get_word_vector(t) + vectors.append(vec) + + # If the sentence was empty and resulted in no vectors, return a zero vector + if not vectors: + return self.np.zeros(self.embedding_dim) + + return self.np.mean(vectors, axis=0) + + # --- Nearest Neighbor Method --- + + def get_word_suggestion(self, list_word): + """ + Queries the ONNX model to find the nearest neighbor word(s) + for the given word or list of words. + + Args: + list_word (str or list of str): A single word or a list of words + to get suggestions for. + + Returns: + str or list of str: The nearest neighbor word(s) from the + pre-loaded suggestion list. + """ + if isinstance(list_word, str): + input_words = [list_word] + return_single = True + else: + input_words = list_word + return_single = False + + # Compute sentence vector for each input word/phrase + # The original code's `get_sentence_vector(' '.join(list(word)))` seems + # intended to treat a list of characters/tokens as a sentence. + # I'll stick to a more standard usage: treat each item in `input_words` + # as a separate phrase/word to encode. + word_input_vecs = [self.get_sentence_vector(' '.join(list(word))) for word in input_words] + + # Convert to numpy array for ONNX input (ensure float32) + input_data = self.np.array(word_input_vecs, dtype=self.np.float32) + + # Run ONNX inference + indices = self.nn_session.run(None, {"X": input_data})[0] + + # Look up suggestions + suggestions = [self.words_for_suggestion[i].tolist() for i in indices] + + return suggestions[0] if return_single else suggestions + + +class Words_Spelling_Correction(FastTextEncoder): + def __init__(self): + self.model_name = "pythainlp/word-spelling-correction-char2vec" + self.model_path = get_hf_hub(self.model_name) + self.model_onnx = get_hf_hub(self.model_name, "nearest_neighbors.onnx") + with open(get_hf_hub(self.model_name, "list_word-spelling-correction-char2vec.txt")) as f: + self.list_word = [i.strip() for i in f.readlines()] + super().__init__(self.model_path, self.model_onnx, self.list_word) + + +_WSC = None + + +def get_words_spell_suggestion(list_words: Union[str, List[str]]) -> Union[List[str], List[List[str]]]: + """ + Get words spell suggestion + + The function is designed to retrieve spelling suggestions \ + for one or more input Thai words. + + Requirements: numpy and onnxruntime (Install before use this function) + + :param Union[str, List[str]] list_word: list words or a word. + :return: List words spell suggestion (max 5 items per word) + :rtype: Union[List[str], List[List[str]]] + + :Example: + :: + + from pythainlp.spell import get_words_spell_suggestion + + print(get_words_spell_suggestion("คมดี")) + # output: ['คนดีผีคุ้ม', 'มีดคอม้า', 'คดี', 'มีดสองคม', 'มูลคดี'] + + print(get_words_spell_suggestion(["คมดี","กระเพาะ"])) + # output: [['คนดีผีคุ้ม', 'มีดคอม้า', 'คดี', 'มีดสองคม', 'มูลคดี'], + # ['กระเพาะ', 'กระพา', 'กะเพรา', 'กระเพาะปลา', 'พระประธาน']] + """ + global _WSC + if _WSC==None: + _WSC = Words_Spelling_Correction() + return _WSC.get_word_suggestion(list_words) diff --git a/tests/extra/testx_spell.py b/tests/extra/testx_spell.py index 14c90ae32..f15460486 100644 --- a/tests/extra/testx_spell.py +++ b/tests/extra/testx_spell.py @@ -11,6 +11,7 @@ spell, spell_sent, symspellpy, + get_words_spell_suggestion, ) from ..core.test_spell import SENT_TOKS @@ -66,3 +67,7 @@ def test_correct_sent(self): correct_sent(SENT_TOKS, engine="wanchanberta_thai_grammarly") ) self.assertIsNotNone(symspellpy.correct_sent(SENT_TOKS)) + + def test_get_words_spell_suggestion(self): + self.assertIsNotNone(get_words_spell_suggestion("คมดี")) + self.assertIsNotNone(get_words_spell_suggestion(["คมดี","มะนา"]))