Source code for libreco.bases.sage_base

from dataclasses import astuple
from typing import Union

import numpy as np
import torch
from tqdm import tqdm

from ..bases import EmbedBase
from ..graph import NeighborWalker
from ..graph.inference import full_neighbor_embeddings
from ..graph.message import ItemMessage, ItemMessageDGL, UserMessage
from ..torchops import device_config


[docs]class SageBase(EmbedBase): """Base class for GraphSage and PinSage. Graph neural network algorithms using neighbor sampling and node features. See Also -------- ~libreco.algorithms.GraphSage ~libreco.algorithms.PinSage ~libreco.algorithms.GraphSageDGL ~libreco.algorithms.PinSageDGL """ def __init__( self, task, data_info, loss_type="cross_entropy", paradigm="i2i", embed_size=16, n_epochs=20, lr=0.001, lr_decay=False, epsilon=1e-8, amsgrad=False, reg=None, batch_size=256, num_neg=1, dropout_rate=0.0, remove_edges=False, num_layers=2, num_neighbors=3, num_walks=10, sample_walk_len=5, margin=1.0, sampler="random", start_node="random", focus_start=False, seed=42, device="cuda", lower_upper_bound=None, full_inference=False, ): super().__init__(task, data_info, embed_size, lower_upper_bound) self.loss_type = loss_type self.paradigm = paradigm self.n_epochs = n_epochs self.lr = lr self.lr_decay = lr_decay self.epsilon = epsilon self.amsgrad = amsgrad self.reg = reg self.batch_size = batch_size self.num_neg = num_neg self.dropout_rate = dropout_rate self.remove_edges = remove_edges self.num_layers = num_layers self.num_neighbors = num_neighbors self.num_walks = num_walks self.sample_walk_len = sample_walk_len self.margin = margin self.sampler = sampler self.start_node = start_node self.focus_start = focus_start self.neighbor_walker = None self.full_inference = full_inference self.seed = seed self.device = device_config(device) self.use_dgl = "DGL" in self.model_name self.torch_model = None self._check_params() def _check_params(self): if self.task != "ranking": raise ValueError(f"`{self.model_name}` is only suitable for ranking") if self.paradigm not in ("u2i", "i2i"): raise ValueError("`paradigm` must either be `u2i` or `i2i`") if self.loss_type not in ("cross_entropy", "focal", "bpr", "max_margin"): raise ValueError(f"unsupported `loss_type`: {self.loss_type}") if self.paradigm == "i2i" and self.start_node not in ("random", "unpopular"): raise ValueError("`start_nodes` must either be `random` or `unpopular`") if not self.sampler: raise ValueError( f"`{self.model_name}` must use negative sampling, make sure data " f"only contains positive samples when using negative sampling." ) def build_model(self): raise NotImplementedError def get_user_repr(self, user_data: UserMessage): users, sparse_indices, dense_values = astuple(user_data) return self.torch_model.user_repr(users, sparse_indices, dense_values) def get_item_repr(self, item_data: Union[ItemMessage, ItemMessageDGL]): if isinstance(item_data, ItemMessage): ( item, sparse_indices, dense_values, neighbors, neighbor_sparse, neighbor_dense, offsets, weights, ) = astuple(item_data) return self.torch_model( item, sparse_indices, dense_values, neighbors, neighbor_sparse, neighbor_dense, offsets, weights, ) else: blocks, start_nodes, sparse_indices, dense_values = astuple(item_data) return self.torch_model(blocks, start_nodes, sparse_indices, dense_values) @torch.inference_mode() def set_embeddings(self): assert isinstance(self.neighbor_walker, NeighborWalker) self.torch_model.eval() if self.full_inference and self.use_dgl: self.item_embeds_np = full_neighbor_embeddings(self) else: item_embed = [] all_items = list(range(self.n_items)) for i in tqdm(range(0, self.n_items, self.batch_size), desc="item embeds"): batch_items = all_items[i : i + self.batch_size] if self.use_dgl: batch_items = torch.tensor(batch_items, dtype=torch.long) item_data = self.neighbor_walker(batch_items) item_data = item_data.to_device(self.device) item_reprs = self.get_item_repr(item_data) item_embed.append(item_reprs.detach().cpu().numpy()) self.item_embeds_np = np.concatenate(item_embed, axis=0) self.user_embeds_np = self._compute_user_embeddings() @torch.inference_mode() def _compute_user_embeddings(self): self.torch_model.eval() user_embed = [] if self.paradigm == "u2i": for i in range(0, self.n_users, self.batch_size): users = np.arange(i, min(i + self.batch_size, self.n_users)) user_data = self.neighbor_walker.get_user_feats(users) user_data = user_data.to_device(self.device) user_reprs = self.get_user_repr(user_data) user_embed.append(user_reprs.detach().cpu().numpy()) return np.concatenate(user_embed, axis=0) else: for u in range(self.n_users): items = self.user_consumed[u] user_embed.append(np.mean(self.item_embeds_np[items], axis=0)) # user_embed.append(self.item_embed[items[-1]]) return np.array(user_embed)