Skip to content

Commit 5abb308

Browse files
authored
Merge pull request #781 from PyThaiNLP/add-thainer-v2
Add Thai NER 2.0
2 parents 5bfa831 + 84c40d9 commit 5abb308

File tree

6 files changed

+112
-16
lines changed

6 files changed

+112
-16
lines changed

docs/api/wangchanberta.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ Notebook:
3030

3131
Modules
3232
-------
33+
.. autoclass:: NamedEntityRecognition
34+
:members:
35+
.. autoclass:: ThaiNameTagger
36+
:members:
3337
.. autofunction:: segment
3438

3539
References

pythainlp/tag/named_entity.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class NER:
1616
**Options for engine**
1717
* *thainer* - Thai NER engine
1818
* *tltk* - wrapper for `TLTK <https://pypi.org/project/tltk/>`_.
19+
* *thainer-v2* - Thai NER engine v2.0 for Thai NER 2.0
1920
2021
**Options for corpus**
2122
* *thainer* - Thai NER corpus
@@ -33,6 +34,9 @@ def load_engine(self, engine: str, corpus: str) -> None:
3334
from pythainlp.tag.thainer import ThaiNameTagger
3435

3536
self.engine = ThaiNameTagger()
37+
elif engine == "thainer-v2" and corpus == "thainer":
38+
from pythainlp.wangchanberta import NamedEntityRecognition
39+
self.engine = NamedEntityRecognition(model="pythainlp/thainer-corpus-v2-base-model")
3640
elif engine == "tltk":
3741
from pythainlp.tag import tltk
3842

@@ -49,7 +53,7 @@ def load_engine(self, engine: str, corpus: str) -> None:
4953
)
5054

5155
def tag(
52-
self, text, pos=True, tag=False
56+
self, text, pos=False, tag=False
5357
) -> Union[List[Tuple[str, str]], List[Tuple[str, str, str]], str]:
5458
"""
5559
This function tags named-entitiy from text in IOB format.
@@ -71,13 +75,13 @@ def tag(
7175
>>>
7276
>>> ner = NER("thainer")
7377
>>> ner.tag("ทดสอบนายวรรณพงษ์ ภัททิยไพบูลย์")
74-
[('ทดสอบ', 'VV', 'O'),
75-
('นาย', 'NN', 'B-PERSON'),
76-
('วรรณ', 'NN', 'I-PERSON'),
77-
('พงษ์', 'NN', 'I-PERSON'),
78-
(' ', 'PU', 'I-PERSON'),
79-
('ภัททิย', 'NN', 'I-PERSON'),
80-
('ไพบูลย์', 'NN', 'I-PERSON')]
78+
[('ทดสอบ', 'O'),
79+
('นาย', 'B-PERSON'),
80+
('วรรณ', 'I-PERSON'),
81+
('พงษ์', 'I-PERSON'),
82+
(' ', 'I-PERSON'),
83+
('ภัททิย', 'I-PERSON'),
84+
('ไพบูลย์', 'I-PERSON')]
8185
>>> ner.tag("ทดสอบนายวรรณพงษ์ ภัททิยไพบูลย์", tag=True)
8286
'ทดสอบ<PERSON>นายวรรณพงษ์ ภัททิยไพบูลย์</PERSON>'
8387
"""

pythainlp/tag/thainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ def _doc2features(doc, i) -> Dict:
7373

