Skip to content

Commit

Permalink
Support pre-defined key-value vocabulary in Hash Layer
Browse files Browse the repository at this point in the history
New Feature: Support pre-defined key-value vocabulary in Hash Layer
  • Loading branch information
shenweichen committed Jul 18, 2021
2 parents 0df401c + 95ad62e commit 9f15559
Show file tree
Hide file tree
Showing 21 changed files with 271 additions and 53 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ If you find this code useful in your research, please cite it using the followin
<td>
<a href="https://github.com/zanshuxun"><img width="70" height="70" src="https://github.com/zanshuxun.png?s=40" alt="pic"></a><br>
<a href="https://github.com/zanshuxun">Zan Shuxun</a> ​
<p>Beijing University <br> of Posts and <br> Telecommunications </p>​
<p>Alibaba Group </p>​
</td>
<td>
​ <a href="https://github.com/pandeconscious"><img width="70" height="70" src="https://github.com/pandeconscious.png?s=40" alt="pic"></a><br>
Expand Down
2 changes: 1 addition & 1 deletion deepctr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils import check_version

__version__ = '0.8.6'
__version__ = '0.8.7'
check_version(__version__)
10 changes: 7 additions & 3 deletions deepctr/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@


class SparseFeat(namedtuple('SparseFeat',
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embeddings_initializer',
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'vocabulary_path', 'dtype', 'embeddings_initializer',
'embedding_name',
'group_name', 'trainable'])):
__slots__ = ()

def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="int32", embeddings_initializer=None,
def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, vocabulary_path=None, dtype="int32", embeddings_initializer=None,
embedding_name=None,
group_name=DEFAULT_GROUP_NAME, trainable=True):

Expand All @@ -32,7 +32,7 @@ def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="
if embedding_name is None:
embedding_name = name

return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype,
return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype,
embeddings_initializer,
embedding_name, group_name, trainable)

Expand Down Expand Up @@ -64,6 +64,10 @@ def embedding_dim(self):
def use_hash(self):
return self.sparsefeat.use_hash

@property
def vocabulary_path(self):
return self.sparsefeat.vocabulary_path

@property
def dtype(self):
return self.sparsefeat.dtype
Expand Down
6 changes: 3 additions & 3 deletions deepctr/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r
feat_name = fg.name
if len(return_feat_list) == 0 or feat_name in return_feat_list:
if fg.use_hash:
lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list))(input_dict[feat_name])
lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list), vocabulary_path=fg.vocabulary_path)(input_dict[feat_name])
else:
lookup_idx = input_dict[feat_name]

Expand Down Expand Up @@ -80,7 +80,7 @@ def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_co
embedding_name = fc.embedding_name
if (len(return_feat_list) == 0 or feature_name in return_feat_list):
if fc.use_hash:
lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))(
lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list), vocabulary_path=fc.vocabulary_path)(
sparse_input_dict[feature_name])
else:
lookup_idx = sparse_input_dict[feature_name]
Expand All @@ -97,7 +97,7 @@ def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_f
feature_name = fc.name
embedding_name = fc.embedding_name
if fc.use_hash:
lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name])
lookup_idx = Hash(fc.vocabulary_size, mask_zero=True, vocabulary_path=fc.vocabulary_path)(sequence_input_dict[feature_name])
else:
lookup_idx = sequence_input_dict[feature_name]
varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx)
Expand Down
9 changes: 3 additions & 6 deletions deepctr/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def build(self, input_shape):
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape[0], input_shape[1]))
size = 4 * \
int(input_shape[0][-1]
) if len(self.hidden_units) == 0 else self.hidden_units[-1]
int(input_shape[0][-1]
) if len(self.hidden_units) == 0 else self.hidden_units[-1]
self.kernel = self.add_weight(shape=(size, 1),
initializer=glorot_normal(
seed=self.seed),
Expand All @@ -78,9 +78,6 @@ def build(self, input_shape):
shape=(1,), initializer=Zeros(), name="bias")
self.dnn = DNN(self.hidden_units, self.activation, self.l2_reg, self.dropout_rate, self.use_bn, seed=self.seed)

