-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Implement model Unimp #83
base: main
Are you sure you want to change the base?
Changes from 20 commits
ac20c1a
d4ef822
84a9fa2
ae6e395
21f6979
09321cf
bd694bc
ef7e5d0
ec483a7
f987fa4
974f068
dafdd1b
f1aa221
c0ddf81
6902fe7
bc80112
23d842e
646dad4
1c48b27
de1095e
62c8e26
64fbefc
9a12bf4
a8a2bc9
9b2a3d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Graph Convolutional Networks (GCN) | ||
|
||
- Paper link: [https://arxiv.org/abs/2009.03509](https://arxiv.org/abs/2009.03509) | ||
|
||
# Dataset Statics | ||
|
||
| Dataset | # Nodes | # Edges | # Classes | | ||
|----------|---------|---------|-----------| | ||
| Cora | 2,708 | 10,556 | 7 | | ||
| Citeseer | 3,327 | 9,228 | 6 | | ||
| Pubmed | 19,717 | 88,651 | 3 | | ||
|
||
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid). | ||
|
||
Results | ||
------- | ||
|
||
```bash | ||
# available dataset: "cora", "citeseer", "pubmed" | ||
TL_BACKEND="tensorflow" python unimp_trainer.py --dataset cora | ||
TL_BACKEND="tensorflow" python unimp_trainer.py --dataset citeseer | ||
TL_BACKEND="tensorflow" python unimp_trainer.py --dataset pubmed | ||
TL_BACKEND="torch" python unimp_trainer.py --dataset cora | ||
TL_BACKEND="torch" python unimp_trainer.py --dataset citeseer | ||
TL_BACKEND="torch" python unimp_trainer.py --dataset pubmed | ||
``` | ||
|
||
| Dataset | Our(tf) | Our(torch) | | ||
|----------|------------|------------| | ||
| cora | 83.10±1.12 | 82.30±0.67 | | ||
| citeseer | 79.90±0.68 | 78.53±0.18 | | ||
| pubmed | 74.10±1.08 | 73.63±0.12 | |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import os | ||
os.environ['CUDA_VISIBLE_DEVICES']='0' | ||
import random | ||
import argparse | ||
import tensorlayerx as tlx | ||
import tensorlayerx.nn as nn | ||
from gammagl.utils import segment_softmax | ||
from gammagl.datasets import Planetoid | ||
from gammagl.layers.conv import MessagePassing | ||
from gammagl.utils import add_self_loops, mask_to_index | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
|
||
class CrossEntropyLoss(WithLoss): | ||
def __init__(self, model, loss_func): | ||
super(CrossEntropyLoss, self).__init__(model,loss_func) | ||
|
||
def forward(self, data, label): | ||
out = self.backbone_network(data['x'], data['edge_index']) | ||
out = tlx.gather(out, data['val_idx']) | ||
label = tlx.reshape(tlx.gather(label, data['val_idx']),shape=(-1,)) | ||
#print(out[0]) | ||
#print(label[0]) | ||
loss = self._loss_fn(out, label) | ||
return loss | ||
|
||
|
||
class MultiHead(MessagePassing): | ||
def __init__(self, in_features, out_features, n_heads,num_nodes): | ||
super().__init__() | ||
self.heads=n_heads | ||
self.num_nodes=num_nodes | ||
self.out_channels=out_features | ||
self.linear = tlx.layers.Linear(out_features=out_features* n_heads, | ||
in_features=in_features) | ||
|
||
init = tlx.initializers.RandomNormal() | ||
self.att_src = init(shape=(1, n_heads, out_features), dtype=tlx.float32) | ||
self.att_dst = init(shape=(1, n_heads, out_features), dtype=tlx.float32) | ||
|
||
self.leaky_relu = tlx.layers.LeakyReLU(0.2) | ||
self.dropout = tlx.layers.Dropout() | ||
|
||
def message(self, x, edge_index): | ||
node_src = edge_index[0, :] | ||
node_dst = edge_index[1, :] | ||
weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src) | ||
weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst) | ||
weight = self.leaky_relu(weight_src + weight_dst) | ||
|
||
alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) | ||
x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) | ||
return x | ||
|
||
|
||
def forward(self, x, edge_index): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why there are two forward functions? |
||
x = tlx.reshape(self.linear(x), shape=(-1,self.heads, self.out_channels)) | ||
x = self.propagate(x, edge_index, num_nodes=self.num_nodes) | ||
x=tlx.ops.reduce_mean(x,axis=1) | ||
|
||
return x | ||
|
||
|
||
class Unimp(tlx.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put this section into |
||
def __init__(self,dataset): | ||
super(Unimp, self).__init__() | ||
|
||
out_layer1=int(dataset.num_node_features/2) | ||
self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes) | ||
self.norm1=nn.LayerNorm(out_layer1) | ||
self.relu1=nn.ReLU() | ||
|
||
self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes) | ||
self.norm2=nn.LayerNorm(dataset.num_classes) | ||
self.relu2=nn.ReLU() | ||
def forward(self, x, edge_index): | ||
out1 = self.layer1(x, edge_index) | ||
out2=self.norm1(out1) | ||
out3=self.relu1(out2) | ||
out4=self.layer2(out3,edge_index) | ||
out5 = self.norm2(out4) | ||
out6 = self.relu2(out5) | ||
return out6 | ||
|
||
def calculate_acc(logits, y, metrics): | ||
metrics.update(logits, y) | ||
rst = metrics.result() | ||
metrics.reset() | ||
return rst | ||
def get_label_mask(label,node,dtype): | ||
mask=[1 for i in range(node['train_node1'])]+[0 for i in range(node['train_node2'])] | ||
random.shuffle(mask) | ||
label_mask=[] | ||
for i in range(node['train_node']): | ||
if mask[i]==0: | ||
label_mask.append([-1]) | ||
else: | ||
label_mask.append([(int)(label[i])]) | ||
label_mask+=[[0] for i in range(node['num_node']-node['train_node'])] | ||
return tlx.ops.convert_to_tensor(label_mask,dtype=dtype) | ||
|
||
def merge_feature_label(label,feature): | ||
return tlx.ops.concat([label,feature],axis=1) | ||
def main(args): | ||
dataset = Planetoid(root='./',name=args.dataset) | ||
graph=dataset[0] | ||
feature=graph.x | ||
edge_index=graph.edge_index | ||
label=graph.y | ||
train_node=int(graph.num_nodes * 0.3) | ||
train_node1=int(graph.num_nodes * 0.1) | ||
node = { | ||
'train_node': train_node, | ||
'train_node1': train_node1, | ||
'train_node2': train_node-train_node1, | ||
'num_node': graph.num_nodes | ||
} | ||
val_mask = tlx.ops.concat( | ||
[tlx.ops.zeros((train_node, 1),dtype=tlx.int32), | ||
tlx.ops.ones((train_node-train_node1, 1),dtype=tlx.int32)],axis=0) | ||
test_mask=graph.test_mask | ||
model=Unimp(dataset) | ||
loss = tlx.losses.softmax_cross_entropy_with_logits | ||
optimizer = tlx.optimizers.Adam(lr=0.01, weight_decay=5e-4) | ||
train_weights = model.trainable_weights | ||
loss_func = CrossEntropyLoss(model, loss) | ||
train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
val_idx = mask_to_index(val_mask) | ||
test_idx = mask_to_index(test_mask) | ||
metrics = tlx.metrics.Accuracy() | ||
data = { | ||
"x": feature, | ||
"y": label, | ||
"edge_index": edge_index, | ||
"val_idx":val_idx, | ||
"test_idx": test_idx, | ||
"num_nodes": graph.num_nodes, | ||
} | ||
|
||
epochs=args.epochs | ||
best_val_acc=0 | ||
for epoch in range(epochs): | ||
model.set_train() | ||
label_mask=get_label_mask(label,node,feature[0].dtype) | ||
data['x']=merge_feature_label(label_mask,feature) | ||
train_loss = train_one_step(data, graph.y) | ||
|
||
model.set_eval() | ||
logits = model(data['x'], data['edge_index']) | ||
test_logits = tlx.gather(logits, data['test_idx']) | ||
test_y = tlx.gather(data['y'], data['test_idx']) | ||
test_acc = calculate_acc(test_logits, test_y, metrics) | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch + 1) | ||
+ " train loss: {:.4f}".format(train_loss.item()) | ||
+ " val acc: {:.4f}".format(test_acc)) | ||
|
||
# save best model on evaluation set | ||
if test_acc > best_val_acc: | ||
best_val_acc = test_acc | ||
model.save_weights('./'+ 'unimp' + ".npz", format='npz_dict') | ||
print("The Best ACC : {:.4f}".format(best_val_acc)) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--epochs", type=int, default=200, help="number of epoch") | ||
parser.add_argument('--dataset', type=str, default='cora', help='dataset') | ||
args = parser.parse_args() | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
,ogbg-molbace,ogbg-molbbbp,ogbg-molclintox,ogbg-molmuv,ogbg-molpcba,ogbg-molsider,ogbg-moltox21,ogbg-moltoxcast,ogbg-molhiv,ogbg-molesol,ogbg-molfreesolv,ogbg-mollipo,ogbg-molchembl,ogbg-ppa,ogbg-code2 | ||
num tasks,1,1,2,17,128,27,12,617,1,1,1,1,1310,1,1 | ||
eval metric,rocauc,rocauc,rocauc,ap,ap,rocauc,rocauc,rocauc,rocauc,rmse,rmse,rmse,rocauc,acc,F1 | ||
download_name,bace,bbbp,clintox,muv,pcba,sider,tox21,toxcast,hiv,esol,freesolv,lipophilicity,chembl,ogbg_ppi_medium,code2 | ||
version,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 | ||
url,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/bace.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/bbbp.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/clintox.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/muv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/pcba.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/sider.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/tox21.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/toxcast.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/esol.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/freesolv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/lipophilicity.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/chembl.zip,http://snap.stanford.edu/ogb/data/graphproppred/ogbg_ppi_medium.zip,http://snap.stanford.edu/ogb/data/graphproppred/code2.zip | ||
add_inverse_edge,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False | ||
data type,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,, | ||
has_node_attr,True,True,True,True,True,True,True,True,True,True,True,True,True,False,True | ||
has_edge_attr,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False | ||
task type,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,regression,regression,regression,binary classification,multiclass classification,subtoken prediction | ||
num classes,2,2,2,2,2,2,2,2,2,-1,-1,-1,2,37,-1 | ||
split,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,species,project | ||
additional node files,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"node_is_attributed,node_dfs_order,node_depth" | ||
additional edge files,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None | ||
binary,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
,ogbl-ppa,ogbl-collab,ogbl-citation2,ogbl-wikikg2,ogbl-ddi,ogbl-biokg,ogbl-vessel | ||
eval metric,hits@100,hits@50,mrr,mrr,hits@20,mrr,rocauc | ||
task type,link prediction,link prediction,link prediction,KG completion,link prediction,KG completion,link prediction | ||
download_name,ppassoc,collab,citation-v2,wikikg-v2,ddi,biokg,vessel | ||
version,1,1,1,1,1,1,1 | ||
url,http://snap.stanford.edu/ogb/data/linkproppred/ppassoc.zip,http://snap.stanford.edu/ogb/data/linkproppred/collab.zip,http://snap.stanford.edu/ogb/data/linkproppred/citation-v2.zip,http://snap.stanford.edu/ogb/data/linkproppred/wikikg-v2.zip,http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip,http://snap.stanford.edu/ogb/data/linkproppred/biokg.zip,http://snap.stanford.edu/ogb/data/linkproppred/vessel.zip | ||
add_inverse_edge,True,True,False,False,True,False,False | ||
has_node_attr,True,True,True,False,False,False,True | ||
has_edge_attr,False,False,False,False,False,False,True | ||
split,throughput,time,time,time,target,random,spatial | ||
additional node files,None,None,node_year,None,None,None,None | ||
additional edge files,None,"edge_weight,edge_year",None,edge_reltype,None,edge_reltype,None | ||
is hetero,False,False,False,False,False,True,False | ||
binary,False,False,False,False,False,False,True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
,ogbn-proteins,ogbn-products,ogbn-arxiv,ogbn-mag,ogbn-papers100M | ||
num tasks,112,1,1,1,1 | ||
num classes,2,47,40,349,172 | ||
eval metric,rocauc,acc,acc,acc,acc | ||
task type,binary classification,multiclass classification,multiclass classification,multiclass classification,multiclass classification | ||
download_name,proteins,products,arxiv,mag,papers100M-bin | ||
version,1,1,1,2,1 | ||
url,http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip,http://snap.stanford.edu/ogb/data/nodeproppred/products.zip,http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip,http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip,http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip | ||
add_inverse_edge,True,True,False,False,False | ||
has_node_attr,False,True,True,True,True | ||
has_edge_attr,True,False,False,False,False | ||
split,species,sales_ranking,time,time,time | ||
additional node files,node_species,None,node_year,node_year,node_year | ||
additional edge files,None,None,None,edge_reltype,None | ||
is hetero,False,False,False,True,False | ||
binary,False,False,False,False,True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this section into
gammagl.layers.conv.unimp_conv.py
. If you put this into this file, users will not be able to use this function.