7474
class ThaiNameTagger:
7575
"""
76-
Thai named-entity recognizer.
76+
Thai named-entity recognizer or Thai NER.
77+
This function support Thai NER 1.4 and 1.5 only.
7778
:param str version: Thai NER version.
7879
It's support Thai NER 1.4 & 1.5.
79-
The defualt value is `1.4`
80+
The defualt value is `1.4
8081
8182
:Example:
8283
::

pythainlp/wangchanberta/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__all__ = [
33
"ThaiNameTagger",
44
"segment",
5+
"NamedEntityRecognition",
56
]
67

7-
from pythainlp.wangchanberta.core import ThaiNameTagger, segment
8+
from pythainlp.wangchanberta.core import ThaiNameTagger, segment, NamedEntityRecognition

pythainlp/wangchanberta/core.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
CamembertTokenizer,
66
pipeline,
77
)
8+
import warnings
9+
from pythainlp.tokenize import word_tokenize
810

911
_model_name = "wangchanberta-base-att-spm-uncased"
1012
_tokenizer = CamembertTokenizer.from_pretrained(
@@ -48,7 +50,7 @@ def _clear_tag(self, tag):
4850
return tag.replace("B-", "").replace("I-", "")
4951

5052
def get_ner(
51-
self, text: str, tag: bool = False
53+
self, text: str, pos: bool= False,tag: bool = False
5254
) -> Union[List[Tuple[str, str]], str]:
5355
"""
5456
This function tags named-entitiy from text in IOB format.
@@ -64,6 +66,8 @@ def get_ner(
6466
word and NER tag
6567
:rtype: Union[list[tuple[str, str]]], str
6668
"""
69+
if pos:
70+
warnings.warn("This model doesn't support output postag and It doesn't output the postag.")
6771
text = re.sub(" ", "<_>", text)
6872
self.json_ner = self.classify_tokens(text)
6973
self.output = ""
@@ -121,6 +125,86 @@ def get_ner(
121125
return self.sent_ner
122126

123127

128+
class NamedEntityRecognition:
129+
def __init__(self, model: str ="pythainlp/thainer-corpus-v2-base-model") -> None:
130+
"""
131+
This function tags named-entitiy from text in IOB format.
132+
133+
Powered by wangchanberta from VISTEC-depa\
134+
AI Research Institute of Thailand
135+
:param str model: The model that use wangchanberta pretrained.
136+
"""
137+
from transformers import AutoTokenizer
138+
from transformers import AutoModelForTokenClassification
139+
self.tokenizer = AutoTokenizer.from_pretrained(model)
140+
self.model = AutoModelForTokenClassification.from_pretrained(model)
141+
def _fix_span_error(self, words, ner):
142+
_ner = []
143+
_ner=ner
144+
_new_tag=[]
145+
for i,j in zip(words,_ner):
146+
i=self.tokenizer.decode(i)
147+
if i.isspace() and j.startswith("B-"):
148+
j="O"
149+
if i=='' or i=='<s>' or i=='</s>':
150+
continue
151+
if i=="<_>":
152+
i=" "
153+
_new_tag.append((i,j))
154+
return _new_tag
155+
def get_ner(
156+
self, text: str, pos: bool= False,tag: bool = False
157+
) -> Union[List[Tuple[str, str]], str]:
158+
"""
159+
This function tags named-entitiy from text in IOB format.
160+
Powered by wangchanberta from VISTEC-depa\
161+
AI Research Institute of Thailand
162+
163+
:param str text: text in Thai to be tagged
164+
:param bool tag: output like html tag.
165+
:return: a list of tuple associated with tokenized word group, NER tag, \
166+
and output like html tag (if the parameter `tag` is \
167+
specified as `True`). \
168+
Otherwise, return a list of tuple associated with tokenized \
169+
word and NER tag
170+
:rtype: Union[list[tuple[str, str]]], str
171+
"""
172+
import torch
173+
if pos:
174+
warnings.warn("This model doesn't support output postag and It doesn't output the postag.")
175+
words_token = word_tokenize(text.replace(" ", "<_>"))
176+
inputs=self.tokenizer(words_token,is_split_into_words=True,return_tensors="pt")
177+
ids = inputs["input_ids"]
178+
mask = inputs["attention_mask"]
179+
# forward pass
180+
outputs = self.model(ids, attention_mask=mask)
181+
logits = outputs[0]
182+
predictions = torch.argmax(logits, dim=2)
183+
predicted_token_class = [self.model.config.id2label[t.item()] for t in predictions[0]]
184+
ner_tag=self._fix_span_error(inputs['input_ids'][0],predicted_token_class)
185+
if tag:
186+
temp = ""
187+
sent = ""
188+
for idx, (word, ner) in enumerate(ner_tag):
189+
if ner.startswith("B-") and temp != "":
190+
sent += "</" + temp + ">"
191+
temp = ner[2:]
192+
sent += "<" + temp + ">"
193+
elif ner.startswith("B-"):
194+
temp = ner[2:]
195+
sent += "<" + temp + ">"
196+
elif ner == "O" and temp != "":
197+
sent += "</" + temp + ">"
198+
temp = ""
199+
sent += word
200+
201+
if idx == len(ner_tag) - 1 and temp != "":
202+
sent += "</" + temp + ">"
203+
204+
return sent
205+
return ner_tag
206+
207+
124208
def segment(text: str) -> List[str]:
125209
"""
126210
Subword tokenize. SentencePiece from wangchanberta model.

tests/test_tag.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,8 @@ def test_ner(self):
213213
)
214214

215215
# arguement `tag` is True
216-
self.assertEqual(
217-
ner.get_ner("วันที่ 15 ก.ย. 61 ทดสอบระบบเวลา 14:49 น.", tag=True),
218-
"วันที่ <DATE>15 ก.ย. 61</DATE> "
219-
"ทดสอบระบบเวลา <TIME>14:49 น.</TIME>",
216+
self.assertIsNotNone(
217+
ner.get_ner("วันที่ 15 ก.ย. 61 ทดสอบระบบเวลา 14:49 น.", tag=True)
220218
)
221219

222220
ner = ThaiNameTagger(version="1.4")
@@ -352,6 +350,10 @@ def test_NER_class(self):
352350
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
353351
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))
354352
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True))
353+
ner = NER(engine="thainer-v2")
354+
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
355+
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))
356+
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True))
355357
ner = NER(engine="tltk")
356358
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
357359
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))

0 commit comments

Comments
 (0)