self.dense = tf.keras.layers.Lambda(lambda x: tf.nn.bias_add(tf.tensordot(
x[0], x[1], axes=(-1, 0)), x[2]))

super(LocalActivationUnit, self).build(
input_shape) # Be sure to call this somewhere!

Expand All @@ -96,7 +93,7 @@ def call(self, inputs, training=None, **kwargs):

att_out = self.dnn(att_input, training=training)

attention_score = self.dense([att_out, self.kernel, self.bias])
attention_score = tf.nn.bias_add(tf.tensordot(att_out, self.kernel, axes=(-1, 0)), self.bias)

return attention_score

Expand Down
18 changes: 6 additions & 12 deletions deepctr/layers/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,10 @@ def call(self, inputs, mask=None, training=None, **kwargs):
if self.blinding:
try:
outputs = tf.matrix_set_diag(outputs, tf.ones_like(outputs)[
:, :, 0] * (-2 ** 32 + 1))
except:
:, :, 0] * (-2 ** 32 + 1))
except AttributeError:
outputs = tf.compat.v1.matrix_set_diag(outputs, tf.ones_like(outputs)[
:, :, 0] * (-2 ** 32 + 1))
:, :, 0] * (-2 ** 32 + 1))

outputs -= reduce_max(outputs, axis=-1, keep_dims=True)
outputs = softmax(outputs)
Expand Down Expand Up @@ -633,14 +633,14 @@ def build(self, input_shape):
_, T, num_units = input_shape.as_list() # inputs.get_shape().as_list()
# First part of the PE function: sin and cos argument
position_enc = np.array([
[pos / np.power(10000, 2. * i / num_units)
for i in range(num_units)]
[pos / np.power(10000, 2. * (i//2) / num_units) for i in range(num_units)]
for pos in range(T)])

# Second part, apply the cosine to even columns and sin to odds.
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1

if self.zero_pad:
position_enc[0, :] = np.zeros(num_units)
self.lookup_table = self.add_weight("lookup_table", (T, num_units),
initializer=tf.initializers.identity(position_enc),
trainable=self.pos_embedding_trainable)
Expand All @@ -651,13 +651,7 @@ def build(self, input_shape):
def call(self, inputs, mask=None):
_, T, num_units = inputs.get_shape().as_list()
position_ind = tf.expand_dims(tf.range(T), 0)

if self.zero_pad:
self.lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
self.lookup_table[1:, :]), 0)

outputs = tf.nn.embedding_lookup(self.lookup_table, position_ind)

if self.scale:
outputs = outputs * num_units ** 0.5
return outputs + inputs
Expand Down
57 changes: 51 additions & 6 deletions deepctr/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
"""
import tensorflow as tf
from tensorflow.python.keras.layers import Flatten
from tensorflow.python.ops.lookup_ops import TextFileInitializer
try:
from tensorflow.python.ops.lookup_ops import StaticHashTable
except ImportError as e:
from tensorflow.python.ops.lookup_ops import HashTable as StaticHashTable


class NoMask(tf.keras.layers.Layer):
Expand All @@ -25,14 +30,47 @@ def compute_mask(self, inputs, mask):


class Hash(tf.keras.layers.Layer):
"""
hash the input to [0,num_buckets)
if mask_zero = True,0 or 0.0 will be set to 0,other value will be set in range[1,num_buckets)
"""Looks up keys in a table when setup `vocabulary_path`, which outputs the corresponding values.
If `vocabulary_path` is not set, `Hash` will hash the input to [0,num_buckets). When `mask_zero` = True,
input value `0` or `0.0` will be set to `0`, and other value will be set in range [1,num_buckets).
The following snippet initializes a `Hash` with `vocabulary_path` file with the first column as keys and
second column as values:
* `1,emerson`
* `2,lake`
* `3,palmer`
>>> hash = Hash(
... num_buckets=3+1,
... vocabulary_path=filename,
... default_value=0)
>>> hash(tf.constant('lake')).numpy()
2
>>> hash(tf.constant('lakeemerson')).numpy()
0
Args:
num_buckets: An `int` that is >= 1. The number of buckets or the vocabulary size + 1
when `vocabulary_path` is setup.
mask_zero: default is False. The `Hash` value will hash input `0` or `0.0` to value `0` when
the `mask_zero` is `True`. `mask_zero` is not used when `vocabulary_path` is setup.
vocabulary_path: default `None`. The `CSV` text file path of the vocabulary hash, which contains
two columns seperated by delimiter `comma`, the first column is the value and the second is
the key. The key data type is `string`, the value data type is `int`. The path must
be accessible from wherever `Hash` is initialized.
default_value: default '0'. The default value if a key is missing in the table.
**kwargs: Additional keyword arguments.
"""

def __init__(self, num_buckets, mask_zero=False, **kwargs):
def __init__(self, num_buckets, mask_zero=False, vocabulary_path=None, default_value=0, **kwargs):
self.num_buckets = num_buckets
self.mask_zero = mask_zero
self.vocabulary_path = vocabulary_path
self.default_value = default_value
if self.vocabulary_path:
initializer = TextFileInitializer(vocabulary_path, 'string', 1, 'int64', 0, delimiter=',')
self.hash_table = StaticHashTable(initializer, default_value=self.default_value)
super(Hash, self).__init__(**kwargs)

def build(self, input_shape):
Expand All @@ -41,13 +79,16 @@ def build(self, input_shape):

def call(self, x, mask=None, **kwargs):


if x.dtype != tf.string:
zero = tf.as_string(tf.zeros([1], dtype=x.dtype))
x = tf.as_string(x, )
else:
zero = tf.as_string(tf.zeros([1], dtype='int32'))

if self.vocabulary_path:
hash_x = self.hash_table.lookup(x)
return hash_x

num_buckets = self.num_buckets if not self.mask_zero else self.num_buckets - 1
try:
hash_x = tf.string_to_hash_bucket_fast(x, num_buckets,
Expand All @@ -60,8 +101,12 @@ def call(self, x, mask=None, **kwargs):
hash_x = (hash_x + 1) * mask

return hash_x

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self, ):
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, }
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, 'vocabulary_path': self.vocabulary_path, 'default_value': self.default_value}
base_config = super(Hash, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down
66 changes: 66 additions & 0 deletions docs/source/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,72 @@ if __name__ == "__main__":
history = model.fit(model_input, data[target].values,
batch_size=256, epochs=10, verbose=2, validation_split=0.2, )
```
## Hash Layer with pre-defined key-value vocabulary

This examples how to use pre-defined key-value vocabulary in `Hash` Layer.`movielens_age_vocabulary.csv` stores the key-value mapping for `age` feature.

```python
from deepctr.models import DeepFM
from deepctr.feature_column import SparseFeat, VarLenSparseFeat, get_feature_names
import numpy as np
import pandas as pd
from tensorflow.python.keras.preprocessing.sequence import pad_sequences

try:
import tensorflow.compat.v1 as tf
except ImportError as e:
import tensorflow as tf

if __name__ == "__main__":
data = pd.read_csv("./movielens_sample.txt")
sparse_features = ["movie_id", "user_id",
"gender", "age", "occupation", "zip", ]

data[sparse_features] = data[sparse_features].astype(str)
target = ['rating']

# 1.Use hashing encoding on the fly for sparse features,and process sequence features

genres_list = list(map(lambda x: x.split('|'), data['genres'].values))
genres_length = np.array(list(map(len, genres_list)))
max_len = max(genres_length)

# Notice : padding=`post`
genres_list = pad_sequences(genres_list, maxlen=max_len, padding='post', dtype=str, value=0)

# 2.set hashing space for each sparse field and generate feature config for sequence feature

fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique() * 5, embedding_dim=4, use_hash=True,
vocabulary_path='./movielens_age_vocabulary.csv' if feat == 'age' else None,
dtype='string')
for feat in sparse_features]
varlen_feature_columns = [
VarLenSparseFeat(SparseFeat('genres', vocabulary_size=100, embedding_dim=4,
use_hash=True, dtype="string"),
maxlen=max_len, combiner='mean',
)] # Notice : value 0 is for padding for sequence input feature
linear_feature_columns = fixlen_feature_columns + varlen_feature_columns
dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns
feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)

