溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

怎么使用Pytorch+PyG實現GraphSAGE

發布時間:2023-04-21 17:13:55 來源:億速云 閱讀:146 作者:iii 欄目:開發技術

這篇文章主要講解了“怎么使用Pytorch+PyG實現GraphSAGE”,文中的講解內容簡單清晰,易于學習與理解,下面請大家跟著小編的思路慢慢深入,一起來研究和學習“怎么使用Pytorch+PyG實現GraphSAGE”吧!

GraphSAGE簡介

GraphSAGE(Graph Sampling and Aggregation)是一種常見的圖神經網絡模型,主要用于結點級別的表征學習。該模型基于采樣和聚合策略,將一個結點及其鄰居節點信息融合在一起,得到其表征表示,并通過多輪迭代更新來提高表征的精度。

實現步驟

數據準備

在本次實現中,我們仍然使用Cora數據集作為示例進行測試,由于GraphSage主要聚焦于單一節點特征的更新,因此這里不需要對數據集做特別處理,只需要將數據轉化成PyG格式即可。

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加載cora數據集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 將nx.Graph形式的圖轉換成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 獲取節點數量和特征向量維度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要訓練的節點分割數據集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

實現模型

接下來,我們需要定義GraphSAGE模型。與傳統的GCN中只需要一層卷積操作不同,GraphSAGE包含兩層卷積和采樣(也稱“聚合”)操作。

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = hidden_channels if i != 0 else num_features
            out_channels = num_classes if i == num_layers - 1 else hidden_channels
            self.convs.append(SAGEConv(in_channels, out_channels))
    def forward(self, x, edge_index):
        for _, conv in enumerate(self.convs[:-1]):
            x = F.relu(conv(x, edge_index))
        # 最后一層不用激活函數
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=-1)

在上述代碼中,我們實現了多層GraphSAGE卷積和相應的聚合函數,并使用ReLU和softmax函數來進行特征提取和分類分數的輸出。

模型訓練

定義好模型之后,就可以開始針對Cora數據集進行模型訓練。首先還是需要先指定優化器和損失函數,并設定一些參數用于記錄訓練過程中的信息,如Epochs、Batch size、學習率等。

# 初始化GraphSage并指定參數
num_layers = 2
hidden_channels = 256
model = GraphSAGE(hidden_channels, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
# 訓練過程
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    out = model(data.x.to(device), data.edge_index.to(device))
    loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
    loss.backward()
    optimizer.step()
    # 在各個測試階段檢測一下準確率
    if epoch % 10 == 0:
        with torch.no_grad():
            _, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
            correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
            acc = correct / data.test_mask.sum().item()
            print("Epoch {:03d}, Train Loss {:.4f}, Test Acc {:.4f}".format(
                epoch, loss.item(), acc))

在上述代碼中,我們使用有標記的訓練數據擬合GraphSAGE模型,在各個驗證階段測試準確率,并通過梯度下降法優化損失函數。

感謝各位的閱讀,以上就是“怎么使用Pytorch+PyG實現GraphSAGE”的內容了,經過本文的學習后,相信大家對怎么使用Pytorch+PyG實現GraphSAGE這一問題有了更深刻的體會,具體使用情況還需要大家實踐驗證。這里是億速云,小編將為大家推送更多相關知識點的文章,歡迎關注!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女