From dc2580f17d8284b262aa19224cfca3b0ba563c6e Mon Sep 17 00:00:00 2001 From: taorann Date: Tue, 30 Jul 2024 22:01:22 +0800 Subject: [PATCH] Modifications as required --- examples/fatragnn/config.yaml | 26 +++-------- examples/fatragnn/fatragnn_trainer.py | 15 +++---- examples/fatragnn/readme.md | 63 +++++++++++++++++++++++++++ gammagl/models/fatragnn.py | 22 ++++++++-- 4 files changed, 94 insertions(+), 32 deletions(-) create mode 100644 examples/fatragnn/readme.md diff --git a/examples/fatragnn/config.yaml b/examples/fatragnn/config.yaml index 0739b6d5..3ad95fad 100644 --- a/examples/fatragnn/config.yaml +++ b/examples/fatragnn/config.yaml @@ -1,14 +1,14 @@ bail: epochs: 400 g_epochs: 5 - a_epochs: 5 - cla_epochs: 2 - dic_epochs: 2 + a_epochs: 4 + cla_epochs: 10 + dic_epochs: 8 dtb_epochs: 5 d_lr: 0.001 - c_lr: 0.001 - e_lr: 0.001 - g_lr: 0.01 + c_lr: 0.005 + e_lr: 0.005 + g_lr: 0.05 drope_rate: 0.1 credit: epochs: 600 @@ -21,16 +21,4 @@ credit: c_lr: 0.01 e_lr: 0.01 g_lr: 0.05 - drope_rate: 0.1 -pokec: - epochs: 400 - g_epochs: 5 - a_epochs: 5 - cla_epochs: 5 - dic_epochs: 2 - dtb_epochs: 1 - d_lr: 0.001 - c_lr: 0.001 - e_lr: 0.001 - g_lr: 0.01 - drope_rate: 0.1 + drope_rate: 0.1 \ No newline at end of file diff --git a/examples/fatragnn/fatragnn_trainer.py b/examples/fatragnn/fatragnn_trainer.py index 74d34af6..71f6bbe5 100644 --- a/examples/fatragnn/fatragnn_trainer.py +++ b/examples/fatragnn/fatragnn_trainer.py @@ -1,6 +1,6 @@ import os # os.environ['CUDA_VISIBLE_DEVICES'] = '0' -os.environ['TL_BACKEND'] = 'tensorflow' +os.environ['TL_BACKEND'] = 'torch' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR import tensorlayerx as tlx @@ -13,7 +13,7 @@ import yaml from gammagl.datasets import Bail from gammagl.datasets import Credit -from gammagl.datasets import Pokec + def fair_metric(pred, labels, sens): idx_s0 = sens == 0 @@ -118,9 +118,6 @@ def main(args): elif args.dataset == 'credit': dataset = Credit(args.dataset_path, args.dataset) - elif args.dataset == 'pokec': - dataset = Pokec(args.dataset_path, args.dataset) - graphs = dataset.data data = { 'x':graphs[0].x, @@ -226,6 +223,7 @@ def main(args): for epoch_g in range(0, args.dtb_epochs): edt_loss = edt_train_one_step(data=data, label=data['y']) + # shift align data['flag'] = 5 if epoch > args.start: @@ -240,12 +238,11 @@ def main(args): for i in range(args.test_set_num): data_tem = data_test[i] acc[i],auc_roc[i], parity[i], equality[i] = evaluate_ged3(net, data_tem['x'], data_tem['edge_index'], data_tem['y'], data_tem['test_mask'], data_tem['sens']) - return acc, auc_roc, parity, equality if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--dataset', type=str, default='pokec') + parser.add_argument('--dataset', type=str, default='bail') parser.add_argument('--start', type=int, default=50) parser.add_argument('--epochs', type=int, default=400) parser.add_argument('--dic_epochs', type=int, default=5) @@ -262,9 +259,9 @@ def main(args): parser.add_argument('--e_lr', type=float, default=0.005) parser.add_argument('--e_wd', type=float, default=0) parser.add_argument('--hidden', type=int, default=128) - parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--seed', type=int, default=3) parser.add_argument('--top_k', type=int, default=10) - parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--drope_rate', type=float, default=0.1) parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset") diff --git a/examples/fatragnn/readme.md b/examples/fatragnn/readme.md new file mode 100644 index 00000000..0fa5ee69 --- /dev/null +++ b/examples/fatragnn/readme.md @@ -0,0 +1,63 @@ +# Graph Fairness Learning under Distribution Shifts + +- Paper link: [https://arxiv.org/abs/2401.16784](https://arxiv.org/abs/2401.16784) +- Author's code repo: [https://github.com/BUPT-GAMMA/FatraGNN](https://github.com/BUPT-GAMMA/FatraGNN). Note that the original code is implemented with Torch for the paper. + +# Dataset Statics + + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|---------|-----------| +| Bail_B0 | 4,686 | 153,942 | 2 | +| Bail_B1 | 2,214 | 49,124 | 2 | +| Bail_B2 | 2,395 | 88,091 | 2 | +| Bail_B3 | 1,536 | 57,838 | 2 | +| Bail_B4 | 1,193 | 30,319 | 2 | +| Credit_C0| 4,184 | 45,718 | 2 | +| Credit_C1| 2,541 | 18,949 | 2 | +| Credit_C2| 3,796 | 28,936 | 2 | +| Credit_C3| 2,068 | 15,314 | 2 | +| Credit_C4| 3,420 | 26,048 | 2 | + +Refer to [Credit](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Credit) and [Bail](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Bail). + +Results +------- + + + + +```bash +TL_BACKEND="torch" python fatragnn_trainer.py --dataset credit --epochs 600 --g_epochs 5 --a_epochs 2 --cla_epochs 12 --dic_epochs 5 --dtb_epochs 5 --c_lr 0.01 --e_lr 0.01 +TL_BACKEND="torch" python fatragnn_trainer.py --dataset bail --epochs 400 --g_epochs 5 --a_epochs 4 --cla_epochs 10 --dic_epochs 8 --dtb_epochs 5 --c_lr 0.005 --e_lr 0.005 + + +TL_BACKEND="tensorflow" python fatragnn_trainer.py --dataset credit --epochs 600 --g_epochs 5 --a_epochs 2 --cla_epochs 12 --dic_epochs 5 --dtb_epochs 5 --c_lr 0.01 --e_lr 0.01 +TL_BACKEND="tensorflow" python fatragnn_trainer.py --dataset bail --epochs 400 --g_epochs 5 --a_epochs 4 --cla_epochs 10 --dic_epochs 8 --dtb_epochs 5 --c_lr 0.005 --e_lr 0.005 +``` +ACC: +| Dataset | Paper | Our(torch) | Our(tensorflow) | +| ---------- | ----------- | ---------------- | --------------- | +| Credit_C1 | 77.31±0.10 | 77.08(±0.08) | 77.06(±0.10) | +| Credit_C2 | 77.12±0.28 | 77.26(±0.13) | 77.22(±0.11) | +| Credit_C3 | 71.81±0.39 | 70.86(±0.15) | 71.02(±0.12) | +| Credit_C4 | 72.15±0.42 | 70.91(±0.10) | 71.08(±0.09) | +| Bail_B1 | 74.59±0.93 | 72.13(±0.97) | 72.08(±0.98) | +| Bail_B2 | 70.46±0.44 | 78.55(±0.94) | 79.02(±0.31) | +| Bail_B3 | 71.65±4.65 | 79.77(±0.70) | 78.96(±0.76) | +| Bail_B4 | 72.59±3.39 | 80.35(±1.73) | 79.91(±0.64) | + + + + +equality: +| Dataset | Paper | Our(torch) | Our(tensorflow) | +| ---------- | ---------- | ---------------- | --------------- | +| Credit_C1 | 0.71±0.03 | 0.53(±0.05) | 0.41(±0.02) | +| Credit_C2 | 0.95±0.7 | 0.13(±0.10) | 0.30(±0.39) | +| Credit_C3 | 0.81±0.56 | 1.81(±1.68) | 2.51(±1.92) | +| Credit_C4 | 1.16±0.13 | 0.14(±0.07) | 0.18(±0.13) | +| Bail_B1 | 2.38±3.19 | 4.38(±2.87) | 1.28(±1.04) | +| Bail_B2 | 0.43±1.14 | 4.48(±2.52) | 3.51(±1.92) | +| Bail_B3 | 2.43±4.94 | 2.62(±2.55) | 2.13(±0.43) | +| Bail_B4 | 2.45±6.67 | 1.16(±1.40) | 3.03(±1.22) | diff --git a/gammagl/models/fatragnn.py b/gammagl/models/fatragnn.py index 6f1f7c4e..5e8d0ade 100644 --- a/gammagl/models/fatragnn.py +++ b/gammagl/models/fatragnn.py @@ -43,6 +43,20 @@ def forward(self, h): class FatraGNNModel(tlx.nn.Module): + r"""FatraGNN from `"Graph Fairness Learning under Distribution Shifts" + `_ paper. + + Parameters + ---------- + in_features: int + input feature dimension. + hidden: int + hidden dimension. + out_features: int + number of output feature dimension. + drop_rate: float + dropout rate. + """ def __init__(self, args): super(FatraGNNModel, self).__init__() self.classifier = MLP_classifier(args) @@ -100,7 +114,7 @@ def modify_structure1(self, edge_index, A2_edge, sens, nodes_num, drop=0.8, add= random.seed(self.seed) src_node, targ_node = edge_index[0], edge_index[1] matching = tlx.gather(sens, src_node) == tlx.gather(sens, targ_node) - # 去掉异配边 + yipei = mask_to_index(matching == False) drop_index = tlx.convert_to_tensor(random.sample(range(yipei.shape[0]), int(yipei.shape[0] * drop))) yipei_drop = tlx.gather(yipei, drop_index) @@ -108,7 +122,7 @@ def modify_structure1(self, edge_index, A2_edge, sens, nodes_num, drop=0.8, add= keep_indices = tlx.scatter_update(keep_indices0, yipei_drop, tlx.zeros((yipei_drop.shape), dtype=tlx.bool)) n_src_node = src_node[keep_indices] n_targ_node = targ_node[keep_indices] - # 加同配 + src_node2, targ_node2 = A2_edge[0], A2_edge[1] matching2 = tlx.gather(sens, src_node2) == tlx.gather(sens, targ_node2) matching3 = src_node2 == targ_node2 @@ -131,7 +145,7 @@ def modify_structure2(self, edge_index, A2_edge, sens, nodes_num, drop=0.6, add= random.seed(self.seed) src_node, targ_node = edge_index[0], edge_index[1] matching = tlx.gather(sens, src_node) == tlx.gather(sens, targ_node) - # 去掉异配边 + yipei = mask_to_index(matching == False) yipei_np = tlx.convert_to_numpy(yipei) @@ -144,7 +158,7 @@ def modify_structure2(self, edge_index, A2_edge, sens, nodes_num, drop=0.6, add= keep_indices = tlx.scatter_update(keep_indices0, yipei_drop, tlx.zeros((yipei_drop.shape), dtype=tlx.bool)) n_src_node = src_node[keep_indices] n_targ_node = targ_node[keep_indices] - # 加同配 + src_node2, targ_node2 = A2_edge[0], A2_edge[1] matching2 = tlx.gather(sens, src_node2) != tlx.gather(sens, targ_node2) matching3 = src_node2 == targ_node2