# 3.generate input data for model
model_input = {name: data[name] for name in feature_names}
model_input['genres'] = genres_list

# 4.Define Model,compile and train
model = DeepFM(linear_feature_columns, dnn_feature_columns, task='regression')
model.compile("adam", "mse", metrics=['mse'], )
if not hasattr(tf, 'version') or tf.version.VERSION < '2.0.0':
with tf.Session() as sess:
sess.run(tf.tables_initializer())
history = model.fit(model_input, data[target].values,
batch_size=256, epochs=10, verbose=2, validation_split=0.2, )
else:
history = model.fit(model_input, data[target].values,
batch_size=256, epochs=10, verbose=2, validation_split=0.2, )

```


## Estimator with TFRecord: Classification Criteo

Expand Down
7 changes: 4 additions & 3 deletions docs/source/Features.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ DNN based CTR prediction models usually have following 4 modules:
## Feature Columns
### SparseFeat
``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, dtype, embeddings_initializer, embedding_name, group_name, trainable)``
``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype, embeddings_initializer, embedding_name, group_name, trainable)``

- name : feature name
- vocabulary_size : number of unique feature values for sprase feature or hashing space when `use_hash=True`
- vocabulary_size : number of unique feature values for sparse feature or hashing space when `use_hash=True`
- embedding_dim : embedding dimension
- use_hash : defualt `False`.If `True` the input will be hashed to space of size `vocabulary_size`.
- use_hash : default `False`.If `True` the input will be hashed to space of size `vocabulary_size`.
- vocabulary_path : default `None`. The `CSV` text file path of the vocabulary table used by `tf.lookup.TextFileInitializer`, which assigns one entry in the table for each line in the file. One entry contains two columns separated by comma, the first is the value column, the second is the key column. The `0` value is reserved to use if a key is missing in the table, so hash value need start from `1`.
- dtype : default `int32`.dtype of input tensor.
- embeddings_initializer : initializer for the `embeddings` matrix.
- embedding_name : default `None`. If None, the embedding_name will be same as `name`.
Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# History
- 07/18/2021 : [v0.8.7](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.7) released.Support pre-defined key-value vocabulary in `Hash` Layer. [example](./Examples.html#hash-layer-with-pre-defined-key-value-vocabulary)
- 06/14/2021 : [v0.8.6](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.6) released.Add [IFM](./Features.html#ifm-input-aware-factorization-machine) [DIFM](./Features.html#difm-dual-input-aware-factorization-machine), [FEFM and DeepFEFM](./Features.html#deepfefm-deep-field-embedded-factorization-machine) model.
- 03/13/2021 : [v0.8.5](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.5) released.Add [BST](./Features.html#bst-behavior-sequence-transformer) model.
- 02/12/2021 : [v0.8.4](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.4) released.Fix bug in DCN-Mix.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/Quick-Start.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=data[feat].max() + 1,
```
- Feature Hashing on the fly
```python
fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=1e6,embedding_dim=4, use_hash=True, dtype='string') # since the input is string
fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=1e6,embedding_dim=4, use_hash=True, dtype='string') # the input is string
for feat in sparse_features] + [DenseFeat(feat, 1, )
for feat in dense_features]
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.8.6'
release = '0.8.7'


# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ You can read the latest code and related projects

News
-----
07/18/2021 : Support pre-defined key-value vocabulary in `Hash` Layer. `example <./Examples.html#hash-layer-with-pre-defined-key-value-vocabulary>`_ `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.7>`_

06/14/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ , `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ and `DeepFEFM <./Features.html#deepfefm-deep-field-embedded-factorization-machine>`_ . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.6>`_

03/13/2021 : Add `BST <./Features.html#bst-behavior-sequence-transformer>`_ . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.5>`_

02/12/2021 : Fix bug in DCN-Mix. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.4>`_

DisscussionGroup
-----------------------

Expand Down
7 changes: 7 additions & 0 deletions examples/movielens_age_vocabulary.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
1,1
2,18
3,25
4,35
5,45
6,50
7,56
Loading

0 comments on commit 9f15559

Please sign in to comment.