Skip to content

Commit

Permalink
Fix edge case that input str is removed by hugging face tokenizers fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed May 21, 2021
1 parent 40e59ed commit f1c1c71
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
10 changes: 9 additions & 1 deletion hanlp/transform/transformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def __init__(self,
cls_token_at_end = xlnet
pad_on_left = xlnet
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer_.from_pretrained(tokenizer, use_fast=use_fast, do_basic_tokenize=do_basic_tokenize)
tokenizer = AutoTokenizer_.from_pretrained(tokenizer, use_fast=use_fast,
do_basic_tokenize=do_basic_tokenize)
if use_fast:
# Dirty fix upstream bug: https://github.com/hankcs/HanLP/issues/1602
if hasattr(tokenizer, '_tokenizer') and hasattr(tokenizer._tokenizer, 'no_truncation'):
Expand Down Expand Up @@ -277,6 +278,13 @@ def tokenize_str(input_str, add_special_tokens=True):
if add_special_tokens:
subtoken_offsets = subtoken_offsets[1 if self.has_cls else 0:-1]

# Edge case that the input_str is swallowed in whole
if not subtoken_offsets and not input_str.isspace():
__index = 1 if add_special_tokens and self.has_cls else 0
input_tokens.insert(__index, input_str)
input_ids.insert(__index, tokenizer.unk_token_id)
subtoken_offsets.append((0, len(input_str)))

if not self.has_cls:
input_tokens = [self.cls_token] + input_tokens
input_ids = [self.cls_token_id] + input_ids
Expand Down
2 changes: 1 addition & 1 deletion hanlp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# Author: hankcs
# Date: 2019-12-28 19:26

__version__ = '2.1.0-alpha.43'
__version__ = '2.1.0-alpha.44'
"""HanLP version"""
7 changes: 3 additions & 4 deletions tests/test_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ def tokenize(mtl, text):


class TestMultiTaskLearning(unittest.TestCase):

def setUp(self) -> None:
super().setUp()

def test_mtl_single_sent(self):
doc: Document = mtl('商品和服务')
self.assertSequenceEqual(doc['tok/fine'], ["商品", "和", "服务"])
Expand Down Expand Up @@ -47,6 +43,9 @@ def test_emoji(self):
self.assertSequenceEqual(mtl('( ͡° ͜ʖ ͡ °)你好', tasks='tok/fine')['tok/fine'],
["( ͡° ͜ʖ ͡ °)", "你", "好"])

def test_unicode_removed_by_hf(self):
self.assertSequenceEqual(mtl('͡', tasks='tok/fine')['tok/fine'], ['͡'])

def test_space(self):
doc: Document = mtl('商品 和服务')
self.assertSequenceEqual(doc['tok/fine'], ["商品", "和", "服务"])
Expand Down

0 comments on commit f1c1c71

Please sign in to comment.