-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Implement model Unimp #83
base: main
Are you sure you want to change the base?
Conversation
support ogbn dataset
get ogb node dataset via OgbNodeDataset class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加一些测试代码,对于graph,link,node 各自选择一个最小的数据集写一下单元测试的代码 放在 https://github.com/BUPT-GAMMA/GammaGL/tree/main/tests/datasets 下面
gammagl/datasets/ogb_graph.py
Outdated
import os.path as osp | ||
import numpy as np | ||
from gammagl.data import InMemoryDataset | ||
from gammgl.utils.ogb_url import decide_download, download_url, extract_zip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gammgl -> gammagl
gammagl/utils/ogb_url.py
Outdated
@@ -0,0 +1,91 @@ | |||
import urllib.request as ur |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gammagl/datasets/ogb_graph.py
Outdated
|
||
# check if previously-downloaded folder exists. | ||
# If so, use that one. | ||
if osp.exists(osp.join(root, self.dir_name + '_pyg')): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pyg相关的代码改成gammagl
tests/datasets/test_ogb_graph.py
Outdated
print(data) | ||
print(data[0]) | ||
|
||
test() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.
tests/datasets/test_ogb_link.py
Outdated
data=OgbLinkDataset('ogbl-ppa') | ||
print(data[0]) | ||
|
||
test() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
tests/datasets/test_ogb_node.py
Outdated
data=OgbNodeDataset('ogbn-arxiv') | ||
print(data[0]) | ||
|
||
test() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
gammagl/datasets/ogb_node.py
Outdated
from gammagl.data import InMemoryDataset | ||
from gammagl.data.download import download_url | ||
from gammagl.data.extract import extract_zip | ||
from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph, read_nodesplitidx_split_hetero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
中文逗号
print(data) | ||
print(data[0]) | ||
|
||
test() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.
- Using "assert" to check the correctness of some variable. e.g. the "feature.shape[0]" and "num_nodes" should be equal.
data=OgbLinkDataset('ogbl-ppa') | ||
print(data[0]) | ||
|
||
test() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
data=OgbNodeDataset('ogbn-arxiv') | ||
print(data[0]) | ||
|
||
test() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
examples/unimp/unimp_trainer.py
Outdated
return loss | ||
|
||
|
||
class MultiHead(MessagePassing): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this section into gammagl.layers.conv.unimp_conv.py
. If you put this into this file, users will not be able to use this function.
examples/unimp/unimp_trainer.py
Outdated
return x | ||
|
||
|
||
class Unimp(tlx.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this section into gammagl.models.unimp.py
. If you put this into this file, users will not be able to use this function.
super(Unimp, self).__init__() | ||
|
||
out_layer1=int(dataset.num_node_features/2) | ||
self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think users can choose their n_heads, so change this to let users choose their own n_heads instead of a fixed one.
self.norm1=nn.LayerNorm(out_layer1) | ||
self.relu1=nn.ReLU() | ||
|
||
self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
import tlx.nn as nn | ||
from gammagl.layers import MultiHead | ||
|
||
class Unimp(tlx.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add rst
doc here refer to this link: https://docs.qq.com/pdf/DUXRTTU9tUnB1WnFB.
examples/unimp/unimp_trainer.py
Outdated
return loss | ||
|
||
|
||
def forward(self, x, edge_index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there are two forward functions?
gammagl/layers/conv/multi_head.py
Outdated
|
||
alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) | ||
x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function means Unimp
? So what is the difference between this model and model GAT
?
support ogbn dataset
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes