模型训练部署 Pipeline 介绍:以一个简单图神经网络为例

流程

  1. 获取图数据:通过钞能力或白嫖获取有标注的数据。

  2. 图数据预处理:

    • 将原始数据转换成图结构,包括节点特征、边特征和图结构信息。

    • 对节点特征和边特征进行归一化、标准化等预处理操作。

  3. 图神经网络模型构建:定义层和模型。

  4. 模型训练:设置超参数、优化器,得到训练后的权重。

  5. 模型评估: 在测试集上评估训练好的模型性能,如分类准确率、F1 分数等。

  6. 模型优化: 保存模型,对模型进行必要的优化,如量化、剪枝等, 减小模型大小并提高推理速度。

  7. 模型部署和维护:将模型部署到生产环境中。监控模型在生产环境中的运行状态和性能指标。

训练

我们用的是之前自己实现的 GraphConv 层。

graph_conv.py

 1import torch
 2from torch import Tensor
 3from torch_geometric.nn import MessagePassing
 4from torch_geometric.utils import add_self_loops, degree
 5torch.manual_seed(42)
 6
 7class GraphConv(MessagePassing):
 8    def __init__(self, in_channels: int, out_channels: int):
 9        super(GraphConv, self).__init__(aggr='add')  # "Add" aggregation.
10        self.lin = torch.nn.Linear(in_channels, out_channels)
11        self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
12        self.reset_parameters()
13
14    def reset_parameters(self):
15        torch.nn.init.xavier_uniform_(self.lin.weight)
16        torch.nn.init.zeros_(self.bias)
17
18    def forward(self, x: Tensor, edge_index: Tensor):
19        # Add self-loops to the adjacency matrix.
20        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
21
22        # Compute normalization.
23        row, col = edge_index
24        deg = degree(col, x.size(0), dtype=x.dtype)
25        deg_inv_sqrt = deg.pow(-0.5)
26        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
27
28        # Start propagating messages.
29        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm)
30
31    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:   
32        # Normalize node features.
33        return norm.view(-1, 1) * x_j
34
35    def update(self, aggr_out: Tensor):
36        # Add bias after aggregation.
37        biased = self.lin(aggr_out) + self.bias
38        return biased

train.py

