diff --git a/README.md b/README.md index 31820f38..09a9da27 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ If you find this code useful in your research, please cite it using the followin pic
Zan Shuxun ​ -

Beijing University
of Posts and
Telecommunications

​ +

Alibaba Group

​ ​ pic
diff --git a/deepctr/__init__.py b/deepctr/__init__.py index 7c97d7aa..ce72047b 100644 --- a/deepctr/__init__.py +++ b/deepctr/__init__.py @@ -1,4 +1,4 @@ from .utils import check_version -__version__ = '0.8.6' +__version__ = '0.8.7' check_version(__version__) diff --git a/deepctr/feature_column.py b/deepctr/feature_column.py index cb04ce1d..6f277ba1 100644 --- a/deepctr/feature_column.py +++ b/deepctr/feature_column.py @@ -15,12 +15,12 @@ class SparseFeat(namedtuple('SparseFeat', - ['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embeddings_initializer', + ['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'vocabulary_path', 'dtype', 'embeddings_initializer', 'embedding_name', 'group_name', 'trainable'])): __slots__ = () - def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="int32", embeddings_initializer=None, + def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, vocabulary_path=None, dtype="int32", embeddings_initializer=None, embedding_name=None, group_name=DEFAULT_GROUP_NAME, trainable=True): @@ -32,7 +32,7 @@ def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype=" if embedding_name is None: embedding_name = name - return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype, + return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype, embeddings_initializer, embedding_name, group_name, trainable) @@ -64,6 +64,10 @@ def embedding_dim(self): def use_hash(self): return self.sparsefeat.use_hash + @property + def vocabulary_path(self): + return self.sparsefeat.vocabulary_path + @property def dtype(self): return self.sparsefeat.dtype diff --git a/deepctr/inputs.py b/deepctr/inputs.py index a36e4e9b..d567f846 100644 --- a/deepctr/inputs.py +++ b/deepctr/inputs.py @@ -51,7 +51,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r feat_name = fg.name if len(return_feat_list) == 0 or feat_name in return_feat_list: if fg.use_hash: - lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list))(input_dict[feat_name]) + lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list), vocabulary_path=fg.vocabulary_path)(input_dict[feat_name]) else: lookup_idx = input_dict[feat_name] @@ -80,7 +80,7 @@ def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_co embedding_name = fc.embedding_name if (len(return_feat_list) == 0 or feature_name in return_feat_list): if fc.use_hash: - lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))( + lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list), vocabulary_path=fc.vocabulary_path)( sparse_input_dict[feature_name]) else: lookup_idx = sparse_input_dict[feature_name] @@ -97,7 +97,7 @@ def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_f feature_name = fc.name embedding_name = fc.embedding_name if fc.use_hash: - lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name]) + lookup_idx = Hash(fc.vocabulary_size, mask_zero=True, vocabulary_path=fc.vocabulary_path)(sequence_input_dict[feature_name]) else: lookup_idx = sequence_input_dict[feature_name] varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx) diff --git a/deepctr/layers/core.py b/deepctr/layers/core.py index 9ee5e248..2b9188b5 100644 --- a/deepctr/layers/core.py +++ b/deepctr/layers/core.py @@ -68,8 +68,8 @@ def build(self, input_shape): 'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)' 'Got different shapes: %s,%s' % (input_shape[0], input_shape[1])) size = 4 * \ - int(input_shape[0][-1] - ) if len(self.hidden_units) == 0 else self.hidden_units[-1] + int(input_shape[0][-1] + ) if len(self.hidden_units) == 0 else self.hidden_units[-1] self.kernel = self.add_weight(shape=(size, 1), initializer=glorot_normal( seed=self.seed), @@ -78,9 +78,6 @@ def build(self, input_shape): shape=(1,), initializer=Zeros(), name="bias") self.dnn = DNN(self.hidden_units, self.activation, self.l2_reg, self.dropout_rate, self.use_bn, seed=self.seed) - self.dense = tf.keras.layers.Lambda(lambda x: tf.nn.bias_add(tf.tensordot( - x[0], x[1], axes=(-1, 0)), x[2])) - super(LocalActivationUnit, self).build( input_shape) # Be sure to call this somewhere! @@ -96,7 +93,7 @@ def call(self, inputs, training=None, **kwargs): att_out = self.dnn(att_input, training=training) - attention_score = self.dense([att_out, self.kernel, self.bias]) + attention_score = tf.nn.bias_add(tf.tensordot(att_out, self.kernel, axes=(-1, 0)), self.bias) return attention_score diff --git a/deepctr/layers/sequence.py b/deepctr/layers/sequence.py index 5c4b5b50..ce1bd64b 100644 --- a/deepctr/layers/sequence.py +++ b/deepctr/layers/sequence.py @@ -560,10 +560,10 @@ def call(self, inputs, mask=None, training=None, **kwargs): if self.blinding: try: outputs = tf.matrix_set_diag(outputs, tf.ones_like(outputs)[ - :, :, 0] * (-2 ** 32 + 1)) - except: + :, :, 0] * (-2 ** 32 + 1)) + except AttributeError: outputs = tf.compat.v1.matrix_set_diag(outputs, tf.ones_like(outputs)[ - :, :, 0] * (-2 ** 32 + 1)) + :, :, 0] * (-2 ** 32 + 1)) outputs -= reduce_max(outputs, axis=-1, keep_dims=True) outputs = softmax(outputs) @@ -633,14 +633,14 @@ def build(self, input_shape): _, T, num_units = input_shape.as_list() # inputs.get_shape().as_list() # First part of the PE function: sin and cos argument position_enc = np.array([ - [pos / np.power(10000, 2. * i / num_units) - for i in range(num_units)] + [pos / np.power(10000, 2. * (i//2) / num_units) for i in range(num_units)] for pos in range(T)]) # Second part, apply the cosine to even columns and sin to odds. position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 - + if self.zero_pad: + position_enc[0, :] = np.zeros(num_units) self.lookup_table = self.add_weight("lookup_table", (T, num_units), initializer=tf.initializers.identity(position_enc), trainable=self.pos_embedding_trainable) @@ -651,13 +651,7 @@ def build(self, input_shape): def call(self, inputs, mask=None): _, T, num_units = inputs.get_shape().as_list() position_ind = tf.expand_dims(tf.range(T), 0) - - if self.zero_pad: - self.lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), - self.lookup_table[1:, :]), 0) - outputs = tf.nn.embedding_lookup(self.lookup_table, position_ind) - if self.scale: outputs = outputs * num_units ** 0.5 return outputs + inputs diff --git a/deepctr/layers/utils.py b/deepctr/layers/utils.py index ca73d6a3..0e219132 100644 --- a/deepctr/layers/utils.py +++ b/deepctr/layers/utils.py @@ -7,6 +7,11 @@ """ import tensorflow as tf from tensorflow.python.keras.layers import Flatten +from tensorflow.python.ops.lookup_ops import TextFileInitializer +try: + from tensorflow.python.ops.lookup_ops import StaticHashTable +except ImportError as e: + from tensorflow.python.ops.lookup_ops import HashTable as StaticHashTable class NoMask(tf.keras.layers.Layer): @@ -25,14 +30,47 @@ def compute_mask(self, inputs, mask): class Hash(tf.keras.layers.Layer): - """ - hash the input to [0,num_buckets) - if mask_zero = True,0 or 0.0 will be set to 0,other value will be set in range[1,num_buckets) + """Looks up keys in a table when setup `vocabulary_path`, which outputs the corresponding values. + If `vocabulary_path` is not set, `Hash` will hash the input to [0,num_buckets). When `mask_zero` = True, + input value `0` or `0.0` will be set to `0`, and other value will be set in range [1,num_buckets). + + The following snippet initializes a `Hash` with `vocabulary_path` file with the first column as keys and + second column as values: + + * `1,emerson` + * `2,lake` + * `3,palmer` + + >>> hash = Hash( + ... num_buckets=3+1, + ... vocabulary_path=filename, + ... default_value=0) + >>> hash(tf.constant('lake')).numpy() + 2 + >>> hash(tf.constant('lakeemerson')).numpy() + 0 + + Args: + num_buckets: An `int` that is >= 1. The number of buckets or the vocabulary size + 1 + when `vocabulary_path` is setup. + mask_zero: default is False. The `Hash` value will hash input `0` or `0.0` to value `0` when + the `mask_zero` is `True`. `mask_zero` is not used when `vocabulary_path` is setup. + vocabulary_path: default `None`. The `CSV` text file path of the vocabulary hash, which contains + two columns seperated by delimiter `comma`, the first column is the value and the second is + the key. The key data type is `string`, the value data type is `int`. The path must + be accessible from wherever `Hash` is initialized. + default_value: default '0'. The default value if a key is missing in the table. + **kwargs: Additional keyword arguments. """ - def __init__(self, num_buckets, mask_zero=False, **kwargs): + def __init__(self, num_buckets, mask_zero=False, vocabulary_path=None, default_value=0, **kwargs): self.num_buckets = num_buckets self.mask_zero = mask_zero + self.vocabulary_path = vocabulary_path + self.default_value = default_value + if self.vocabulary_path: + initializer = TextFileInitializer(vocabulary_path, 'string', 1, 'int64', 0, delimiter=',') + self.hash_table = StaticHashTable(initializer, default_value=self.default_value) super(Hash, self).__init__(**kwargs) def build(self, input_shape): @@ -41,13 +79,16 @@ def build(self, input_shape): def call(self, x, mask=None, **kwargs): - if x.dtype != tf.string: zero = tf.as_string(tf.zeros([1], dtype=x.dtype)) x = tf.as_string(x, ) else: zero = tf.as_string(tf.zeros([1], dtype='int32')) + if self.vocabulary_path: + hash_x = self.hash_table.lookup(x) + return hash_x + num_buckets = self.num_buckets if not self.mask_zero else self.num_buckets - 1 try: hash_x = tf.string_to_hash_bucket_fast(x, num_buckets, @@ -60,8 +101,12 @@ def call(self, x, mask=None, **kwargs): hash_x = (hash_x + 1) * mask return hash_x + + def compute_output_shape(self, input_shape): + return input_shape + def get_config(self, ): - config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, } + config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, 'vocabulary_path': self.vocabulary_path, 'default_value': self.default_value} base_config = super(Hash, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/docs/source/Examples.md b/docs/source/Examples.md index de6b33c1..35c9de18 100644 --- a/docs/source/Examples.md +++ b/docs/source/Examples.md @@ -322,6 +322,72 @@ if __name__ == "__main__": history = model.fit(model_input, data[target].values, batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) ``` +## Hash Layer with pre-defined key-value vocabulary + +This examples how to use pre-defined key-value vocabulary in `Hash` Layer.`movielens_age_vocabulary.csv` stores the key-value mapping for `age` feature. + +```python +from deepctr.models import DeepFM +from deepctr.feature_column import SparseFeat, VarLenSparseFeat, get_feature_names +import numpy as np +import pandas as pd +from tensorflow.python.keras.preprocessing.sequence import pad_sequences + +try: + import tensorflow.compat.v1 as tf +except ImportError as e: + import tensorflow as tf + +if __name__ == "__main__": + data = pd.read_csv("./movielens_sample.txt") + sparse_features = ["movie_id", "user_id", + "gender", "age", "occupation", "zip", ] + + data[sparse_features] = data[sparse_features].astype(str) + target = ['rating'] + + # 1.Use hashing encoding on the fly for sparse features,and process sequence features + + genres_list = list(map(lambda x: x.split('|'), data['genres'].values)) + genres_length = np.array(list(map(len, genres_list))) + max_len = max(genres_length) + + # Notice : padding=`post` + genres_list = pad_sequences(genres_list, maxlen=max_len, padding='post', dtype=str, value=0) + + # 2.set hashing space for each sparse field and generate feature config for sequence feature + + fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique() * 5, embedding_dim=4, use_hash=True, + vocabulary_path='./movielens_age_vocabulary.csv' if feat == 'age' else None, + dtype='string') + for feat in sparse_features] + varlen_feature_columns = [ + VarLenSparseFeat(SparseFeat('genres', vocabulary_size=100, embedding_dim=4, + use_hash=True, dtype="string"), + maxlen=max_len, combiner='mean', + )] # Notice : value 0 is for padding for sequence input feature + linear_feature_columns = fixlen_feature_columns + varlen_feature_columns + dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns + feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns) + + # 3.generate input data for model + model_input = {name: data[name] for name in feature_names} + model_input['genres'] = genres_list + + # 4.Define Model,compile and train + model = DeepFM(linear_feature_columns, dnn_feature_columns, task='regression') + model.compile("adam", "mse", metrics=['mse'], ) + if not hasattr(tf, 'version') or tf.version.VERSION < '2.0.0': + with tf.Session() as sess: + sess.run(tf.tables_initializer()) + history = model.fit(model_input, data[target].values, + batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) + else: + history = model.fit(model_input, data[target].values, + batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) + +``` + ## Estimator with TFRecord: Classification Criteo diff --git a/docs/source/Features.md b/docs/source/Features.md index 2dcdd21d..13db0903 100644 --- a/docs/source/Features.md +++ b/docs/source/Features.md @@ -23,12 +23,13 @@ DNN based CTR prediction models usually have following 4 modules: ## Feature Columns ### SparseFeat -``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, dtype, embeddings_initializer, embedding_name, group_name, trainable)`` +``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype, embeddings_initializer, embedding_name, group_name, trainable)`` - name : feature name -- vocabulary_size : number of unique feature values for sprase feature or hashing space when `use_hash=True` +- vocabulary_size : number of unique feature values for sparse feature or hashing space when `use_hash=True` - embedding_dim : embedding dimension -- use_hash : defualt `False`.If `True` the input will be hashed to space of size `vocabulary_size`. +- use_hash : default `False`.If `True` the input will be hashed to space of size `vocabulary_size`. +- vocabulary_path : default `None`. The `CSV` text file path of the vocabulary table used by `tf.lookup.TextFileInitializer`, which assigns one entry in the table for each line in the file. One entry contains two columns separated by comma, the first is the value column, the second is the key column. The `0` value is reserved to use if a key is missing in the table, so hash value need start from `1`. - dtype : default `int32`.dtype of input tensor. - embeddings_initializer : initializer for the `embeddings` matrix. - embedding_name : default `None`. If None, the embedding_name will be same as `name`. diff --git a/docs/source/History.md b/docs/source/History.md index b0304655..64066d16 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 07/18/2021 : [v0.8.7](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.7) released.Support pre-defined key-value vocabulary in `Hash` Layer. [example](./Examples.html#hash-layer-with-pre-defined-key-value-vocabulary) - 06/14/2021 : [v0.8.6](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.6) released.Add [IFM](./Features.html#ifm-input-aware-factorization-machine) [DIFM](./Features.html#difm-dual-input-aware-factorization-machine), [FEFM and DeepFEFM](./Features.html#deepfefm-deep-field-embedded-factorization-machine) model. - 03/13/2021 : [v0.8.5](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.5) released.Add [BST](./Features.html#bst-behavior-sequence-transformer) model. - 02/12/2021 : [v0.8.4](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.4) released.Fix bug in DCN-Mix. diff --git a/docs/source/Quick-Start.md b/docs/source/Quick-Start.md index e587757f..a8b0ab38 100644 --- a/docs/source/Quick-Start.md +++ b/docs/source/Quick-Start.md @@ -86,7 +86,7 @@ fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=data[feat].max() + 1, ``` - Feature Hashing on the fly ```python -fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=1e6,embedding_dim=4, use_hash=True, dtype='string') # since the input is string +fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=1e6,embedding_dim=4, use_hash=True, dtype='string') # the input is string for feat in sparse_features] + [DenseFeat(feat, 1, ) for feat in dense_features] ``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 536ff116..d1ff9206 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.8.6' +release = '0.8.7' # -- General configuration --------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index a8904b99..eecc8a8a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,12 +42,12 @@ You can read the latest code and related projects News ----- +07/18/2021 : Support pre-defined key-value vocabulary in `Hash` Layer. `example <./Examples.html#hash-layer-with-pre-defined-key-value-vocabulary>`_ `Changelog `_ + 06/14/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ , `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ and `DeepFEFM <./Features.html#deepfefm-deep-field-embedded-factorization-machine>`_ . `Changelog `_ 03/13/2021 : Add `BST <./Features.html#bst-behavior-sequence-transformer>`_ . `Changelog `_ -02/12/2021 : Fix bug in DCN-Mix. `Changelog `_ - DisscussionGroup ----------------------- diff --git a/examples/movielens_age_vocabulary.csv b/examples/movielens_age_vocabulary.csv new file mode 100644 index 00000000..ce07b01d --- /dev/null +++ b/examples/movielens_age_vocabulary.csv @@ -0,0 +1,7 @@ +1,1 +2,18 +3,25 +4,35 +5,45 +6,50 +7,56 diff --git a/examples/run_multivalue_movielens_vocab_hash.py b/examples/run_multivalue_movielens_vocab_hash.py new file mode 100644 index 00000000..2376369f --- /dev/null +++ b/examples/run_multivalue_movielens_vocab_hash.py @@ -0,0 +1,58 @@ +from deepctr.models import DeepFM +from deepctr.feature_column import SparseFeat, VarLenSparseFeat, get_feature_names +import numpy as np +import pandas as pd +from tensorflow.python.keras.preprocessing.sequence import pad_sequences + +try: + import tensorflow.compat.v1 as tf +except ImportError as e: + import tensorflow as tf + +if __name__ == "__main__": + data = pd.read_csv("./movielens_sample.txt") + sparse_features = ["movie_id", "user_id", + "gender", "age", "occupation", "zip", ] + + data[sparse_features] = data[sparse_features].astype(str) + target = ['rating'] + + # 1.Use hashing encoding on the fly for sparse features,and process sequence features + + genres_list = list(map(lambda x: x.split('|'), data['genres'].values)) + genres_length = np.array(list(map(len, genres_list))) + max_len = max(genres_length) + + # Notice : padding=`post` + genres_list = pad_sequences(genres_list, maxlen=max_len, padding='post', dtype=str, value=0) + + # 2.set hashing space for each sparse field and generate feature config for sequence feature + + fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique() * 5, embedding_dim=4, use_hash=True, + vocabulary_path='./movielens_age_vocabulary.csv' if feat == 'age' else None, + dtype='string') + for feat in sparse_features] + varlen_feature_columns = [ + VarLenSparseFeat(SparseFeat('genres', vocabulary_size=100, embedding_dim=4, + use_hash=True, dtype="string"), + maxlen=max_len, combiner='mean', + )] # Notice : value 0 is for padding for sequence input feature + linear_feature_columns = fixlen_feature_columns + varlen_feature_columns + dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns + feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns) + + # 3.generate input data for model + model_input = {name: data[name] for name in feature_names} + model_input['genres'] = genres_list + + # 4.Define Model,compile and train + model = DeepFM(linear_feature_columns, dnn_feature_columns, task='regression') + model.compile("adam", "mse", metrics=['mse'], ) + if not hasattr(tf, 'version') or tf.version.VERSION < '2.0.0': + with tf.Session() as sess: + sess.run(tf.tables_initializer()) + history = model.fit(model_input, data[target].values, + batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) + else: + history = model.fit(model_input, data[target].values, + batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) diff --git a/setup.py b/setup.py index a5b5219d..5e3652fd 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="deepctr", - version="0.8.6", + version="0.8.7", author="Weichen Shen", author_email="weichenswc@163.com", description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with tensorflow 1.x and 2.x .", diff --git a/tests/feature_test.py b/tests/feature_test.py index 7f15208f..35005fb7 100644 --- a/tests/feature_test.py +++ b/tests/feature_test.py @@ -1,8 +1,9 @@ from deepctr.models import DeepFM -from deepctr.feature_column import SparseFeat, DenseFeat,get_feature_names +from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat, get_feature_names import numpy as np -def test_long_dense_vector(): + +def test_long_dense_vector(): feature_columns = [SparseFeat('user_id', 4, ), SparseFeat('item_id', 5, ), DenseFeat("pic_vec", 5)] fixlen_feature_names = get_feature_names(feature_columns) @@ -16,4 +17,14 @@ def test_long_dense_vector(): model = DeepFM(feature_columns, feature_columns[:-1]) model.compile('adagrad', 'binary_crossentropy') - model.fit(model_input, label) \ No newline at end of file + model.fit(model_input, label) + + +def test_feature_column_sparsefeat_vocabulary_path(): + vocab_path = "./dummy_test.csv" + sf = SparseFeat('user_id', 4, vocabulary_path=vocab_path) + if sf.vocabulary_path != vocab_path: + raise ValueError("sf.vocabulary_path is invalid") + vlsf = VarLenSparseFeat(sf, 6) + if vlsf.vocabulary_path != vocab_path: + raise ValueError("vlsf.vocabulary_path is invalid") diff --git a/tests/layers/sequence_test.py b/tests/layers/sequence_test.py index dd030d74..ccbc013b 100644 --- a/tests/layers/sequence_test.py +++ b/tests/layers/sequence_test.py @@ -79,8 +79,6 @@ def test_BiLSTM(merge_mode): def test_Transformer(): - if tf.__version__ >= '2.0.0': - tf.compat.v1.disable_eager_execution() # todo with CustomObjectScope({'Transformer': sequence.Transformer}): layer_test(sequence.Transformer, kwargs={'att_embedding_size': 1, 'head_num': 8, 'use_layer_norm': True, 'supports_masking': False, @@ -102,7 +100,7 @@ def test_KMaxPooling(): ] ) def test_PositionEncoding(pos_embedding_trainable, zero_pad): - with CustomObjectScope({'PositionEncoding': sequence.PositionEncoding}): + with CustomObjectScope({'PositionEncoding': sequence.PositionEncoding, "tf": tf}): layer_test(sequence.PositionEncoding, kwargs={'pos_embedding_trainable': pos_embedding_trainable, 'zero_pad': zero_pad}, input_shape=(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_SIZE)) diff --git a/tests/layers/utils_test.py b/tests/layers/utils_test.py new file mode 100644 index 00000000..a651dd07 --- /dev/null +++ b/tests/layers/utils_test.py @@ -0,0 +1,30 @@ +import pytest +import numpy as np +import tensorflow as tf +from deepctr.layers.utils import Hash +from tests.utils import layer_test + +try: + from tensorflow.python.keras.utils import CustomObjectScope +except ImportError: + from tensorflow.keras.utils import CustomObjectScope + + +@pytest.mark.parametrize( + 'num_buckets,mask_zero,vocabulary_path,input_data,expected_output', + [ + (3 + 1, False, None, ['lakemerson'], None), + (3 + 1, True, None, ['lakemerson'], None), + ( + 3 + 1, False, "./tests/layers/vocabulary_example.csv", [['lake'], ['johnson'], ['lakemerson']], [[1], [3], [0]]) + ] +) +def test_Hash(num_buckets, mask_zero, vocabulary_path, input_data, expected_output): + if not hasattr(tf, 'version') or tf.version.VERSION < '2.0.0': + return + + with CustomObjectScope({'Hash': Hash}): + layer_test(Hash, + kwargs={'num_buckets': num_buckets, 'mask_zero': mask_zero, 'vocabulary_path': vocabulary_path}, + input_dtype=tf.string, input_data=np.array(input_data, dtype='str'), + expected_output_dtype=tf.int64, expected_output=expected_output) diff --git a/tests/layers/vocabulary_example.csv b/tests/layers/vocabulary_example.csv new file mode 100644 index 00000000..4bce734c --- /dev/null +++ b/tests/layers/vocabulary_example.csv @@ -0,0 +1,3 @@ +1,lake +2,merson +3,johnson \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index db570297..c190991d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,6 +18,7 @@ VOCABULARY_SIZE = 4 Estimator_TEST_TF1 = True + def gen_sequence(dim, max_len, sample_size): return np.array([np.random.randint(0, dim, max_len) for _ in range(sample_size)]), np.random.randint(1, max_len + 1, sample_size) @@ -44,15 +45,15 @@ def get_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dens for i in range(sparse_feature_num): if use_group: - group_name = str(i%3) + group_name = str(i % 3) else: group_name = DEFAULT_GROUP_NAME dim = np.random.randint(1, 10) feature_columns.append( - SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32,group_name=group_name)) + SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32, group_name=group_name)) for i in range(dense_feature_num): - transform_fn = lambda x: (x - 0.0)/ 1.0 + def transform_fn(x): return (x - 0.0) / 1.0 feature_columns.append( DenseFeat( prefix + 'dense_feature_' + str(i), @@ -363,6 +364,7 @@ def check_model(model, model_name, x, y, check_model_io=True): print(model_name + " test pass!") + def get_test_data_estimator(sample_size=1000, embedding_size=4, sparse_feature_num=1, dense_feature_num=1, classification=True): x = {} @@ -372,7 +374,7 @@ def get_test_data_estimator(sample_size=1000, embedding_size=4, sparse_feature_n for i in range(sparse_feature_num): name = 's_'+str(i) x[name] = np.random.randint(0, voc_size, sample_size) - dnn_feature_columns.append(tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(name,voc_size),embedding_size)) + dnn_feature_columns.append(tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(name, voc_size), embedding_size)) linear_feature_columns.append(tf.feature_column.categorical_column_with_identity(name, voc_size)) for i in range(dense_feature_num): @@ -390,8 +392,9 @@ def get_test_data_estimator(sample_size=1000, embedding_size=4, sparse_feature_n else: input_fn = tf.estimator.inputs.numpy_input_fn(x, y, shuffle=False) - return linear_feature_columns,dnn_feature_columns,input_fn + return linear_feature_columns, dnn_feature_columns, input_fn + -def check_estimator(model,input_fn): +def check_estimator(model, input_fn): model.train(input_fn) - model.evaluate(input_fn) \ No newline at end of file + model.evaluate(input_fn)