DeepAniGNet: Privacy-Preserving Recommendation via Graph Neural Networks
Overview
Published as arXiv:2508.14905 (“Privacy Preserving Inference of Personalized Content for Out of Matrix Users”), DeepAniGNet is a three-module hybrid recommendation system that addresses the cold-start problem while preserving user privacy through aggregate content representations. The system combines: (1) a BERT-based content encoder (custom bert-anime2vec pretrained on anime synopses + reviews) producing 774-dimensional item embeddings, (2) a GINEConv graph neural network over a 10,000-node item-item recommendation graph with edge features, and (3) a DropoutGNet scoring network with stochastic masking (p=0.5) of Weighted Matrix Factorization (WMF) latent factors — simulating cold-start conditions during training so the model learns to score using content features alone. At inference, cold-start users are represented by a content basket — the masked average of BERT encodings of up to 10 watched items — transmitted instead of raw preference histories, providing privacy through aggregation. Evaluated on the custom AnimeULike dataset (10,000 anime, 13,000 users scraped from MyAnimeList) and the standard CiteULike benchmark, achieving state-of-the-art cold-start Recall@100.
Three-Module Architecture
Module 1: TransformerNet (Content Encoder)
Custom bert-anime2vec model fine-tuned on anime text. For each anime, concatenates synopsis + reviews, tokenizes with BERT tokenizer (max_length=512, padding='max_length', truncation=True), and extracts the [CLS] token embedding from last_hidden_state[:, 0, :] with per-sample z-score normalization. Concatenated with 6 MinMaxScaler-normalized scalar features (rating, rank, popularity, members, favorites) → 774-dimensional item embedding.
User encoding: Randomly samples up to max_shows_per_user=10 items from the user’s watch history, BERT-encodes each, and computes masked averaging across the subset → single 774-dim user content vector. This averaged representation does not reveal individual items — providing privacy through aggregation.
Selective fine-tuning: Only the last transformer layer (layer 5 of DistilBERT) is unfrozen; all earlier layers remain frozen to prevent catastrophic forgetting.
Module 2: GINEConv Graph Network
Builds an item-item graph from MyAnimeList user-generated recommendations (“if you liked X, try Y” pairs with helpfulness counts).
GINEConv variant (primary — when edge attributes available):
Edge features [num_edges, feat_dim] → Linear(feat_dim → node_dim) → Edge embeddings
Node features [10000, feat_dim] + Edge embeddings → GINEConv(
MLP: Linear(feat_dim→96) → Tanh → Linear(96→64) → Tanh → Linear(64→32) → Tanh
) → node_reprs [10000, 32] → BatchNorm (per-feature z-score)
GCNConv fallback (3-layer): GCNConv(feat→96) → Tanh → GCNConv(96→64) → Tanh → GCNConv(64→32) → Tanh.
Batch-local subgraph extraction: mask_edge_inds() selects only edges where both endpoints are in the current batch of item indices, creating efficient batch-local subgraphs via boolean mask intersection.
Node features support multiple modalities: TF-IDF (2000-dim), BERT encodings (768-dim), or one-hot (10,000-dim). Edge features similarly support TF-IDF or BERT representations of recommendation text.
Module 3: DropoutGNet (Scoring Network)
Inputs: u [B, 200] (WMF user), v [B, 200] (WMF item),
u_phi [B, 774] (content), v_phi [B, 774+32] (content + GNN)
Stochastic masking (training): u = Bernoulli(0.5) · u, v = Bernoulli(0.5) · v
Concatenate: phi_u = [u; u_phi] → [B, 974]
phi_v = [v; v_phi; gnn_feats] → [B, 1006]
Hidden: Linear(974→500) → BatchNorm → Tanh (user branch)
Linear(1006→500) → BatchNorm → Tanh (item branch)
Project: f_u = Linear(500→10), f_v = Linear(500→10)
Score: R̂ = f_u^T · f_v (batch matrix multiply)
Loss: MSE(R̂, u^T v) (WMF ground truth)
All linear layers initialized with N(0, 0.01) weights and zero biases. The stochastic dropout (p=0.5 for both user and item factors) is the key training mechanism: by randomly zeroing WMF factors, the model is forced to learn scoring from content features alone — enabling cold-start inference by setting WMF factors to zero.
Privacy-Preserving Inference
The privacy mechanism operates through content basket aggregation:
- Client-side: User samples up to 10 items from their watch history, computes BERT encodings locally, and transmits only the masked average embedding — a single 774-dim vector that does not reveal individual items watched
- Server-side: At cold-start inference, WMF factors are set to zero (
v = np.zeros_like(v)), and the model scores using only the content basket embedding + GNN item representations - Training simulation: The
eip/eup(eval item/user drop probability) parameters control masking at evaluation time — wheneup=0.0, user WMF factors are completely zeroed, simulating full privacy-preserving inference
Dataset: AnimeULike
Web scraping pipeline (via urllib + BeautifulSoup with multiprocessing.Pool):
| Data Source | Scale | Fields |
|---|---|---|
| MAL anime pages | 10,000 titles | Synopsis, score, rank, popularity, members, favorites |
| MAL reviews | 50 pages/anime | Full review text (concatenated with synopsis) |
| MAL recommendations | User-generated pairs | (show1, show2, helpfulness, recommendation text) |
| MAL user profiles | ~13,000 users | Per-user anime ratings |
Data split (random_state=0): ~9,915 train / 50 validation / 35 test anime titles. Train items: 8,000 (with index mapping). WMF pre-computed at rank=200, regularization=1, α=10, 10 iterations, with z-score standardization and clipping to [−5, 5].
Training Configuration
| Parameter | Value |
|---|---|
| Effective batch size | 100 (batch × accumulation) |
| Optimizer | Adam (lr=1e-4) or SGD (lr=0.005) |
| LR scheduling | Exponential decay (rate 0.9–1.0) |
| Epochs | 250 (anime), 5 (CiteULike pretrain) |
| WMF rank | 200 |
| Projection latent dim | 10 |
| Hidden dims | [500] |
| Dropout (user/item) | 0.5 / 0.5 |
| Activation | Tanh + BatchNorm |
| BERT max_length | 512 tokens |
| Max shows per user | 10 |
| Train samples | 50K positive + 50K negative |
| DataParallel | Multi-GPU |
Evaluation-time caching: Both TransformerNet (u_phi_cache/v_phi_cache) and DropoutGNet (f_u_cache/f_v_cache) maintain embedding caches. A UserAnimeCacheID “perimeter” indexing strategy processes all items for user 0, then only the first item for subsequent users, ensuring full user×item coverage with minimal redundant computation.
Evaluation
Metrics: Item Recall@100 (fraction of relevant items in top-100 per user) and User Recall@100 (fraction of preferring users who receive the item in top-100). Cold-start evaluation zeros out item WMF factors and reconstructs R̂ = U_hat · V_hat^T from cached embeddings.
The system achieves state-of-the-art cold-start Recall@100 on CiteULike and outperforms both WMF baseline and DropoutNet on AnimeULike. The content basket approach enables competitive cold-start performance without exposing individual viewing history.
Demo
Tech Stack
Python (3.8+), PyTorch (1.8.1), PyTorch Geometric (1.7.0 — GINEConv, GCNConv), HuggingFace Transformers (4.5.1 — BERT, DistilBERT), scikit-learn (TF-IDF, MinMaxScaler), Pandas, NumPy, SciPy, BeautifulSoup (web scraping), Matplotlib, NetworkX, Google Colab