Source code for libreco.algorithms.swing

"""Implementation of Swing."""
import pathlib

from ..bases import Base
from ..evaluation import print_metrics
from ..prediction.preprocess import convert_id
from ..recommendation import construct_rec, popular_recommendations
from ..utils.misc import time_block
from ..utils.save_load import load_params, save_params
from ..utils.sparse import build_sparse
from ..utils.validate import check_fitting, check_unknown, check_unknown_user


[docs]class Swing(Base): """*Swing* algorithm. .. CAUTION:: + Swing can only be used in ``ranking`` task. Parameters ---------- task : {'ranking'} Recommendation task. See :ref:`Task`. data_info : :class:`~libreco.data.DataInfo` object Object that contains useful information for training and inference. top_k : int, default: 20 Number of items to consider during recommendation. alpha : float, default: 1.0 Smoothing coefficient. max_cache_num : int, default: 100,000,000 Maximum cached item number during swing score computing. num_threads : int, default: 1 Number of threads to use. seed : int, default: 42 Random seed. References ---------- *Xiaoyong Yang et al.* `Large Scale Product Graph Construction for Recommendation in E-commerce <https://arxiv.org/pdf/2010.05525>`_. """ def __init__( self, task, data_info, top_k=20, alpha=1.0, max_cache_num=100_000_000, num_threads=1, seed=42, ): super().__init__(task, data_info, lower_upper_bound=None) assert task == "ranking", "`Swing` is only suitable for ranking task." self.all_args = locals() self.top_k = top_k self.alpha = alpha self.max_cache_num = max_cache_num self.num_threads = num_threads self.seed = seed self.rs_model = None self.incremental = False
[docs] def fit( self, train_data, neg_sampling, verbose=1, eval_data=None, metrics=None, k=10, eval_batch_size=8192, eval_user_num=None, ): import recfarm check_fitting(self, train_data, eval_data, neg_sampling, k) self.show_start_time() user_interacts = build_sparse(train_data.sparse_interaction) item_interacts = build_sparse(train_data.sparse_interaction, transpose=True) self.rs_model = recfarm.Swing( self.task, self.top_k, self.alpha, self.max_cache_num, self.n_users, self.n_items, user_interacts, item_interacts, self.user_consumed, self.default_pred, ) with time_block("swing computing", verbose=1): self.rs_model.compute_swing(self.num_threads, self.incremental) num = self.rs_model.num_swing_elements() density_ratio = 100 * num / (self.n_items * self.n_items) print(f"swing num_elements: {num}, density: {density_ratio:5.4f} %") if verbose > 1: print_metrics( model=self, neg_sampling=neg_sampling, eval_data=eval_data, metrics=metrics, eval_batch_size=eval_batch_size, k=k, sample_user_num=eval_user_num, seed=self.seed, ) print("=" * 30)
[docs] def predict(self, user, item, cold_start="popular", inner_id=False): user_arr, item_arr = convert_id(self, user, item, inner_id) unknown_num, _, user_arr, item_arr = check_unknown(self, user_arr, item_arr) if unknown_num > 0 and cold_start != "popular": raise ValueError(f"{self.model_name} only supports popular strategy") preds = self.rs_model.predict(user_arr.tolist(), item_arr.tolist()) return preds[0] if len(user_arr) == 1 else preds
[docs] def recommend_user( self, user, n_rec, cold_start="popular", inner_id=False, filter_consumed=True, random_rec=False, ): result_recs = dict() user_ids, unknown_users = check_unknown_user(self.data_info, user, inner_id) if unknown_users: if cold_start != "popular": raise ValueError( f"{self.model_name} only supports `popular` cold start strategy" ) for u in unknown_users: result_recs[u] = popular_recommendations( self.data_info, inner_id, n_rec ) if user_ids: computed_recs, no_rec_indices = self.rs_model.recommend( user_ids, n_rec, filter_consumed, random_rec, ) for i in no_rec_indices: computed_recs[i] = popular_recommendations( self.data_info, inner_id=True, n_rec=n_rec ) user_recs = construct_rec(self.data_info, user_ids, computed_recs, inner_id) result_recs.update(user_recs) return result_recs
[docs] def save(self, path, model_name, **kwargs): import recfarm path_obj = pathlib.Path(path) if not path_obj.is_dir(): print(f"file folder {path} doesn't exists, creating a new one...") path_obj.mkdir(parents=True, exist_ok=False) save_params(self, path, model_name) recfarm.save_swing(self.rs_model, path, model_name)
[docs] @classmethod def load(cls, path, model_name, data_info, **kwargs): import recfarm hparams = load_params(path, data_info, model_name) model = cls(**hparams) model.rs_model = recfarm.load_swing(path, model_name) return model
def rebuild_model(self, path, model_name): import recfarm self.rs_model = recfarm.load_swing(path, model_name) self.rs_model.n_users = self.n_users self.rs_model.n_items = self.n_items self.rs_model.user_consumed = self.user_consumed self.incremental = True