转摘PyG利用GCN实现Cora、Citeseer、Pubmed引用论文节点分类
文章目录
- 前言
- 一、导入相关库
- 二、Cora、Citeseer、Pubmed数据集
- 三、定义配置类
- 四、定义工具类
- 五、加载数据集
- 六、定义GCN网络
- 七、定义模型
- 八、模型训练
- 九、模型验证
- 八、结果
- 完整代码
前言
大家好,我是阿光。
本专栏内包含基于GNN的项目实战案例(PyG实现),以及研究多年遇到的问题和一些总结与注意事项,理论与实践相结合,每一个代码实例都附带有完整的代码。
正在更新中~ ✨

🚨 我的项目环境:
- 平台:Windows10
- 语言环境:python3.7
- 编译器:PyCharm
- PyTorch版本:1.11.0
- PyG版本:2.1.0
💥 项目专栏:[【GNN图神经网络项目实战案例目录】](https://blog.csdn.net/m0_47256162/article/details/128961064)
本文我们将使用Pytorch + Pytorch Geometric来复现论文 [《SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS》](https://arxiv.org/pdf/1609.02907.pdf)
的实验部分,采用GCN实现Cora、Citeseer、Pubmed引用论文节点分类。
一、导入相关库
本项目我们需要结合两个库,一个是Pytorch,因为还需要按照torch的网络搭建模型进行书写,第二个是PyG,因为在torch中并没有关于图网络层的定义,所以需要torch_geometric这个库来定义一些图层。
prism language-python
import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from utils import *
from config import Config
from model import GCN
二、Cora、Citeseer、Pubmed数据集
由于本文要实现的是基于GCN实现Cora、Citeseer、Pubmed引用论文节点分类,所以需要获得对应的图数据,对于这三种图数据在 PyG
中已提供接口进行调用。
PyG中Planetoid数据集:

参数列表:
- root:加载数据集的路径,如果本地没有会自动下载
- name:导入的数据集名称,可选
Cora
、Citeseer
、Pubmed
- transform:对
torch_geometric.data.Data
图数据进行转换操作
对于 Cora
、Citeseer
、Pubmed
这三个数据集是比较经典的引用论文图数据集,常用于图的节点分类任务,图中的节点代表一篇论文,边关系代表不同论文之间的引用情况,节点特征是基于文档词袋获得的。
Dataset | Type | Nodes | Edges | Features | Classes |
---|---|---|---|---|---|
Cora | Citation Network | 2708 | 5429 | 1433 | 7 |
Citeseer | Citation Network | 3327 | 4732 | 3703 | 6 |
Pubmed | Citation Network | 19717 | 44338 | 500 | 3 |
在调用PyG接口调用图数据时可能由于网络原因下载失败,可以百度下载这些图数据,然后将其保存到本地,然后再进行导入即可。
三、定义配置类
由于我们在搭建模型以及训练时有很多超参数以及常量,为了编程规范,这里定义了一个配置类,将所有参数保存在这个类中,然后将这个类实例化,调用其中的参数,代码如下:
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:11
* 描述: 配置类
"""
import torch
class Config:
dataset_path = './data' # 数据集保存路径
dataset_name = 'Citeseer' # 数据集名称,可选Cora、Citeseer、Pubmed
epochs = 200 # 训练轮数
lr = 0.01 # 学习率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
这里都是一些较为常见的参数,小伙伴可以根据自己的编程习惯以及项目需求去定义其它参数。
四、定义工具类
为了能够持久化模型训练过程中的精度、损失等指标,我们可以自己实现一个函数,在模型训练过程中将训练日志写入到本地,日志文件如下:

代码实现如下:
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:04
* 描述: 工具类
"""
import datetime
# 打开日志文件
def open_log_file(file_name=None):
file = open('./results/' + file_name, 'w', encoding='utf-8')
return file
# 关闭日志文件
def close_log_file(file=None):
file.close()
# 打印日志信息
def log(msg='', file=None, print_msg=True, end='\n'):
if print_msg:
print(msg) # 控制台打印信息
now = datetime.datetime.now() # 获取当前时间
t = str(now.year) + '/' + str(now.month) + '/' + str(now.day) + ' ' \
+ str(now.hour).zfill(2) + ':' + str(now.minute).zfill(2) + ':' + str(now.second).zfill(2)
if isinstance(msg, str):
lines = msg.split('\n')
else:
lines = [msg]
for line in lines:
if line == lines[-1]:
file.write('[' + t + '] ' + str(line) + end)
else:
file.write('[' + t + '] ' + str(line))
五、加载数据集
该环节我们可以利用 PyG
中提供的接口 Planetoid
来调用Cora、Citeseer、Pubmed三个图数据集,关于该函数的介绍,前面已经说明。
prism language-python
# 2.加载图数据集
dataset = Planetoid(root=args.dataset_path, name=args.dataset_name)
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(args.device) # Cora的一张图
六、定义GCN网络
这里我们就不重点介绍GCN网络了,相信大家能够掌握基本原理,本文我们使用的是PyG定义网络层,在PyG中已经定义好了GCNConv这个层,该层采用的就是GCN机制。

对于GCNConv的常用参数:
- in_channels:每个样本的输入维度,就是每个节点的特征维度
- out_channels:经过
GCNConv
后映射成的新的维度,就是经过GCNConv
后每个节点的维度长度 - normalize:是否添加自环,并且是否归一化,默认为True
- add_self_loops:为图添加自环,是否考虑自身节点的信息
- bias:训练一个偏置b
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:07
* 描述: 定义GCN模型
"""
from torch import nn
import torch_geometric.nn as pyg_nn
import torch.nn.functional as F
# 定义GCNConv网络
class GCN(nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = pyg_nn.GCNConv(num_node_features, 16)
self.conv2 = pyg_nn.GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index) # [num_nodes, 16]
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index) # [num_nodes, num_classes]
return F.log_softmax(x, dim=1) # [num_nodes, num_classes]
上面网络我们定义了两个 GCNConv
层,第一层的参数的输入维度就是初始每个节点的特征维度,输出维度是16。
第二个层的输入维度为16,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。
原论文中给出的模型结构图如下:

Z = f ( X , A ) = s o f t m a x ( A ^ R e L U ( A ^ X W ( 0 ) ) W ( 1 ) ) Z=f(X,A)=softmax(\hat A ReLU(\hat A XW^{(0)})W^{(1)}) Z=f(X,A)=softmax(A^ReLU(A^XW(0))W(1))
该模型结构很简单就是利用两层的GCN机制,然后将第二层的输出结果送入到softmax中进行分类。
七、定义模型
下面就是定义了一些模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。
prism language-python
# 3.定义模型
model = GCN(num_node_features, num_classes).to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数
八、模型训练
模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。
prism language-python
# 训练模式
model.train()
for epoch in range(args.epochs):
optimizer.zero_grad()
pred = model(data)
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目
acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度
loss_train.backward()
optimizer.step()
model.eval()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask])
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
log("【EPOCH: {:3d}/{:d}】".format(epoch + 1, args.epochs) + '训练损失为: 【{:.4f}】'.format(
loss_train.item()) + ' 训练精度为: 【{:.4f}】'.format(
acc_train) + ' 测试集损失为: 【{:.4f}】'.format(loss_test.item()) + ' 测试集精度为: 【{:.4f}】'.format(acc_test),
file,
False)
if epoch % 20 == 0:
print("【EPOCH: 】%s" % str(epoch + 1))
print('训练集损失为: 【{:.4f}】'.format(loss_train.item()), '训练集精度为: 【{:.4f}】'.format(acc_train),
'测试集损失为: 【{:.4f}】'.format(loss_test.item()), '测试集精度为: 【{:.4f}】'.format(acc_test))
log('【Finished Training!!!】', file)
九、模型验证
下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集。
prism language-python
# 模型验证
model.eval()
pred = model(data)
# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()
# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()
log('Train Accuracy: 【{:.4f}】'.format(acc_train) + ' Train Loss: 【{:.4f}】'.format(loss_train), file)
log('Test Accuracy: 【{:.4f}】'.format(acc_test) + ' Test Loss: 【{:.4f}】'.format(loss_test), file)
# 关闭日志文件
close_log_file(file)
八、结果
下面为模型训练过程中的一些指标数据以及最终基于三个数据集的训练结果
prism language-python
【EPOCH: 】1
训练集损失为: 【1.7934】 训练集精度为: 【0.1833】 测试集损失为: 【1.7926】 测试集精度为: 【0.1570】
【EPOCH: 】21
训练集损失为: 【0.0293】 训练集精度为: 【1.0000】 测试集损失为: 【1.2539】 测试集精度为: 【0.6230】
【EPOCH: 】41
训练集损失为: 【0.0023】 训练集精度为: 【1.0000】 测试集损失为: 【1.5319】 测试集精度为: 【0.6240】
【EPOCH: 】61
训练集损失为: 【0.0010】 训练集精度为: 【1.0000】 测试集损失为: 【1.5962】 测试集精度为: 【0.6320】
【EPOCH: 】81
训练集损失为: 【0.0007】 训练集精度为: 【1.0000】 测试集损失为: 【1.6156】 测试集精度为: 【0.6350】
【EPOCH: 】101
训练集损失为: 【0.0006】 训练集精度为: 【1.0000】 测试集损失为: 【1.6325】 测试集精度为: 【0.6320】
【EPOCH: 】121
训练集损失为: 【0.0005】 训练集精度为: 【1.0000】 测试集损失为: 【1.6484】 测试集精度为: 【0.6340】
【EPOCH: 】141
训练集损失为: 【0.0004】 训练集精度为: 【1.0000】 测试集损失为: 【1.6633】 测试集精度为: 【0.6350】
【EPOCH: 】161
训练集损失为: 【0.0004】 训练集精度为: 【1.0000】 测试集损失为: 【1.6774】 测试集精度为: 【0.6340】
【EPOCH: 】181
训练集损失为: 【0.0003】 训练集精度为: 【1.0000】 测试集损失为: 【1.6907】 测试集精度为: 【0.6330】
【Finished Training!!!】
>>>Train Accuracy: 【1.0000】 Train Loss: 【0.0003】
>>>Test Accuracy: 【0.6300】 Test Loss: 【1.7034】
Citeseer | Cora | Pubmed | |
---|---|---|---|
Accuracy | 0.6300 | 0.7870 | 0.7670 |
Loss | 1.7034 | 0.7819 | 0.7245 |
上表中的统计指标都是基于测试数据集的
完整代码
1️⃣ 工具类
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:04
* 描述: 工具类
"""
import datetime
# 打开日志文件
def open_log_file(file_name=None):
file = open('./results/' + file_name, 'w', encoding='utf-8')
return file
# 关闭日志文件
def close_log_file(file=None):
file.close()
# 打印日志信息
def log(msg='', file=None, print_msg=True, end='\n'):
if print_msg:
print(msg) # 控制台打印信息
now = datetime.datetime.now() # 获取当前时间
t = str(now.year) + '/' + str(now.month) + '/' + str(now.day) + ' ' \
+ str(now.hour).zfill(2) + ':' + str(now.minute).zfill(2) + ':' + str(now.second).zfill(2)
if isinstance(msg, str):
lines = msg.split('\n')
else:
lines = [msg]
for line in lines:
if line == lines[-1]:
file.write('[' + t + '] ' + str(line) + end)
else:
file.write('[' + t + '] ' + str(line))
2️⃣ 配置类
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:11
* 描述: 配置类
"""
import torch
class Config:
dataset_path = './data' # 数据集保存路径
dataset_name = 'Citeseer' # 数据集名称,可选Cora、Citeseer、Pubmed
epochs = 200 # 训练轮数
lr = 0.01 # 学习率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
3️⃣ 搭建模型
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:07
* 描述: 定义GCN模型
"""
from torch import nn
import torch_geometric.nn as pyg_nn
import torch.nn.functional as F
# 定义GCNConv网络
class GCN(nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = pyg_nn.GCNConv(num_node_features, 16)
self.conv2 = pyg_nn.GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index) # [num_nodes, 16]
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index) # [num_nodes, num_classes]
return F.log_softmax(x, dim=1) # [num_nodes, num_classes]
4️⃣ 模型训练
prism language-python
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2023/2/10
* 时间: 18:07
* 描述: 使用GCN对Citeseer、Cora、Pubmed进行节点分类
"""
import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from utils import *
from config import Config
from model import GCN
args = Config()
# 1.打开日志文件
file = open_log_file(args.dataset_name)
# 2.加载图数据集
dataset = Planetoid(root=args.dataset_path, name=args.dataset_name)
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(args.device) # Cora的一张图
# 3.定义模型
model = GCN(num_node_features, num_classes).to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数
# 训练模式
model.train()
for epoch in range(args.epochs):
optimizer.zero_grad()
pred = model(data)
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目
acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度
loss_train.backward()
optimizer.step()
model.eval()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask])
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
log("【EPOCH: {:3d}/{:d}】".format(epoch + 1, args.epochs) + '训练损失为: 【{:.4f}】'.format(
loss_train.item()) + ' 训练精度为: 【{:.4f}】'.format(
acc_train) + ' 测试集损失为: 【{:.4f}】'.format(loss_test.item()) + ' 测试集精度为: 【{:.4f}】'.format(acc_test),
file,
False)
if epoch % 20 == 0:
print("【EPOCH: 】%s" % str(epoch + 1))
print('训练集损失为: 【{:.4f}】'.format(loss_train.item()), '训练集精度为: 【{:.4f}】'.format(acc_train),
'测试集损失为: 【{:.4f}】'.format(loss_test.item()), '测试集精度为: 【{:.4f}】'.format(acc_test))
log('【Finished Training!!!】', file)
# 模型验证
model.eval()
pred = model(data)
# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()
# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()
log('Train Accuracy: 【{:.4f}】'.format(acc_train) + ' Train Loss: 【{:.4f}】'.format(loss_train), file)
log('Test Accuracy: 【{:.4f}】'.format(acc_test) + ' Test Loss: 【{:.4f}】'.format(loss_test), file)
# 关闭日志文件
close_log_file(file)
===========================
【来源: CSDN】
【作者: 海洋.之心】
【原文链接】 https://weibaohang.blog.csdn.net/article/details/128977236
声明:转载此文是出于传递更多信息之目的。若有来源标注错误或侵犯了您的合法权益,请作者持权属证明与本网联系,我们将及时更正、删除,谢谢。