"""Implementation of Swing."""
import pathlib
from ..bases import Base
from ..evaluation import print_metrics
from ..prediction.preprocess import convert_id
from ..recommendation import construct_rec, popular_recommendations
from ..utils.misc import time_block
from ..utils.save_load import load_params, save_params
from ..utils.sparse import build_sparse
from ..utils.validate import check_fitting, check_unknown, check_unknown_user
[docs]class Swing(Base):
"""*Swing* algorithm.
.. CAUTION::
+ Swing can only be used in ``ranking`` task.
Parameters
----------
task : {'ranking'}
Recommendation task. See :ref:`Task`.
data_info : :class:`~libreco.data.DataInfo` object
Object that contains useful information for training and inference.
top_k : int, default: 20
Number of items to consider during recommendation.
alpha : float, default: 1.0
Smoothing coefficient.
max_cache_num : int, default: 100,000,000
Maximum cached item number during swing score computing.
num_threads : int, default: 1
Number of threads to use.
seed : int, default: 42
Random seed.
References
----------
*Xiaoyong Yang et al.* `Large Scale Product Graph Construction for Recommendation in E-commerce
<https://arxiv.org/pdf/2010.05525>`_.
"""
def __init__(
self,
task,
data_info,
top_k=20,
alpha=1.0,
max_cache_num=100_000_000,
num_threads=1,
seed=42,
):
super().__init__(task, data_info, lower_upper_bound=None)
assert task == "ranking", "`Swing` is only suitable for ranking task."
self.all_args = locals()
self.top_k = top_k
self.alpha = alpha
self.max_cache_num = max_cache_num
self.num_threads = num_threads
self.seed = seed
self.rs_model = None
self.incremental = False
[docs] def fit(
self,
train_data,
neg_sampling,
verbose=1,
eval_data=None,
metrics=None,
k=10,
eval_batch_size=8192,
eval_user_num=None,
):
import recfarm
check_fitting(self, train_data, eval_data, neg_sampling, k)
self.show_start_time()
user_interacts = build_sparse(train_data.sparse_interaction)
item_interacts = build_sparse(train_data.sparse_interaction, transpose=True)
if self.incremental:
assert self.rs_model is not None
with time_block("update swing", verbose=1):
self.rs_model.update_swing(
self.num_threads, user_interacts, item_interacts
)
else:
self.rs_model = recfarm.Swing(
self.top_k,
self.alpha,
self.max_cache_num,
self.n_users,
self.n_items,
user_interacts,
item_interacts,
self.user_consumed,
self.default_pred,
)
with time_block("swing computing", verbose=1):
self.rs_model.compute_swing(self.num_threads)
num = self.rs_model.num_swing_elements()
density_ratio = 100 * num / (self.n_items * self.n_items)
print(f"swing num_elements: {num}, density: {density_ratio:5.4f} %")
if verbose > 1:
print_metrics(
model=self,
neg_sampling=neg_sampling,
eval_data=eval_data,
metrics=metrics,
eval_batch_size=eval_batch_size,
k=k,
sample_user_num=eval_user_num,
seed=self.seed,
)
print("=" * 30)
[docs] def predict(self, user, item, cold_start="popular", inner_id=False):
user_arr, item_arr = convert_id(self, user, item, inner_id)
unknown_num, _, user_arr, item_arr = check_unknown(self, user_arr, item_arr)
if unknown_num > 0 and cold_start != "popular":
raise ValueError(f"{self.model_name} only supports popular strategy")
preds = self.rs_model.predict(user_arr.tolist(), item_arr.tolist())
return preds[0] if len(user_arr) == 1 else preds
[docs] def recommend_user(
self,
user,
n_rec,
cold_start="popular",
inner_id=False,
filter_consumed=True,
random_rec=False,
):
result_recs = dict()
user_ids, unknown_users = check_unknown_user(self.data_info, user, inner_id)
if unknown_users:
if cold_start != "popular":
raise ValueError(
f"{self.model_name} only supports `popular` cold start strategy"
)
for u in unknown_users:
result_recs[u] = popular_recommendations(
self.data_info, inner_id, n_rec
)
if user_ids:
computed_recs, additional_rec_counts = self.rs_model.recommend(
user_ids,
n_rec,
filter_consumed,
random_rec,
)
for rec, arc in zip(computed_recs, additional_rec_counts):
if arc > 0:
additional_recs = popular_recommendations(
self.data_info, inner_id=True, n_rec=arc
)
rec.extend(additional_recs)
user_recs = construct_rec(self.data_info, user_ids, computed_recs, inner_id)
result_recs.update(user_recs)
return result_recs
[docs] def save(self, path, model_name, **kwargs):
import recfarm
path_obj = pathlib.Path(path)
if not path_obj.is_dir():
print(f"file folder {path} doesn't exists, creating a new one...")
path_obj.mkdir(parents=True, exist_ok=False)
save_params(self, path, model_name)
recfarm.save_swing(self.rs_model, path, model_name)
[docs] @classmethod
def load(cls, path, model_name, data_info, **kwargs):
import recfarm
hparams = load_params(path, data_info, model_name)
model = cls(**hparams)
model.rs_model = recfarm.load_swing(path, model_name)
return model
[docs] def rebuild_model(self, path, model_name):
"""Assign the saved model variables to the newly initialized model.
This method is used before retraining the new model, in order to avoid training
from scratch every time we get some new data.
Parameters
----------
path : str
File folder path for the saved model variables.
model_name : str
Name of the saved model file.
"""
import recfarm
self.rs_model = recfarm.load_swing(path, model_name)
self.rs_model.n_users = self.n_users
self.rs_model.n_items = self.n_items
self.rs_model.user_consumed = self.user_consumed
self.incremental = True