Skip to content

Commit

Permalink
implemented multi-label support
Browse files Browse the repository at this point in the history
Revert "minor fix"

This reverts commit 91e9847.

 On branch master
 Your branch is up to date with 'origin/master'.

 Changes to be committed:
	modified:   hanlp/common/component.py
	modified:   hanlp/layers/transformers/loader.py

(cherry picked from commit 7bae452)

minor fix

(cherry picked from commit 91e9847)

minor fix

(cherry picked from commit d4104d7)

multi-label support cherry picked to master

(cherry picked from commit 62f7b3d)

implemented multi-label support

(cherry picked from commit a844481)
  • Loading branch information
callzhang authored and hankcs committed Dec 23, 2020
1 parent 4ea9595 commit 56c44d3
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,5 @@ fabric.properties
.idea/caches/build_file_checksums.ser
.idea
*.iml
data
data
.vscode/settings.json
5 changes: 4 additions & 1 deletion hanlp/common/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def update(self, tokens: Iterable[str]) -> None:
self.add(token)

def get_idx(self, token: str) -> int:
idx = self.token_to_idx.get(token, None)
if type(token) is list:
idx = [self.get_idx(t) for t in token]
else:
idx = self.token_to_idx.get(token, None)
if idx is None:
if self.mutable:
idx = len(self.token_to_idx)
Expand Down
47 changes: 38 additions & 9 deletions hanlp/components/classifiers/transformer_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from hanlp.transform.table import TableTransform
from hanlp.utils.log_util import logger
from hanlp.utils.util import merge_locals_kwargs
import numpy as np


class TransformerTextTransform(TableTransform):

def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None,
y_column=-1, skip_header=True, delimiter='auto', **kwargs) -> None:
super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs)
y_column=-1, skip_header=True, delimiter='auto', multi_label=False, **kwargs) -> None:
super().__init__(config, map_x, map_y, x_columns, y_column, multi_label, skip_header, delimiter, **kwargs)
self.tokenizer: FullTokenizer = None

def inputs_to_samples(self, inputs, gold=False):
Expand Down Expand Up @@ -61,26 +62,40 @@ def inputs_to_samples(self, inputs, gold=False):
segment_ids += [0] * diff

assert len(token_ids) == max_length, "Error with input length {} vs {}".format(len(token_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),
max_length)
assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids),
max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids), max_length)


label = Y
yield (token_ids, attention_mask, segment_ids), label

def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
max_length = self.config.max_length
types = (tf.int32, tf.int32, tf.int32), tf.string
shapes = ([max_length], [max_length], [max_length]), []
shapes = ([max_length], [max_length], [max_length]), [None,] if self.config.multi_label else []
values = (0, 0, 0), self.label_vocab.safe_pad_token
return types, shapes, values

def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
logger.fatal('map_x should always be set to True')
exit(1)

def y_to_idx(self, y) -> tf.Tensor:
if self.config.multi_label:
#need to change index to binary vector
mapped = tf.map_fn(fn=lambda x: tf.cast(self.label_vocab.lookup(x), tf.int32), elems=y, fn_output_signature=tf.TensorSpec(dtype=tf.dtypes.int32, shape=[None,]))
one_hots = tf.one_hot(mapped, len(self.label_vocab))
idx = tf.reduce_sum(one_hots, -2)
else:
idx = self.label_vocab.lookup(y)
return idx

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
preds = tf.argmax(Y, axis=-1)
# Prediction to be Y > 0:
if self.config.multi_label:
preds = Y
else:
preds = tf.argmax(Y, axis=-1)
for y in preds:
yield self.label_vocab.idx_to_token[y]

Expand Down Expand Up @@ -126,7 +141,14 @@ def _y_id_to_str(self, Y_pred) -> str:
return self.transform.label_vocab.idx_to_token[Y_pred.numpy()]

def build_loss(self, loss, **kwargs):
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
if loss:
assert isinstance(loss, tf.keras.losses.loss), 'Must specify loss as an instance in tf.keras.losses'
return loss
elif self.config.multi_label:
#Loss to be BinaryCrossentropy for multi-label:
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
else:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return loss

# noinspection PyMethodOverriding
Expand Down Expand Up @@ -158,3 +180,10 @@ def build_vocab(self, trn_data, logger):
warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size)
self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs
return train_examples

def build_metrics(self, metrics, logger, **kwargs):
if self.config.multi_label:
metric = tf.keras.metrics.BinaryAccuracy('binary_accuracy')
else:
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
return [metric]
24 changes: 20 additions & 4 deletions hanlp/transform/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Date: 2019-11-10 21:00
from abc import ABC
from typing import Tuple, Union

import numpy as np
import tensorflow as tf

from hanlp.common.structure import SerializableDict
Expand All @@ -16,9 +16,9 @@

class TableTransform(Transform, ABC):
def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None,
y_column=-1,
y_column=-1, multi_label=False,
skip_header=True, delimiter='auto', **kwargs) -> None:
super().__init__(config, map_x, map_y, x_columns=x_columns, y_column=y_column,
super().__init__(config, map_x, map_y, x_columns=x_columns, y_column=y_column, multi_label=multi_label,
skip_header=skip_header,
delimiter=delimiter, **kwargs)
self.label_vocab = create_label_vocab()
Expand All @@ -28,6 +28,9 @@ def file_to_inputs(self, filepath: str, gold=True):
y_column = self.config.y_column
num_features = self.config.get('num_features', None)
for cells in read_cells(filepath, skip_header=self.config.skip_header, delimiter=self.config.delimiter):
#multi-label: Dataset in .tsv format: x_columns: at most 2 columns being a sentence pair while in most
# cases just one column being the doc content. y_column being the single label, which shall be modified
# to load a list of labels.
if x_columns:
inputs = tuple(c for i, c in enumerate(cells) if i in x_columns), cells[y_column]
else:
Expand All @@ -37,6 +40,15 @@ def file_to_inputs(self, filepath: str, gold=True):
if num_features is None:
num_features = len(inputs[0])
self.config.num_features = num_features
# multi-label support
if self.config.multi_label:
assert type(inputs[1]) is str, 'Y value has to be string'
if inputs[1][0] == '[':
# multi-label is in literal form of a list
labels = eval(inputs[1])
else:
labels = inputs[1].strip().split(',')
inputs = inputs[0], labels
else:
assert num_features == len(inputs[0]), f'Numbers of columns {num_features} ' \
f'inconsistent with current {len(inputs[0])}'
Expand All @@ -56,7 +68,11 @@ def y_to_idx(self, y) -> tf.Tensor:
def fit(self, trn_path: str, **kwargs):
samples = 0
for t in self.file_to_samples(trn_path, gold=True):
self.label_vocab.add(t[1]) # the second one regardless of t is pair or triple
if self.config.multi_label:
for l in t[1]:
self.label_vocab.add(l)
else:
self.label_vocab.add(t[1]) # the second one regardless of t is pair or triple
samples += 1
return samples

Expand Down
4 changes: 1 addition & 3 deletions hanlp/utils/tf_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@


def size_of_dataset(dataset: tf.data.Dataset) -> int:
count = 0
for element in dataset.unbatch().batch(1):
count += 1
count = len(list(dataset.unbatch().as_numpy_iterator()))
return count


Expand Down

1 comment on commit 56c44d3

@hanlpbot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit has been mentioned on Butterfly Effect. There might be relevant details there:

https://bbs.hankcs.com/t/topic/3011/5

Please sign in to comment.