diff --git a/.gitignore b/.gitignore index 76b3c5739..3c7ce29ca 100644 --- a/.gitignore +++ b/.gitignore @@ -284,4 +284,5 @@ fabric.properties .idea/caches/build_file_checksums.ser .idea *.iml -data \ No newline at end of file +data +.vscode/settings.json diff --git a/hanlp/common/vocab.py b/hanlp/common/vocab.py index 7dec92ed0..74a0fc11f 100644 --- a/hanlp/common/vocab.py +++ b/hanlp/common/vocab.py @@ -79,7 +79,10 @@ def update(self, tokens: Iterable[str]) -> None: self.add(token) def get_idx(self, token: str) -> int: - idx = self.token_to_idx.get(token, None) + if type(token) is list: + idx = [self.get_idx(t) for t in token] + else: + idx = self.token_to_idx.get(token, None) if idx is None: if self.mutable: idx = len(self.token_to_idx) diff --git a/hanlp/components/classifiers/transformer_classifier.py b/hanlp/components/classifiers/transformer_classifier.py index 85270eae5..a70d9b65a 100644 --- a/hanlp/components/classifiers/transformer_classifier.py +++ b/hanlp/components/classifiers/transformer_classifier.py @@ -15,13 +15,14 @@ from hanlp.transform.table import TableTransform from hanlp.utils.log_util import logger from hanlp.utils.util import merge_locals_kwargs +import numpy as np class TransformerTextTransform(TableTransform): def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None, - y_column=-1, skip_header=True, delimiter='auto', **kwargs) -> None: - super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs) + y_column=-1, skip_header=True, delimiter='auto', multi_label=False, **kwargs) -> None: + super().__init__(config, map_x, map_y, x_columns, y_column, multi_label, skip_header, delimiter, **kwargs) self.tokenizer: FullTokenizer = None def inputs_to_samples(self, inputs, gold=False): @@ -61,17 +62,17 @@ def inputs_to_samples(self, inputs, gold=False): segment_ids += [0] * diff assert len(token_ids) == max_length, "Error with input length {} vs {}".format(len(token_ids), max_length) - assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), - max_length) - assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids), - max_length) + assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) + assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids), max_length) + + label = Y yield (token_ids, attention_mask, segment_ids), label def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]: max_length = self.config.max_length types = (tf.int32, tf.int32, tf.int32), tf.string - shapes = ([max_length], [max_length], [max_length]), [] + shapes = ([max_length], [max_length], [max_length]), [None,] if self.config.multi_label else [] values = (0, 0, 0), self.label_vocab.safe_pad_token return types, shapes, values @@ -79,8 +80,22 @@ def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]: logger.fatal('map_x should always be set to True') exit(1) + def y_to_idx(self, y) -> tf.Tensor: + if self.config.multi_label: + #need to change index to binary vector + mapped = tf.map_fn(fn=lambda x: tf.cast(self.label_vocab.lookup(x), tf.int32), elems=y, fn_output_signature=tf.TensorSpec(dtype=tf.dtypes.int32, shape=[None,])) + one_hots = tf.one_hot(mapped, len(self.label_vocab)) + idx = tf.reduce_sum(one_hots, -2) + else: + idx = self.label_vocab.lookup(y) + return idx + def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable: - preds = tf.argmax(Y, axis=-1) + # Prediction to be Y > 0: + if self.config.multi_label: + preds = Y + else: + preds = tf.argmax(Y, axis=-1) for y in preds: yield self.label_vocab.idx_to_token[y] @@ -126,7 +141,14 @@ def _y_id_to_str(self, Y_pred) -> str: return self.transform.label_vocab.idx_to_token[Y_pred.numpy()] def build_loss(self, loss, **kwargs): - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if loss: + assert isinstance(loss, tf.keras.losses.loss), 'Must specify loss as an instance in tf.keras.losses' + return loss + elif self.config.multi_label: + #Loss to be BinaryCrossentropy for multi-label: + loss = tf.keras.losses.BinaryCrossentropy(from_logits=True) + else: + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) return loss # noinspection PyMethodOverriding @@ -158,3 +180,10 @@ def build_vocab(self, trn_data, logger): warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size) self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs return train_examples + + def build_metrics(self, metrics, logger, **kwargs): + if self.config.multi_label: + metric = tf.keras.metrics.BinaryAccuracy('binary_accuracy') + else: + metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') + return [metric] \ No newline at end of file diff --git a/hanlp/transform/table.py b/hanlp/transform/table.py index 046be98f0..ad95fd8f7 100644 --- a/hanlp/transform/table.py +++ b/hanlp/transform/table.py @@ -3,7 +3,7 @@ # Date: 2019-11-10 21:00 from abc import ABC from typing import Tuple, Union - +import numpy as np import tensorflow as tf from hanlp.common.structure import SerializableDict @@ -16,9 +16,9 @@ class TableTransform(Transform, ABC): def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None, - y_column=-1, + y_column=-1, multi_label=False, skip_header=True, delimiter='auto', **kwargs) -> None: - super().__init__(config, map_x, map_y, x_columns=x_columns, y_column=y_column, + super().__init__(config, map_x, map_y, x_columns=x_columns, y_column=y_column, multi_label=multi_label, skip_header=skip_header, delimiter=delimiter, **kwargs) self.label_vocab = create_label_vocab() @@ -28,6 +28,9 @@ def file_to_inputs(self, filepath: str, gold=True): y_column = self.config.y_column num_features = self.config.get('num_features', None) for cells in read_cells(filepath, skip_header=self.config.skip_header, delimiter=self.config.delimiter): + #multi-label: Dataset in .tsv format: x_columns: at most 2 columns being a sentence pair while in most + # cases just one column being the doc content. y_column being the single label, which shall be modified + # to load a list of labels. if x_columns: inputs = tuple(c for i, c in enumerate(cells) if i in x_columns), cells[y_column] else: @@ -37,6 +40,15 @@ def file_to_inputs(self, filepath: str, gold=True): if num_features is None: num_features = len(inputs[0]) self.config.num_features = num_features + # multi-label support + if self.config.multi_label: + assert type(inputs[1]) is str, 'Y value has to be string' + if inputs[1][0] == '[': + # multi-label is in literal form of a list + labels = eval(inputs[1]) + else: + labels = inputs[1].strip().split(',') + inputs = inputs[0], labels else: assert num_features == len(inputs[0]), f'Numbers of columns {num_features} ' \ f'inconsistent with current {len(inputs[0])}' @@ -56,7 +68,11 @@ def y_to_idx(self, y) -> tf.Tensor: def fit(self, trn_path: str, **kwargs): samples = 0 for t in self.file_to_samples(trn_path, gold=True): - self.label_vocab.add(t[1]) # the second one regardless of t is pair or triple + if self.config.multi_label: + for l in t[1]: + self.label_vocab.add(l) + else: + self.label_vocab.add(t[1]) # the second one regardless of t is pair or triple samples += 1 return samples diff --git a/hanlp/utils/tf_util.py b/hanlp/utils/tf_util.py index 465856cc7..1ea040ea9 100644 --- a/hanlp/utils/tf_util.py +++ b/hanlp/utils/tf_util.py @@ -11,9 +11,7 @@ def size_of_dataset(dataset: tf.data.Dataset) -> int: - count = 0 - for element in dataset.unbatch().batch(1): - count += 1 + count = len(list(dataset.unbatch().as_numpy_iterator())) return count