Source code for libserving.serialization.embed

import os

import numpy as np

from libreco.bases import EmbedBase

from .common import (
    check_path_exists,
    save_id_mapping,
    save_model_name,
    save_to_json,
    save_user_consumed,
)


[docs]def save_embed(path: str, model: EmbedBase): """Save Embed model to disk. Parameters ---------- path : str Model saving path. model : EmbedBase Model to save. """ check_path_exists(path) save_model_name(path, model) save_id_mapping(path, model.data_info) save_user_consumed(path, model.data_info) save_vectors(path, model.user_embeds_np, model.n_users, "user_embed.json") save_vectors(path, model.item_embeds_np, model.n_items, "item_embed.json")
def save_vectors(path: str, embeds: np.ndarray, num: int, name: str): embed_path = os.path.join(path, name) embed_dict = dict() for i in range(num): embed_dict[i] = embeds[i].tolist() save_to_json(embed_path, embed_dict) def save_faiss_index(path: str, model: EmbedBase, nlist: int = 80, nprobe: int = 10): import faiss check_path_exists(path) index_path = os.path.join(path, "faiss_index.bin") item_embeds = model.item_embeds_np[: model.n_items].astype(np.float32) d = item_embeds.shape[1] quantizer = faiss.IndexFlatIP(d) index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) index.train(item_embeds) index.add(item_embeds) index.nprobe = nprobe faiss.write_index(index, index_path)