"""Implementation of TwoTower model."""
import numpy as np
from ..bases import DynEmbedBase, ModelMeta
from ..feature.ssl import get_mutual_info
from ..layers import dense_nn, normalize_embeds
from ..tfops import dropout_config, reg_config, tf
from ..torchops import hidden_units_config
from ..utils.misc import count_params
from ..utils.validate import dense_field_size, sparse_feat_size
[docs]class TwoTower(DynEmbedBase, metaclass=ModelMeta, backend="tensorflow"):
"""*TwoTower* algorithm. See :ref:`TwoTower` for more details.
.. CAUTION::
TwoTower can only be used in ``ranking`` task.
.. versionadded:: 1.2.0
Parameters
----------
task : {'ranking'}
Recommendation task. See :ref:`Task`.
data_info : :class:`~libreco.data.DataInfo` object
Object that contains useful information for training and inference.
loss_type : {'cross_entropy', 'max_margin', 'softmax'}, default: 'softmax'
Loss for model training.
embed_size: int, default: 16
Vector size of embeddings.
norm_embed : bool, default: False
Whether to l2 normalize output embeddings.
It is generally recommended to normalize embeddings in ``TwoTower`` model.
n_epochs : int, default: 10
Number of epochs for training.
lr : float, default 0.001
Learning rate for training.
lr_decay : bool, default: False
Whether to use learning rate decay.
epsilon : float, default: 1e-5
A small constant added to the denominator to improve numerical stability in Adam optimizer.
According to the `official comment <https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/python/training/adam.py#L64>`_,
the default value of `1e-8` for `epsilon` might not be a good default in general, so here we choose `1e-5`.
Users can try tuning this hyperparameter if the training is unstable.
reg : float or None, default: None
Regularization parameter, must be non-negative or None.
batch_size : int, default: 256
Batch size for training.
sampler : {'random', 'unconsumed', 'popular'}, default: 'random'
Negative sampling strategy. These strategies are only used in ``cross_entropy`` and ``max_margin`` loss.
For ``softmax`` loss, in-batch sampling is leveraged based on Reference[1].
- ``'random'`` means random sampling.
- ``'unconsumed'`` samples items that the target user did not consume before.
- ``'popular'`` has a higher probability to sample popular items as negative samples.
num_neg : int, default: 1
Number of negative samples for each positive sample, only used in `ranking` task.
use_bn : bool, default: True
Whether to use batch normalization.
dropout_rate : float or None, default: None
Probability of an element to be zeroed. If it is None, dropout is not used.
hidden_units : int, list of int or tuple of (int,), default: (128, 64, 32)
Number of layers and corresponding layer size in MLP.
margin : float, default: 1.0
Margin used in `max_margin` loss.
use_correction : bool, default: True
Whether to use sampling bias correction in softmax loss described in Reference[1].
temperature : float, default: 1.0
Parameter added in logits when computing softmax. A typical value would be in the range [0.05, 0.5].
If one sets ``temperature <= 0``, it will be treated as a variable and learned during training.
remove_accidental_hits : bool, default: False
Whether to remove accidental hits of examples used as negatives. An accidental hit is defined as
a candidate that is used as an in-batch negative but has the same id with the positive candidate.
Note this could make the training slower.
ssl_pattern : {'rfm', 'rfm-complementary', 'cfm'} or None, default: None
Whether to use self-supervised learning technique described in References[2].
Note that self-supervised learning can only be used in softmax loss.
- ``'rfm'`` stands for *Random Feature Masking*.
- ``'rfm-complementary'`` stands for *Random Feature Masking* with complementary masking.
- ``'cfm'`` stands for *Correlated Feature Masking*. In this case mutual information is used according to the paper.
alpha : int, default: 0.2
Parameter for controlling self-supervised loss weight in total loss during multi-task training.
seed : int, default: 42
Random seed.
tf_sess_config : dict or None, default: None
Optional TensorFlow session config, see `ConfigProto options
<https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/core/protobuf/config.proto#L431>`_.
Raises
------
ValueError
If ``ssl_pattern`` is not None and data doesn't have item sparse features.
ValueError
If ``ssl_pattern`` is not None and ``loss_type`` is not ``softmax``.
References
----------
[1] *Xinyang Yi et al.* `Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations
<https://storage.googleapis.com/pub-tools-public-publication-data/pdf/6c8a86c981a62b0126a11896b7f6ae0dae4c3566.pdf>`_.
[2] *Tiansheng Yao et al.* `Self-supervised Learning for Large-scale Item Recommendations
<https://arxiv.org/pdf/2007.12865.pdf>`_.
"""
user_variables = ("embedding/user_embeds_var",)
item_variables = ("embedding/item_embeds_var",)
sparse_variables = ("embedding/sparse_embeds_var",)
dense_variables = ("embedding/dense_embeds_var",)
def __init__(
self,
task,
data_info=None,
loss_type="softmax",
embed_size=16,
norm_embed=False,
n_epochs=20,
lr=0.001,
lr_decay=False,
epsilon=1e-5,
reg=None,
batch_size=256,
sampler="random",
num_neg=1,
use_bn=True,
dropout_rate=None,
hidden_units=(128, 64, 32),
margin=1.0,
use_correction=True,
temperature=1.0,
remove_accidental_hits=False,
ssl_pattern=None,
alpha=0.2,
seed=42,
tf_sess_config=None,
):
super().__init__(task, data_info, embed_size, norm_embed)
self.all_args = locals()
self.loss_type = loss_type
self.n_epochs = n_epochs
self.lr = lr
self.lr_decay = lr_decay
self.epsilon = epsilon
self.reg = reg_config(reg)
self.batch_size = batch_size
self.sampler = sampler
self.num_neg = num_neg
self.use_bn = use_bn
self.dropout_rate = dropout_config(dropout_rate)
self.hidden_units = hidden_units_config(hidden_units)
self.margin = margin
self.use_correction = use_correction
self.temperature = temperature
self.remove_accidental_hits = remove_accidental_hits
self.ssl_pattern = ssl_pattern
self.alpha = alpha
self.seed = seed
self.user_sparse = True if data_info.user_sparse_col.name else False
self.item_sparse = True if data_info.item_sparse_col.name else False
self.user_dense = True if data_info.user_dense_col.name else False
self.item_dense = True if data_info.item_dense_col.name else False
self._check_params()
def _check_params(self):
if self.task != "ranking":
raise ValueError("`TwoTower` is only suitable for ranking")
if self.loss_type not in ("cross_entropy", "max_margin", "softmax"):
raise ValueError(f"Unsupported `loss_type`: {self.loss_type}")
if self.ssl_pattern is not None:
if self.ssl_pattern not in ("rfm", "rfm-complementary", "cfm"):
raise ValueError(
f"`ssl` pattern supports `rfm`, `rfm-complementary` and `cfm`, "
f"got {self.ssl_pattern}."
)
if not self.item_sparse:
raise ValueError(
"`ssl`(self-supervised learning) relies on item sparse features, "
"which are not available in training data."
)
if self.loss_type != "softmax":
raise ValueError(
"`ssl`(self-supervised learning) can only be used in `softmax` loss."
)
def build_model(self):
tf.set_random_seed(self.seed)
self._build_placeholders()
self._build_variables()
self.user_embeds = self.compute_user_embeddings("user")
self.item_embeds = self.compute_item_embeddings("item")
self.serving_topk = self.build_topk()
if self.loss_type == "cross_entropy":
self.output = tf.reduce_sum(self.user_embeds * self.item_embeds, axis=1)
if self.loss_type == "max_margin":
self.item_embeds_neg = self.compute_item_embeddings("item_neg")
if self.ssl_pattern is not None:
self.ssl_left_embeds = self.compute_ssl_embeddings("ssl_left")
self.ssl_right_embeds = self.compute_ssl_embeddings("ssl_right")
count_params()
# print([x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"])
def _build_placeholders(self):
self.user_indices = tf.placeholder(tf.int32, shape=[None])
self.item_indices = tf.placeholder(tf.int32, shape=[None])
if self.loss_type == "cross_entropy":
self.labels = tf.placeholder(tf.float32, shape=[None])
if self.loss_type == "max_margin":
self.item_indices_neg = tf.placeholder(tf.int32, shape=[None])
if self.loss_type == "softmax" and self.use_correction:
self.correction = tf.placeholder(tf.float32, shape=[None])
self.is_training = tf.placeholder_with_default(False, shape=[])
if self.user_sparse:
self.user_sparse_indices = tf.placeholder(
tf.int32, shape=[None, len(self.data_info.user_sparse_col.name)]
)
if self.user_dense:
self.user_dense_values = tf.placeholder(
tf.float32, shape=[None, len(self.data_info.user_dense_col.name)]
)
if self.item_sparse:
self.item_sparse_indices = tf.placeholder(
tf.int32, shape=[None, len(self.data_info.item_sparse_col.name)]
)
if self.loss_type == "max_margin":
self.item_sparse_indices_neg = tf.placeholder(
tf.int32, shape=[None, len(self.data_info.item_sparse_col.name)]
)
if self.item_dense:
self.item_dense_values = tf.placeholder(
tf.float32, shape=[None, len(self.data_info.item_dense_col.name)]
)
if self.loss_type == "max_margin":
self.item_dense_values_neg = tf.placeholder(
tf.float32, shape=[None, len(self.data_info.item_dense_col.name)]
)
if self.ssl_pattern is not None:
# item_indices + sparse_indices
self.ssl_left_sparse_indices = tf.placeholder(
tf.int32, shape=[None, len(self.data_info.item_sparse_col.name) + 1]
)
self.ssl_right_sparse_indices = tf.placeholder(
tf.int32, shape=[None, len(self.data_info.item_sparse_col.name) + 1]
)
if self.item_dense:
self.ssl_left_dense_values = tf.placeholder(
tf.float32, shape=[None, len(self.data_info.item_dense_col.name)]
)
self.ssl_right_dense_values = tf.placeholder(
tf.float32, shape=[None, len(self.data_info.item_dense_col.name)]
)
def _build_variables(self):
with tf.variable_scope("embedding"):
self.user_embeds_var = tf.get_variable(
name="user_embeds_var",
shape=(self.n_users + 1, self.embed_size),
initializer=tf.glorot_uniform_initializer(),
regularizer=self.reg,
)
self.item_embeds_var = tf.get_variable(
name="item_embeds_var",
shape=(self.n_items, self.embed_size),
initializer=tf.glorot_uniform_initializer(),
regularizer=self.reg,
)
if self.user_sparse or self.item_sparse:
self.sparse_embeds_var = tf.get_variable(
name="sparse_embeds_var",
shape=(sparse_feat_size(self.data_info), self.embed_size),
initializer=tf.glorot_uniform_initializer(),
regularizer=self.reg,
)
if self.user_dense or self.item_dense:
self.dense_embeds_var = tf.get_variable(
name="dense_embeds_var",
shape=(dense_field_size(self.data_info), self.embed_size),
initializer=tf.glorot_uniform_initializer(),
regularizer=self.reg,
)
if self.temperature <= 0.0:
self.temperature_var = tf.get_variable(
name="temperature_var",
shape=(),
initializer=tf.ones_initializer(),
trainable=True,
)
if self.ssl_pattern is not None:
default_var = tf.get_variable(
name="default_var",
shape=[1, self.embed_size],
initializer=tf.zeros_initializer(),
trainable=False,
)
self.ssl_embeds_var = tf.concat(
[default_var, self.item_embeds_var, self.sparse_embeds_var], axis=0
)
def compute_user_embeddings(self, category):
user_embed = tf.nn.embedding_lookup(self.user_embeds_var, self.user_indices)
concat_embeds = [user_embed]
if self.user_sparse:
user_sparse_embed = self._compute_sparse_feats(category)
concat_embeds.append(user_sparse_embed)
if self.user_dense:
user_dense_embed = self._compute_dense_feats(category)
concat_embeds.append(user_dense_embed)
user_features = (
tf.concat(concat_embeds, axis=1)
if len(concat_embeds) > 1
else concat_embeds[0]
)
return self._shared_layers(user_features, "user_tower")
def compute_item_embeddings(self, category):
if category == "item":
item_embed = tf.nn.embedding_lookup(self.item_embeds_var, self.item_indices)
elif category == "item_neg":
item_embed = tf.nn.embedding_lookup(
self.item_embeds_var, self.item_indices_neg
)
else:
raise ValueError("Unknown item category")
concat_embeds = [item_embed]
if self.item_sparse:
item_sparse_embed = self._compute_sparse_feats(category)
concat_embeds.append(item_sparse_embed)
if self.item_dense:
item_dense_embed = self._compute_dense_feats(category)
concat_embeds.append(item_dense_embed)
item_features = (
tf.concat(concat_embeds, axis=1)
if len(concat_embeds) > 1
else concat_embeds[0]
)
return self._shared_layers(item_features, "item_tower")
def compute_ssl_embeddings(self, category):
ssl_embed = self._compute_sparse_feats(category)
if self.item_dense:
ssl_dense = self._compute_dense_feats(category)
ssl_embed = tf.concat([ssl_embed, ssl_dense], axis=1)
return self._shared_layers(ssl_embed, "item_tower")
def _compute_sparse_feats(self, category):
if category == "user":
sparse_indices = self.user_sparse_indices
elif category == "item":
sparse_indices = self.item_sparse_indices
elif category == "item_neg":
sparse_indices = self.item_sparse_indices_neg
elif category == "ssl_left":
sparse_indices = self.ssl_left_sparse_indices
elif category == "ssl_right":
sparse_indices = self.ssl_right_sparse_indices
else:
raise ValueError("Unknown sparse indices category.")
if category.startswith("ssl"):
sparse_embed = tf.nn.embedding_lookup(self.ssl_embeds_var, sparse_indices)
else:
sparse_embed = tf.nn.embedding_lookup(
self.sparse_embeds_var, sparse_indices
)
return tf.keras.layers.Flatten()(sparse_embed)
def _compute_dense_feats(self, category):
if category == "user":
dense_col_indices = self.data_info.user_dense_col.index
dense_values = self.user_dense_values
else:
dense_col_indices = self.data_info.item_dense_col.index
if category == "item":
dense_values = self.item_dense_values
elif category == "item_neg":
dense_values = self.item_dense_values_neg
elif category == "ssl_left":
dense_values = self.ssl_left_dense_values
elif category == "ssl_right":
dense_values = self.ssl_right_dense_values
else:
raise ValueError("Unknown dense values category.")
batch_size = tf.shape(dense_values)[0]
dense_embed = tf.gather(self.dense_embeds_var, dense_col_indices, axis=0)
dense_embed = tf.expand_dims(dense_embed, axis=0)
dense_embed = tf.tile(dense_embed, [batch_size, 1, 1])
# broadcast element-wise multiplication
return tf.keras.layers.Flatten()(dense_values[:, :, tf.newaxis] * dense_embed)
def _shared_layers(self, inputs, name):
embeds = dense_nn(
inputs,
self.hidden_units,
use_bn=self.use_bn,
dropout_rate=self.dropout_rate,
is_training=self.is_training,
reuse_layer=True,
name=name,
)
return normalize_embeds(embeds, backend="tf") if self.norm_embed else embeds
[docs] def fit(
self,
train_data,
neg_sampling,
verbose=1,
shuffle=True,
eval_data=None,
metrics=None,
k=10,
eval_batch_size=8192,
eval_user_num=None,
num_workers=0,
):
if self.loss_type == "softmax" and self.use_correction:
_, item_counts = np.unique(train_data.item_indices, return_counts=True)
assert len(item_counts) == self.n_items
self.item_corrections = item_counts / len(train_data)
if self.ssl_pattern is not None and self.ssl_pattern == "cfm":
self.sparse_feat_mutual_info = get_mutual_info(train_data, self.data_info)
super().fit(
train_data,
neg_sampling,
verbose,
shuffle,
eval_data,
metrics,
k,
eval_batch_size,
eval_user_num,
)
def set_embeddings(self):
super().set_embeddings()
if hasattr(self, "temperature_var"):
learned_temperature = self.sess.run(self.temperature_var)
print(f"Learned temperature variable: {learned_temperature}")
def adjust_logits(self, logits, all_adjust=True):
temperature = (
self.temperature_var
if hasattr(self, "temperature_var")
else self.temperature
)
logits = tf.math.divide_no_nan(logits, temperature)
if self.use_correction and all_adjust:
correction = tf.clip_by_value(self.correction, 1e-8, 1.0)
logQ = tf.reshape(tf.math.log(correction), (1, -1))
logits -= logQ
if self.remove_accidental_hits and all_adjust:
row_items = tf.reshape(self.item_indices, (1, -1))
col_items = tf.reshape(self.item_indices, (-1, 1))
equal_items = tf.cast(tf.equal(row_items, col_items), tf.float32)
label_diag = tf.eye(tf.shape(logits)[0])
mask = tf.cast(equal_items - label_diag, tf.bool)
paddings = tf.fill(tf.shape(logits), tf.float32.min)
return tf.where(mask, paddings, logits)
else:
return logits