Source code for libserving.serialization.knn

import os

from scipy import sparse

from libreco.bases import CfBase

from .common import (
    check_path_exists,
    save_id_mapping,
    save_model_name,
    save_to_json,
    save_user_consumed,
)


[docs]def save_knn(path: str, model: CfBase, k: int): """Save KNN model to disk. Parameters ---------- path : str Model saving path. model : CfBase Model to save. k : int Number of similar users/items to save. """ check_path_exists(path) save_model_name(path, model) save_id_mapping(path, model.data_info) save_user_consumed(path, model.data_info) save_sim_matrix(path, model.sim_matrix, k)
def save_sim_matrix(path: str, sim_matrix: sparse.csr_matrix, k: int): k_sims = dict() num = len(sim_matrix.indptr) - 1 indices = sim_matrix.indices.tolist() indptr = sim_matrix.indptr.tolist() data = sim_matrix.data.tolist() for i in range(num): i_slice = slice(indptr[i], indptr[i + 1]) sorted_sims = sorted(zip(indices[i_slice], data[i_slice]), key=lambda x: -x[1]) k_sims[i] = sorted_sims[:k] sim_path = os.path.join(path, "sim.json") save_to_json(sim_path, k_sims)