diff --git a/gammagl/layers/conv/gat_conv.py b/gammagl/layers/conv/gat_conv.py index db2348a6..e98123e0 100644 --- a/gammagl/layers/conv/gat_conv.py +++ b/gammagl/layers/conv/gat_conv.py @@ -1,8 +1,7 @@ import tensorlayerx as tlx from gammagl.layers.conv import MessagePassing from gammagl.utils import segment_softmax - - +from gammagl.mpops import bspmm class GATConv(MessagePassing): @@ -79,10 +78,14 @@ def __init__(self, self.linear = tlx.layers.Linear(out_features=self.out_channels * self.heads, in_features=self.in_channels, b_init=None) + + init_weight = tlx.initializers.TruncatedNormal() + self.w = tlx.nn.Parameter( + init_weight((in_channels, self.out_channels * self.heads))) initor = tlx.initializers.TruncatedNormal() - self.att_src = self._get_weights("att_src", shape=(1, self.heads, self.out_channels), init=initor, order=True) - self.att_dst = self._get_weights("att_dst", shape=(1, self.heads, self.out_channels), init=initor, order=True) + self.att = tlx.nn.Parameter( + initor((1, self.heads, self.out_channels * 2))) self.leaky_relu = tlx.layers.LeakyReLU(negative_slope) self.dropout = tlx.layers.Dropout(self.dropout_rate) @@ -91,22 +94,23 @@ def __init__(self, self.bias = self._get_weights("bias", shape=(self.heads * self.out_channels,), init=initor) elif self.add_bias and not concat: self.bias = self._get_weights("bias", shape=(self.out_channels,), init=initor) - - def message(self, x, edge_index, edge_weight=None, num_nodes=None): + + def forward(self, x, edge_index, num_nodes=None): + x = tlx.matmul(x, self.w) + x = tlx.reshape(x, shape=(-1, self.heads, self.out_channels)) node_src = edge_index[0, :] node_dst = edge_index[1, :] - weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src) - weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst) - weight = self.leaky_relu(weight_src + weight_dst) + feat_src = tlx.gather(x, node_src) + feat_dst = tlx.gather(x, node_dst) + feat = tlx.concat((feat_src, feat_dst), axis=-1) + feat = tlx.reshape(feat, shape=(-1, self.heads, self.out_channels * 2)) + e = tlx.reduce_sum(feat * self.att, axis = -1) - alpha = self.dropout(segment_softmax(weight, node_dst, num_nodes)) - x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) - return x * edge_weight if edge_weight else x + e = self.leaky_relu(e) + alpha = self.dropout(segment_softmax(e, node_dst, num_nodes)) - - def forward(self, x, edge_index, num_nodes=None): - x = tlx.reshape(self.linear(x), shape=(-1, self.heads, self.out_channels)) - x = self.propagate(x, edge_index, num_nodes=num_nodes) + x = self.propagate(x, edge_index, num_nodes=num_nodes, edge_weight=alpha) + # x = bspmm(edge_index, weight=alpha, x=x, reduce='sum') if self.concat: x = tlx.reshape(x, (-1, self.heads * self.out_channels)) diff --git a/gammagl/mpops/mindspore.py b/gammagl/mpops/mindspore.py index 815369a6..24931538 100644 --- a/gammagl/mpops/mindspore.py +++ b/gammagl/mpops/mindspore.py @@ -60,3 +60,6 @@ def segment_max(x, segment_ids, num_segments=None): def gspmm(index, weight=None, x=None, reduce='sum'): pass + +def bspmm(index, weight=None, x=None, reduce='sum'): + pass diff --git a/gammagl/mpops/paddle.py b/gammagl/mpops/paddle.py index 25f28891..1fec648a 100644 --- a/gammagl/mpops/paddle.py +++ b/gammagl/mpops/paddle.py @@ -222,3 +222,6 @@ def _scatter(x, index, updates, overwrite=True): def gspmm(index, weight=None, x=None, reduce='sum'): pass + +def bspmm(index, weight=None, x=None, reduce='sum'): + pass diff --git a/gammagl/mpops/tensorflow.py b/gammagl/mpops/tensorflow.py index 574380d0..1b86b081 100644 --- a/gammagl/mpops/tensorflow.py +++ b/gammagl/mpops/tensorflow.py @@ -60,3 +60,6 @@ def segment_min(x, segment_ids, num_segments=None): def gspmm(index, weight=None, x=None, reduce='sum'): pass + +def bspmm(index, weight=None, x=None, reduce='sum'): + pass diff --git a/gammagl/mpops/torch.py b/gammagl/mpops/torch.py index 2f775f7b..b5a4524a 100644 --- a/gammagl/mpops/torch.py +++ b/gammagl/mpops/torch.py @@ -1,7 +1,7 @@ import torch use_ext = False try: - from .torch_ext._torch_ext import c_segment_sum, c_segment_mean, c_segment_max, c_spmm_sum, c_spmm_mean, c_spmm_max + from .torch_ext._torch_ext import c_segment_sum, c_segment_mean, c_segment_max, c_spmm_sum, c_spmm_mean, c_spmm_max, c_bspmm_sum use_ext = True except: pass @@ -297,3 +297,17 @@ def gspmm(index, weight=None, x=None, reduce='sum'): return c_spmm_max(index, weight, x) else: raise Exception("Unsupported reduce type, please choose from ['sum', 'mean', 'max'].") + + +def bspmm(index, weight=None, x=None, reduce='sum'): + if weight == None: + weight = torch.ones(size=(index.shape[1], ), dtype=torch.float32) + if reduce == 'sum': + return c_bspmm_sum(index, weight, x) + # elif reduce == 'mean': + # return c_spmm_mean(index, weight, x) + # elif reduce == 'max': + # return c_spmm_max(index, weight, x) + else: + # raise Exception("Unsupported reduce type, please choose from ['sum', 'mean', 'max'].") + raise Exception("Unsupported reduce type, please choose from ['sum'].") diff --git a/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp new file mode 100644 index 00000000..ddee82b6 --- /dev/null +++ b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp @@ -0,0 +1,102 @@ +#include "./bspmm_sum_cpu.h" +#include +#include "ATen/core/TensorBody.h" + +torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x){ + if (!x.is_contiguous()) { + x = x.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + + int num_nodes = x.size(0); + int heads = x.size(1); + int out_channels = x.size(2); + + torch::Tensor out = torch::zeros_like(x, x.options()); + auto E = index.size(1); + // auto K = x.numel() / x.size(0); + + auto index_data = index.data_ptr(); + using scalar_t = float; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); + +#ifdef COMPILE_WITH_OMP +#pragma omp parallel for +#endif + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; + + for (auto h = 0; h < heads; ++h){ + for (auto k = 0; k < out_channels; ++k){ +#ifdef COMPILE_WITH_OMP +#pragma omp atomic +#endif + out_data[dst * out_channels * heads + h * out_channels + k] += + weight_data[e * heads + h] * x_data[src * out_channels * heads + h * out_channels + k]; + } + } + } + return out; +} + +std::tuple bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, torch::Tensor &grad) { + if (!grad.is_contiguous()) { + grad = grad.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + + int num_nodes = grad.size(0); + int heads = grad.size(1); + int out_channels = grad.size(2); + + torch::Tensor grad_x = torch::zeros_like(grad, grad.options()); + torch::Tensor grad_weight = torch::zeros_like(weight, weight.options()); + auto E = index.size(1); + // auto K = grad.numel() / grad.size(0); + + auto index_data = index.data_ptr(); + using scalar_t = float; + auto grad_data = grad.data_ptr(); + auto grad_x_data = grad_x.data_ptr(); + auto grad_weight_data = grad_weight.data_ptr(); + auto x_data = x.data_ptr(); + auto weight_data = weight.data_ptr(); + +// 计算反向传播的梯度 +#ifdef COMPILE_WITH_OMP +#pragma omp parallel for +#endif + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; + + for (auto h = 0; h < heads; ++h){ + for (auto k = 0; k < out_channels; ++k){ +#ifdef COMPILE_WITH_OMP +#pragma omp atomic +#endif + grad_x_data[src * out_channels * heads + h * out_channels + k] += + weight_data[e * heads + h] * grad_data[dst * out_channels * heads + h * out_channels + k]; + + grad_weight_data[e * heads + h] += x_data[src * out_channels * heads + h * out_channels + k] * + grad_data[dst * out_channels * heads + h * out_channels + k]; + + } + } + } + // return {grad_x, grad_weight}; + return std::make_tuple(grad_x, grad_weight); +} \ No newline at end of file diff --git a/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h new file mode 100644 index 00000000..478b67b8 --- /dev/null +++ b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h @@ -0,0 +1,6 @@ +#include + +torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, + torch::Tensor &x); +std::tuple bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, + torch::Tensor &grad); diff --git a/gammagl/mpops/torch_ext/include/gspmm.h b/gammagl/mpops/torch_ext/include/gspmm.h index 7b4c3cbb..10f332f9 100644 --- a/gammagl/mpops/torch_ext/include/gspmm.h +++ b/gammagl/mpops/torch_ext/include/gspmm.h @@ -24,3 +24,11 @@ class SpMMMax : public torch::autograd::Function { static std::vector backward(torch::autograd::AutogradContext *ctx, std::vector grad_outs); }; + +class BSpMMSum : public torch::autograd::Function { + public: + static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x); + static std::vector backward(torch::autograd::AutogradContext *ctx, + std::vector grad_outs); +}; diff --git a/gammagl/mpops/torch_ext/src/gspmm.cpp b/gammagl/mpops/torch_ext/src/gspmm.cpp index 9842860d..0840ac7e 100644 --- a/gammagl/mpops/torch_ext/src/gspmm.cpp +++ b/gammagl/mpops/torch_ext/src/gspmm.cpp @@ -9,6 +9,8 @@ #include "../cpu/spmm_sum_cpu.h" #include "../cpu/spmm_mean_cpu.h" #include "../cpu/spmm_max_cpu.h" +#include "../cpu/bspmm_sum_cpu.h" + #ifdef COMPILE_WITH_CUDA #include "../cuda/spmm_sum_cuda.h" #endif @@ -171,3 +173,53 @@ std::vector SpMMMax::backward(torch::autograd::AutogradContext *c return {torch::Tensor(), torch::Tensor(), grad_x}; } + + +torch::Tensor BSpMMSum::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x) { + ctx->save_for_backward({index, weight, x}); + ctx->mark_non_differentiable({index, weight}); + torch::Tensor out; + // CUDA + if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { + // #ifdef COMPILE_WITH_CUDA + // out = bspmm_sum_cuda_forward(index, weight, x); + // #else + AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); + // #endif + } + // CPU + else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { + out = bspmm_sum_cpu_forward(index, weight, x); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return out; +} + +std::vector BSpMMSum::backward(torch::autograd::AutogradContext *ctx, std::vector grad_outs) { + auto saved = ctx->get_saved_variables(); + auto index = saved[0], weight = saved[1], x = saved[2]; + auto grad = grad_outs[0]; + torch::Tensor grad_x, grad_weight; + + // CUDA + if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { + // #ifdef COMPILE_WITH_CUDA + // grad_x = bspmm_sum_cuda_backward(index, weight, grad); + // #else + AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); + // #endif + } + // CPU + else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { + auto result = bspmm_sum_cpu_backward(index, weight, x, grad); + grad_x = std::get<0>(result); + grad_weight = std::get<1>(result); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return {torch::Tensor(), grad_weight, grad_x}; +} diff --git a/gammagl/mpops/torch_ext/src/operators.cpp b/gammagl/mpops/torch_ext/src/operators.cpp index e22043d7..5f99cd92 100644 --- a/gammagl/mpops/torch_ext/src/operators.cpp +++ b/gammagl/mpops/torch_ext/src/operators.cpp @@ -39,6 +39,12 @@ torch::Tensor spmm_max(torch::Tensor index, torch::Tensor weight, torch::Tensor return SpMMMax::apply(index, weight, x); } +torch::Tensor bspmm_sum(torch::Tensor index, torch::Tensor weight, + torch::Tensor x) { + auto result = BSpMMSum::apply(index, weight, x); + return result; +} + PYBIND11_MODULE(_torch_ext, m) { m.def("c_segment_max", segment_max); m.def("c_segment_sum", segment_sum); @@ -46,4 +52,5 @@ PYBIND11_MODULE(_torch_ext, m) { m.def("c_spmm_sum", spmm_sum); m.def("c_spmm_mean", spmm_mean); m.def("c_spmm_max", spmm_max); + m.def("c_bspmm_sum", bspmm_sum); }