本章详细介绍了图神经网络(GNN)在节点级、边级和图级三大任务中的应用,并结合具体案例进行讲解。核心概念包括节点分类、链接预测和图分类。节点分类任务通过半监督学习,利用部分节点标签预测其他节点标签,常用于引文网络和社交网络用户画像。链接预测则关注预测两个节点之间是否存在边,广泛应用于推荐系统和知识图谱补全。图分类任务对整图进行分类预测,常用于分子属性预测和蛋白质功能预测。读者将学习到如何选择合适的GNN模型和评估指标,并掌握PyTorch Geometric和DGL等常用工具库的使用。完成本章后,读者能够应用GNN解决实际问题,如设计推荐系统、进行药物发现和交通预测等。
GNN 应用:节点 / 边 / 图三大任务
GNN 落地主要有三类任务: 节点级 (每个节点一个预测), 边级 (每条边一个预测), 图级 (整图一个预测)。
这一章拆解三大任务和实战案例。
1. 节点级任务 (Node-level)
1.1 节点分类 (Node Classification)
最经典的任务, 已知部分节点标签, 预测其他节点:
- Cora 引文网络 (2708 论文, 7 类)
- Reddit 帖子分类 (232K 帖子, 41 类)
- 社交网络用户画像 (用户分群)
训练方式: 半监督, 只需要部分标签, 消息传递会传播到全图。
评估: 准确率 / Macro-F1。
from torch_geometric.nn import GATConv
class NodeClassifier(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.gat1 = GATConv(in_dim, hidden_dim, heads=8)
self.gat2 = GATConv(hidden_dim * 8, out_dim, heads=1)
def forward(self, x, edge_index):
x = self.gat1(x, edge_index).relu()
x = self.gat2(x, edge_index)
return x
1.2 节点回归 (Node Regression)
预测节点的连续值:
- 分子原子能量 (DFT 计算昂贵, GNN 预测)
- 城市交通流量预测
- 蛋白质残基属性
1.3 节点聚类 (Node Clustering)
把节点分群, 无监督:
- 社区检测 (Louvain / 谱聚类)
- 用 GNN embedding + KMeans
from torch_geometric.nn import GCNConv
from sklearn.cluster import KMeans
# 训 GCN, 拿 embedding
model = GCN(in_dim, hidden_dim, out_dim)
# ... 训练 ...
with torch.no_grad():
embeddings = model.conv1(data.x, data.edge_index).cpu().numpy()
# KMeans 聚类
clusters = KMeans(n_clusters=10).fit_predict(embeddings)
2. 边级任务 (Edge-level)
2.1 链接预测 (Link Prediction)
预测两个节点之间是否会有边, 推荐系统核心:
- 好友推荐 (预测用户 u 和 v 是否会成为好友)
- 物品推荐 (预测用户 u 是否会喜欢物品 i)
- 知识图谱补全 (预测三元组 (h, r, ?) 缺什么)
核心思路: 学一个评分函数 f(h, r, t), 给真实三元组高分, 假三元组低分。
from torch_geometric.nn import GCNConv
# 用 GCN 学节点 embedding
node_emb = model(data.x, data.edge_index)
# 边存在性预测 (u, v): 拼接 u 和 v embedding, 过 MLP
edge_pred = torch.sigmoid(MLP(torch.cat([node_emb[u], node_emb[v]], dim=1)))
2.2 三种评分函数
DistMult (最简单):
f(h, r, t) = <h, r, t> = sum_i h_i * r_i * t_i
TransE (平移不变, 几何直观):
f(h, r, t) = -||h + r - t||_2
RotatE (复数空间旋转, 强大):
t = h ∘ r (复数乘法 = 旋转)
f = -||t - t'||
2.3 负采样
只训练正样本会过拟合, 要负采样生成负例:
- 随机负采样: 随机替换头或尾
- 难负采样 (Hard Negative): 用相似但错的样本 (e.g. 同一类型但不同实体)
- 对比学习: InfoNCE 损失
# 随机负采样
def negative_sample(edge_index, num_nodes):
neg_u = torch.randint(0, num_nodes, (edge_index.size(1),))
neg_v = torch.randint(0, num_nodes, (edge_index.size(1),))
return torch.stack([neg_u, neg_v])
2.4 实战:好友推荐
import torch
from torch_geometric.nn import GCNConv
class FriendRecommender(torch.nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def predict(self, z, edge_label_index):
# 用余弦相似度
u = z[edge_label_index[0]]
v = z[edge_label_index[1]]
return (u * v).sum(dim=-1)
3. 图级任务 (Graph-level)
3.1 图分类 (Graph Classification)
整图一个预测, 每个图独立:
- 分子属性预测 (毒理 / 溶解度)
- 蛋白质功能预测
- 社交网络异常检测
关键: 需要Readout 把节点 embedding 聚合成图 embedding:
from torch_geometric.nn import global_mean_pool, global_max_pool
# 训 GCN 拿节点 embedding
node_emb = gcn(data.x, data.edge_index)
# Readout: 整图 batch 所有节点 pool 到一个向量
graph_emb = global_mean_pool(node_emb, data.batch) # (B, d)
# 分类
logits = classifier(graph_emb)
Readout 方式:
global_mean_pool: 所有节点取平均global_max_pool: 所有节点取最大global_add_pool: 所有节点求和- Set2Set: 用 LSTM 聚合
- SortPool: 排序后取前 K 个
3.2 图回归 (Graph Regression)
整图预测连续值, 如分子能量:
# 同样 readout + 回归
graph_emb = global_mean_pool(node_emb, data.batch)
energy = regressor(graph_emb) # 标量
3.3 图生成 (Graph Generation)
生成新图, 分子设计 / 药物发现:
- GraphVAE: VAE 变体, 生成节点 / 边
- GraphGAN: GAN 思路
- Diffusion: 图扩散模型, 当前 SOTA
- MolGPT: Transformer 自回归生成 SMILES
# 用 RDKit + 预训练 GNN 生成新分子
from torch_geometric.nn import GINConv
from rdkit import Chem
generator = MolecularGenerator(node_dim=..., edge_dim=...)
new_smiles = generator.sample(n=1000)
valid = [s for s in new_smiles if Chem.MolFromSmiles(s) is not None]
4. 三大任务对比
| 任务 | 输入 | 输出 | 例子 |
|---|---|---|---|
| 节点分类 | 1 个节点 (在全图) | 标签 | 引文网络 |
| 链接预测 | 2 个节点 | 是否有边 | 好友推荐 |
| 图分类 | 整图 | 标签 | 分子属性 |
5. 工具与库
PyTorch Geometric (PyG)
学术最常用:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool
dataset = TUDataset(root='data/MUTAG', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
class GraphClassifier(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.conv1 = GINConv(MLP(in_dim, hidden_dim))
self.conv2 = GINConv(MLP(hidden_dim, hidden_dim))
self.classifier = Linear(hidden_dim, out_dim)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
x = global_add_pool(x, batch)
return self.classifier(x)
DGL
工业界更常用, 多 backend (PyTorch / MXNet / TF):
import dgl
import dgl.nn as dglnn
g = dgl.graph(([0, 1, 2], [1, 2, 3]))
conv = dglnn.GraphConv(in_dim, out_dim)
h = conv(g, h)
6. 工业案例
6.1 推荐系统 (Pinterest / Uber Eats)
- 把用户-物品交互当二部图
- PinSage (Pinterest 2018): GraphSAGE 变体, 亿级物品图
- LightGCN (He et al. 2020): 简化 GCN, 去掉特征变换, 适合推荐
6.2 药物发现 (AlphaFold / ESM)
- 分子图预测毒理 / 溶解度
- 蛋白质图 (残基 + 接触) 预测功能
- AlphaFold 3: 用 GNN 预测蛋白质-DNA-小分子复合物结构
6.3 交通预测 (Google Maps / DiDi)
- 路网 = 图, 路口 = 节点, 路 = 边
- 预测每条路 5-30 分钟后的车速
- STGCN / DCRNN: 时空 GNN
6.4 金融风控 (蚂蚁 / PayPal)
- 用户-设备-账户构成异质图
- GNN 检测欺诈团伙 / 套现路径
7. 评估指标
| 任务 | 指标 |
|---|---|
| 节点分类 | Accuracy / Macro-F1 |
| 链接预测 | AUC / Hits@K / MRR |
| 图分类 | Accuracy / ROC-AUC |
| 图回归 | MAE / RMSE |
8. 数据集速查
- Cora / CiteSeer / Pubmed: 节点分类 baseline
- OGB (Open Graph Benchmark): 大规模标准 benchmark
- MUTAG / PROTEINS: 图分类小数据
- ZINC: 分子图, 15K-250K 分子
- OAG: 学术图谱, 异质图
总结
GNN 应用 = 节点 / 边 / 图 三大任务, 每个都有标准 baseline 和数据集。
最后一章Graph Transformer, 看 GNN 与 Transformer 结合的新方向。
章末小测验
检验你对《GNN 应用:节点 / 边 / 图三大任务》的掌握程度。
链接预测的核心目标是?
图分类任务中 Readout 操作用来?
TransE 的核心思想是?