读取数据集、设置超参数、训练、评估、保存模型。

 1```python
 2import sys
 3sys.path.append('./')
 4from graph_conv import GraphConv
 5
 6import torch
 7import torch.nn.functional as F
 8from torch_geometric.datasets import Planetoid
 9from torch_geometric.data import Data, Dataset
10
11class GCN(torch.nn.Module):
12    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
13        super().__init__()
14        self.conv1 = GraphConv(in_channels, hidden_channels)
15        self.conv2 = GraphConv(hidden_channels, out_channels)
16
17    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
18        x = self.conv1(x, edge_index)
19        x = F.relu(x)
20        x = F.dropout(x, p=0.5, training=self.training)
21        x = self.conv2(x, edge_index)
22        return F.log_softmax(x, dim=1)
23
24def train(model: GCN, data: Dataset, lr: float, epochs: int) -> None:
25    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
26
27    for epoch in range(epochs):
28        optimizer.zero_grad()
29        output = model(data.x, data.edge_index)
30        loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
31        loss.backward()
32        optimizer.step()
33
34        evaluate(model, data)
35        acc = evaluate(model, data)
36        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {acc:.4f}')
37
38def evaluate(model: GCN, data: Data) -> float:
39    model.eval()
40    _, pred = model(data.x, data.edge_index).max(dim=1)
41    correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
42    acc = correct / int(data.test_mask.sum())
43    return acc
44
45def main():
46    dataset_path = './dataset'
47    dataset = Planetoid(root=dataset_path, name='Cora')
48    data = dataset[0]
49
50    hidden_dim = 16
51    lr = 0.001
52    epochs = 100
53
54    model = GCN(dataset.num_features, hidden_dim, dataset.num_classes)
55
56    train(model, data, lr, epochs)
57
58    test_acc = evaluate(model, data)
59    print(f'Test Accuracy: {test_acc:.4f}')
60
61    torch.save(model.state_dict(), './gcn_cora.pth')
62
63if __name__ == '__main__':
64    main()

推理模型

训练得到的 pth 模型更像一个 checkpoint,而不是真正用来推理的模型。因此还需要导出。

导出推理模型

一般 Pytorch 使用 torch.onnx.export 可以导出为 ONNX 格式,这个格式比较通用。

1$ pip3 install onnx
 1def main():
 2    ...
 3    export_onnx(model, data)
 4
 5def export_onnx(model: GCN, data: Data) -> None:
 6    model.eval()
 7    x = data.x
 8    edge_index = data.edge_index
 9    torch.onnx.export(model, (x, edge_index), 'gcn_cora.onnx', input_names=['x', 'edge_index'], output_names=['output'], opset_version=11)
10
11    print('ONNX model exported successfully')

查看模型结构

ONNX 库提供了模型结构探索的 API:

 1import onnx
 2import argparse
 3
 4if __name__ == '__main__':
 5    parser = argparse.ArgumentParser()
 6    parser.add_argument('--model_path', type=str, required=True)
 7    args = parser.parse_args()
 8
 9    model = onnx.load(args.model_path)
10
11    def dim_str(dims):
12        return ', '.join(str(dim.dim_value) if dim.WhichOneof('value') == 'dim_value' else '?' for dim in dims)
13
14    print("Inputs:")
15    for input in model.graph.input:
16        print(input.name, dim_str(input.type.tensor_type.shape.dim))
17    print("Outputs:")
18    for output in model.graph.output:
19        print(output.name, dim_str(output.type.tensor_type.shape.dim))
20
21    print("Nodes:")
22    for node in model.graph.node:
23        print(node.name, node.op_type, ":", ", ".join(node.input), "=>", ", ".join(node.output))

另外最方便的就是用 onnx.helper.printable_graph 函数。

Netron,也叫 ONNX Visualizer 可以直接在浏览器中查看 ONNX 模型的结构。访问:https://netron.app

gh

图中的节点就是 Op,除开常见的,简要介绍一些:

  1. Gather: 从输入张量中根据索引张量提取元素。常用于从大张量中提取部分元素,如从一个批次的数据中提取某些样本。

    • 输入: 一个数据张量和一个索引张量

    • 输出:根据索引提取的子张量。

  2. Expand: 通过复制沿指定维度扩展输入张量的大小。常用于将小张量扩展到与其他张量匹配的大小,以便进行后续的计算。

    • 输入: 一个张量和一个目标形状

    • 输出:扩展后的张量。

  3. Gemm: 执行通用矩阵乘法 (General Matrix Multiplication)。是神经网络中常见的线性变换操作。

    • 输入: 两个矩阵和一个可选的偏置向量

    • 输出:矩阵乘法的结果。

  4. Squeeze: 从输入张量中删除所有大小为 1 的维度。常用于去除不必要的维度,以便于后续的计算和处理。

    • 输入:一个张量

    • 输出:删除了大小为 1 的维度的张量。

  5. ScatterElements: 根据索引张量更新输入张量的元素。常用于有选择地更新张量中的部分元素,如更新神经网络的参数。

    • 输入: 一个数据张量、一个索引张量和一个更新张量

    • 输出:更新后的数据张量。

  6. Tile: 通过复制沿指定维度扩展输入张量。与 Expand 算子类似,但 Tile 是通过复制来扩展,而 Expand 是通过复制和填充。

    • 输入: 一个张量和一个重复次数张量

    • 输出:扩展后的张量。

模型小型化

之前的步骤我们得知了模型的结构如下:

 1graph torch_jit (
 2  %x[FLOAT, 2708x1433]
 3  %edge_index[INT64, 2x10556]
 4) initializers (
 5  %conv1.bias[FLOAT, 16]
 6  %conv1.lin.weight[FLOAT, 16x1433]
 7  %conv1.lin.bias[FLOAT, 16]
 8  %conv2.bias[FLOAT, 7]
 9  %conv2.lin.weight[FLOAT, 7x16]
10  %conv2.lin.bias[FLOAT, 7]
11) {
12  %/conv1/Constant_output_0 = Constant[value = <Scalar Tensor []>]()
13  %/conv1/Constant_1_output_0 = Constant[value = <Tensor>]()
14  %onnx::Tile_10 = Constant[value = <Tensor>]()
15  %/conv1/Constant_2_output_0 = Constant[value = <Tensor>]()
16  %/conv1/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/conv1/Constant_2_output_0)
17  %/conv1/Expand_output_0 = Expand(%/conv1/Constant_1_output_0, %/conv1/ConstantOfShape_output_0)
18  %/conv1/Tile_output_0 = Tile(%/conv1/Expand_output_0, %onnx::Tile_10)
19  %/conv1/Constant_3_output_0 = Constant[value = <Scalar Tensor []>]()
20  %/conv1/Concat_output_0 = Concat[axis = 1](%edge_index, %/conv1/Tile_output_0)
21  %/conv1/Split_output_0, %/conv1/Split_output_1 = Split[axis = 0, split = [1, 1]](%/conv1/Concat_output_0)
22  %/conv1/Squeeze_output_0 = Squeeze[axes = [0]](%/conv1/Split_output_0)
23  %/conv1/Squeeze_1_output_0 = Squeeze[axes = [0]](%/conv1/Split_output_1)
24  %/conv1/Constant_4_output_0 = Constant[value = <Tensor>]()
25  %/conv1/Constant_5_output_0 = Constant[value = <Tensor>]()
26  %/conv1/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv1/Constant_5_output_0, %/conv1/Squeeze_1_output_0, %/conv1/Constant_4_output_0)
27  %/conv1/Constant_6_output_0 = Constant[value = <Tensor>]()
28  %/conv1/Add_output_0 = Add(%/conv1/Constant_6_output_0, %/conv1/ScatterElements_output_0)
29  %/conv1/Constant_7_output_0 = Constant[value = <Scalar Tensor []>]()
30  %/conv1/Pow_output_0 = Pow(%/conv1/Add_output_0, %/conv1/Constant_7_output_0)
31  %/conv1/Gather_output_0 = Gather[axis = 0](%/conv1/Pow_output_0, %/conv1/Squeeze_output_0)
32  %/conv1/Gather_1_output_0 = Gather[axis = 0](%/conv1/Pow_output_0, %/conv1/Squeeze_1_output_0)
33  %/conv1/Mul_output_0 = Mul(%/conv1/Gather_output_0, %/conv1/Gather_1_output_0)
34  %/conv1/Gather_2_output_0 = Gather[axis = 0](%/conv1/Concat_output_0, %/conv1/Constant_3_output_0)
35  %/conv1/Gather_3_output_0 = Gather[axis = 0](%/conv1/Concat_output_0, %/conv1/Constant_output_0)
36  %/conv1/Gather_4_output_0 = Gather[axis = -2](%x, %/conv1/Gather_3_output_0)
37  %/conv1/Constant_8_output_0 = Constant[value = <Tensor>]()
38  %/conv1/Reshape_output_0 = Reshape(%/conv1/Mul_output_0, %/conv1/Constant_8_output_0)
39  %/conv1/Mul_1_output_0 = Mul(%/conv1/Reshape_output_0, %/conv1/Gather_4_output_0)
40  %/conv1/aggr_module/Constant_output_0 = Constant[value = <Tensor>]()
41  %/conv1/aggr_module/Reshape_output_0 = Reshape(%/conv1/Gather_2_output_0, %/conv1/aggr_module/Constant_output_0)
42  %/conv1/aggr_module/Shape_output_0 = Shape(%/conv1/Mul_1_output_0)
43  %/conv1/aggr_module/Expand_output_0 = Expand(%/conv1/aggr_module/Reshape_output_0, %/conv1/aggr_module/Shape_output_0)
44  %/conv1/aggr_module/Constant_1_output_0 = Constant[value = <Tensor>]()
45  %/conv1/aggr_module/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv1/aggr_module/Constant_1_output_0, %/conv1/aggr_module/Expand_output_0, %/conv1/Mul_1_output_0)
46  %/conv1/aggr_module/Constant_2_output_0 = Constant[value = <Tensor>]()
47  %/conv1/aggr_module/Add_output_0 = Add(%/conv1/aggr_module/Constant_2_output_0, %/conv1/aggr_module/ScatterElements_output_0)
48  %/conv1/lin/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%/conv1/aggr_module/Add_output_0, %conv1.lin.weight, %conv1.lin.bias)
49  %/conv1/Add_1_output_0 = Add(%/conv1/lin/Gemm_output_0, %conv1.bias)
50  %/Relu_output_0 = Relu(%/conv1/Add_1_output_0)
51  %/conv2/Constant_output_0 = Constant[value = <Tensor>]()
52  %/conv2/Constant_1_output_0 = Constant[value = <Tensor>]()
53  %/conv2/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/conv2/Constant_1_output_0)
54  %/conv2/Expand_output_0 = Expand(%/conv2/Constant_output_0, %/conv2/ConstantOfShape_output_0)
55  %/conv2/Tile_output_0 = Tile(%/conv2/Expand_output_0, %onnx::Tile_10)
56  %/conv2/Concat_output_0 = Concat[axis = 1](%edge_index, %/conv2/Tile_output_0)
57  %/conv2/Split_output_0, %/conv2/Split_output_1 = Split[axis = 0, split = [1, 1]](%/conv2/Concat_output_0)
58  %/conv2/Squeeze_output_0 = Squeeze[axes = [0]](%/conv2/Split_output_0)
59  %/conv2/Squeeze_1_output_0 = Squeeze[axes = [0]](%/conv2/Split_output_1)
60  %/conv2/Constant_2_output_0 = Constant[value = <Tensor>]()
61  %/conv2/Constant_3_output_0 = Constant[value = <Tensor>]()
62  %/conv2/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv2/Constant_3_output_0, %/conv2/Squeeze_1_output_0, %/conv2/Constant_2_output_0)
63  %/conv2/Constant_4_output_0 = Constant[value = <Tensor>]()
64  %/conv2/Add_output_0 = Add(%/conv2/Constant_4_output_0, %/conv2/ScatterElements_output_0)
65  %/conv2/Constant_5_output_0 = Constant[value = <Scalar Tensor []>]()
66  %/conv2/Pow_output_0 = Pow(%/conv2/Add_output_0, %/conv2/Constant_5_output_0)
67  %/conv2/Gather_output_0 = Gather[axis = 0](%/conv2/Pow_output_0, %/conv2/Squeeze_output_0)
68  %/conv2/Gather_1_output_0 = Gather[axis = 0](%/conv2/Pow_output_0, %/conv2/Squeeze_1_output_0)
69  %/conv2/Mul_output_0 = Mul(%/conv2/Gather_output_0, %/conv2/Gather_1_output_0)
70  %/conv2/Gather_2_output_0 = Gather[axis = 0](%/conv2/Concat_output_0, %/conv1/Constant_3_output_0)
71  %/conv2/Gather_3_output_0 = Gather[axis = 0](%/conv2/Concat_output_0, %/conv1/Constant_output_0)
72  %/conv2/Gather_4_output_0 = Gather[axis = -2](%/Relu_output_0, %/conv2/Gather_3_output_0)
73  %/conv2/Constant_6_output_0 = Constant[value = <Tensor>]()
74  %/conv2/Reshape_output_0 = Reshape(%/conv2/Mul_output_0, %/conv2/Constant_6_output_0)
75  %/conv2/Mul_1_output_0 = Mul(%/conv2/Reshape_output_0, %/conv2/Gather_4_output_0)
76  %/conv2/aggr_module/Constant_output_0 = Constant[value = <Tensor>]()
77  %/conv2/aggr_module/Reshape_output_0 = Reshape(%/conv2/Gather_2_output_0, %/conv2/aggr_module/Constant_output_0)
78  %/conv2/aggr_module/Shape_output_0 = Shape(%/conv2/Mul_1_output_0)
79  %/conv2/aggr_module/Expand_output_0 = Expand(%/conv2/aggr_module/Reshape_output_0, %/conv2/aggr_module/Shape_output_0)
80  %/conv2/aggr_module/Constant_1_output_0 = Constant[value = <Tensor>]()
81  %/conv2/aggr_module/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv2/aggr_module/Constant_1_output_0, %/conv2/aggr_module/Expand_output_0, %/conv2/Mul_1_output_0)
82  %/conv2/aggr_module/Constant_2_output_0 = Constant[value = <Tensor>]()
83  %/conv2/aggr_module/Add_output_0 = Add(%/conv2/aggr_module/Constant_2_output_0, %/conv2/aggr_module/ScatterElements_output_0)
84  %/conv2/lin/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%/conv2/aggr_module/Add_output_0, %conv2.lin.weight, %conv2.lin.bias)
85  %/conv2/Add_1_output_0 = Add(%/conv2/lin/Gemm_output_0, %conv2.bias)
86  %output = LogSoftmax[axis = 1](%/conv2/Add_1_output_0)
87  return %output
88}

这是一个两层的图卷积神经网络。第一层包含一个图卷积层和一个全连接层,第二层也包含一个图卷积层和一个全连接层,最后接一个 LogSoftmax 作为输出层。

常见的小型化(压缩)技术:

  1. 参数量剪枝

    • 对全连接层的权重矩阵进行剪枝,移除一些绝对值较小的权重

    • 对图卷积层的权重向量进行剪枝

  2. 网络结构剪枝

    • 移除其中一层的图卷积层和全连接层,从而减少模型深度
  3. 量化

    • 将权重和激活从 FP32 量化到 INT8 等更低比特的定点数

下面我们实现这些操作。由于封装的很好,不需要知道原理,调包就行:

 1from __future__ import annotations
 2import argparse
 3import onnxruntime
 4from onnxruntime.quantization import quantize_dynamic, QuantType
 5from onnxruntime.quantization.shape_inference import quant_pre_process
 6from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession
 7from torch_geometric.datasets import Planetoid
 8
 9def evaluate_model(model: InferenceSession, data) -> float:
10    """
11    Evaluate the accuracy of the given model on the test set.
12
13    Args:
14        model (InferenceSession): The ONNX model to be evaluated.
15        data: The dataset containing features, edge indices, labels, and test mask.
16
17    Returns:
18        float: The accuracy of the model on the test set.
19    """
20    x, edge_index, y = data.x, data.edge_index, data.y
21    test_mask = data.test_mask
22    input_name = model.get_inputs()[0].name
23    output_name = model.get_outputs()[0].name
24    output = model.run([output_name], {input_name: x.cpu().numpy(), 'edge_index': edge_index.cpu().numpy()})[0]
25
26    output_test = output[test_mask]
27    pred = output_test.argmax(axis=1)
28    correct = (pred == y[test_mask].numpy()).sum()
29    acc = correct / test_mask.sum().item()
30    return acc
31
32def load_data(dataset_path: str):
33    dataset = Planetoid(root=dataset_path, name='Cora')
34    return dataset[0]
35
36def main(model_path: str, dataset_path: str):
37    """
38    Main function to evaluate, preprocess, and quantize the ONNX model.
39
40    Args:
41        model_path (str): The path to the ONNX model file.
42        dataset_path (str): The path to the dataset directory.
43    """
44    # Load data
45    data = load_data(dataset_path)
46
47    # Evaluate original model
48    session = onnxruntime.InferenceSession(model_path)
49    original_acc = evaluate_model(session, data)
50    print(f"Original model accuracy on test set: {original_acc:.4f}")
51
52    prefix = model_path.split('.')[0]
53    # Preprocess model
54    preprocessed_model_path = prefix + '_preprocessed.onnx'
55    quant_pre_process(model_path, preprocessed_model_path)
56    preprocessed_session = onnxruntime.InferenceSession(preprocessed_model_path)
57    preprocessed_acc = evaluate_model(preprocessed_session, data)
58    print(f"Preprocessed model accuracy on test set: {preprocessed_acc:.4f}")
59
60    # Quantize model
61    quantized_model_path = prefix + '_quantized.onnx'
62    quantize_dynamic(preprocessed_model_path, quantized_model_path, weight_type=QuantType.QUInt8)
63    quantized_session = onnxruntime.InferenceSession(quantized_model_path)
64    quantized_acc = evaluate_model(quantized_session, data)
65    print(f"Quantized model accuracy on test set: {quantized_acc:.4f}")
66
67if __name__ == "__main__":
68    parser = argparse.ArgumentParser(description="Evaluate and quantize ONNX model.")
69    parser.add_argument("--model_path", required=False, type=str, help="Path to the ONNX model file.", default="gcn_cora.onnx")
70    parser.add_argument("--dataset_path", required=False, type=str, help="Path to the dataset directory.", default="./dataset")
71    args = parser.parse_args()
72
73    main(args.model_path, args.dataset_path)

部署和推理

最后可以把模型部署到具体的硬件上,并进行推理,以 ORT 为例:

session = ort.InferenceSession("optimized_model.onnx", sess_options, providers=["CUDAExecutionProvider"])

当然也可以用 NV 研发的 TensorRT。另外,目前国内有很多芯片公司研发了自己的推理芯片,也产生了大量的岗位。

ONNX 虽然通用性很好,但支持的算子和框架并不统一。每个公司都想定义自己的 IR,导致迁移工作非常困难。于是 LLVM 团队提出了 MLIR,MLIR 支持 Dialect,相当于可以把各种方言转换为 MLIR(这个转换称为 lowing),进一步可以翻译为机器码执行。

onnx/onnx-mlir 实现了将 ONNX lowering 到 MLIR.

附:DGL 版本

 1import torch
 2import torch.nn as nn
 3import torch.nn.functional as F
 4from dgl.data import CoraGraphDataset
 5from dgl.nn import GraphConv
 6from dgl import AddSelfLoop
 7
 8class GCN(nn.Module):
 9    def __init__(self, in_feats, h_feats, num_classes):
10        super(GCN, self).__init__()
11        self.conv1 = GraphConv(in_feats, h_feats)
12        self.conv2 = GraphConv(h_feats, num_classes)
13
14    def forward(self, g, in_feat):
15        h = self.conv1(g, in_feat)
16        h = F.relu(h)
17        h = F.dropout(h, training=self.training)
18        h = self.conv2(g, h)
19        return F.log_softmax(h, dim=1)
20
21def evaluate(model, g, features, labels, mask):
22    model.eval()
23    with torch.no_grad():
24        logits = model(g, features)
25        logits = logits[mask]
26        labels = labels[mask]
27        _, indices = torch.max(logits, dim=1)
28        correct = torch.sum(indices == labels)
29        return correct.item() * 1.0 / len(labels)
30
31def main():
32    # load and preprocess dataset
33    transform = AddSelfLoop()
34    data = CoraGraphDataset(transform=transform)
35    g = data[0]
36    features = g.ndata['feat']
37    labels = g.ndata['label']
38    train_mask = g.ndata['train_mask']
39    val_mask = g.ndata['val_mask'] 
40    test_mask = g.ndata['test_mask']
41    in_feats = features.shape[1]
42    n_classes = data.num_labels
43    n_edges = g.number_of_edges()
44    print("""----Data statistics------'
45      Edges %d
46      Classes %d
47      Train samples %d
48      Val samples %d
49      Test samples %d""" %
50          (n_edges, n_classes,
51              train_mask.int().sum().item(),
52              val_mask.int().sum().item(),
53              test_mask.int().sum().item()))
54
55    # create GCN model    
56    model = GCN(in_feats, 16, n_classes)
57
58    # use optimizer
59    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
60
61    # initialize graph
62    dur = []
63    for epoch in range(100):
64        model.train()
65        # forward
66        logits = model(g, features)
67        loss = F.nll_loss(logits[train_mask], labels[train_mask])
68
69        optimizer.zero_grad()
70        loss.backward()
71        optimizer.step()
72
73        acc = evaluate(model, g, features, labels, val_mask)
74        print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | "
75              .format(epoch, loss.item(), acc))
76
77    print()
78    acc = evaluate(model, g, features, labels, test_mask)
79    print("Test Accuracy {:.4f}".format(acc))
80
81if __name__ == '__main__':
82    main()