Skip to content

Commit

Permalink
Modifications as required
Browse files Browse the repository at this point in the history
  • Loading branch information
taorann committed Jul 30, 2024
1 parent a3fc25f commit dc2580f
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 32 deletions.
26 changes: 7 additions & 19 deletions examples/fatragnn/config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
bail:
epochs: 400
g_epochs: 5
a_epochs: 5
cla_epochs: 2
dic_epochs: 2
a_epochs: 4
cla_epochs: 10
dic_epochs: 8
dtb_epochs: 5
d_lr: 0.001
c_lr: 0.001
e_lr: 0.001
g_lr: 0.01
c_lr: 0.005
e_lr: 0.005
g_lr: 0.05
drope_rate: 0.1
credit:
epochs: 600
Expand All @@ -21,16 +21,4 @@ credit:
c_lr: 0.01
e_lr: 0.01
g_lr: 0.05
drope_rate: 0.1
pokec:
epochs: 400
g_epochs: 5
a_epochs: 5
cla_epochs: 5
dic_epochs: 2
dtb_epochs: 1
d_lr: 0.001
c_lr: 0.001
e_lr: 0.001
g_lr: 0.01
drope_rate: 0.1
drope_rate: 0.1
15 changes: 6 additions & 9 deletions examples/fatragnn/fatragnn_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TL_BACKEND'] = 'tensorflow'
os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
import tensorlayerx as tlx
Expand All @@ -13,7 +13,7 @@
import yaml
from gammagl.datasets import Bail
from gammagl.datasets import Credit
from gammagl.datasets import Pokec


def fair_metric(pred, labels, sens):
idx_s0 = sens == 0
Expand Down Expand Up @@ -118,9 +118,6 @@ def main(args):
elif args.dataset == 'credit':
dataset = Credit(args.dataset_path, args.dataset)

elif args.dataset == 'pokec':
dataset = Pokec(args.dataset_path, args.dataset)

graphs = dataset.data
data = {
'x':graphs[0].x,
Expand Down Expand Up @@ -226,6 +223,7 @@ def main(args):

for epoch_g in range(0, args.dtb_epochs):
edt_loss = edt_train_one_step(data=data, label=data['y'])

# shift align
data['flag'] = 5
if epoch > args.start:
Expand All @@ -240,12 +238,11 @@ def main(args):
for i in range(args.test_set_num):
data_tem = data_test[i]
acc[i],auc_roc[i], parity[i], equality[i] = evaluate_ged3(net, data_tem['x'], data_tem['edge_index'], data_tem['y'], data_tem['test_mask'], data_tem['sens'])

return acc, auc_roc, parity, equality

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='pokec')
parser.add_argument('--dataset', type=str, default='bail')
parser.add_argument('--start', type=int, default=50)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--dic_epochs', type=int, default=5)
Expand All @@ -262,9 +259,9 @@ def main(args):
parser.add_argument('--e_lr', type=float, default=0.005)
parser.add_argument('--e_wd', type=float, default=0)
parser.add_argument('--hidden', type=int, default=128)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--seed', type=int, default=3)
parser.add_argument('--top_k', type=int, default=10)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--gpu', type=int, default=1)
parser.add_argument('--drope_rate', type=float, default=0.1)
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")

Expand Down
63 changes: 63 additions & 0 deletions examples/fatragnn/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Graph Fairness Learning under Distribution Shifts

- Paper link: [https://arxiv.org/abs/2401.16784](https://arxiv.org/abs/2401.16784)
- Author's code repo: [https://github.com/BUPT-GAMMA/FatraGNN](https://github.com/BUPT-GAMMA/FatraGNN). Note that the original code is implemented with Torch for the paper.

# Dataset Statics


| Dataset | # Nodes | # Edges | # Classes |
|----------|---------|---------|-----------|
| Bail_B0 | 4,686 | 153,942 | 2 |
| Bail_B1 | 2,214 | 49,124 | 2 |
| Bail_B2 | 2,395 | 88,091 | 2 |
| Bail_B3 | 1,536 | 57,838 | 2 |
| Bail_B4 | 1,193 | 30,319 | 2 |
| Credit_C0| 4,184 | 45,718 | 2 |
| Credit_C1| 2,541 | 18,949 | 2 |
| Credit_C2| 3,796 | 28,936 | 2 |
| Credit_C3| 2,068 | 15,314 | 2 |
| Credit_C4| 3,420 | 26,048 | 2 |

Refer to [Credit](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Credit) and [Bail](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Bail).

Results
-------




```bash
TL_BACKEND="torch" python fatragnn_trainer.py --dataset credit --epochs 600 --g_epochs 5 --a_epochs 2 --cla_epochs 12 --dic_epochs 5 --dtb_epochs 5 --c_lr 0.01 --e_lr 0.01
TL_BACKEND="torch" python fatragnn_trainer.py --dataset bail --epochs 400 --g_epochs 5 --a_epochs 4 --cla_epochs 10 --dic_epochs 8 --dtb_epochs 5 --c_lr 0.005 --e_lr 0.005


TL_BACKEND="tensorflow" python fatragnn_trainer.py --dataset credit --epochs 600 --g_epochs 5 --a_epochs 2 --cla_epochs 12 --dic_epochs 5 --dtb_epochs 5 --c_lr 0.01 --e_lr 0.01
TL_BACKEND="tensorflow" python fatragnn_trainer.py --dataset bail --epochs 400 --g_epochs 5 --a_epochs 4 --cla_epochs 10 --dic_epochs 8 --dtb_epochs 5 --c_lr 0.005 --e_lr 0.005
```
ACC:
| Dataset | Paper | Our(torch) | Our(tensorflow) |
| ---------- | ----------- | ---------------- | --------------- |
| Credit_C1 | 77.31±0.10 | 77.08(±0.08) | 77.06(±0.10) |
| Credit_C2 | 77.12±0.28 | 77.26(±0.13) | 77.22(±0.11) |
| Credit_C3 | 71.81±0.39 | 70.86(±0.15) | 71.02(±0.12) |
| Credit_C4 | 72.15±0.42 | 70.91(±0.10) | 71.08(±0.09) |
| Bail_B1 | 74.59±0.93 | 72.13(±0.97) | 72.08(±0.98) |
| Bail_B2 | 70.46±0.44 | 78.55(±0.94) | 79.02(±0.31) |
| Bail_B3 | 71.65±4.65 | 79.77(±0.70) | 78.96(±0.76) |
| Bail_B4 | 72.59±3.39 | 80.35(±1.73) | 79.91(±0.64) |




equality:
| Dataset | Paper | Our(torch) | Our(tensorflow) |
| ---------- | ---------- | ---------------- | --------------- |
| Credit_C1 | 0.71±0.03 | 0.53(±0.05) | 0.41(±0.02) |
| Credit_C2 | 0.95±0.7 | 0.13(±0.10) | 0.30(±0.39) |
| Credit_C3 | 0.81±0.56 | 1.81(±1.68) | 2.51(±1.92) |
| Credit_C4 | 1.16±0.13 | 0.14(±0.07) | 0.18(±0.13) |
| Bail_B1 | 2.38±3.19 | 4.38(±2.87) | 1.28(±1.04) |
| Bail_B2 | 0.43±1.14 | 4.48(±2.52) | 3.51(±1.92) |
| Bail_B3 | 2.43±4.94 | 2.62(±2.55) | 2.13(±0.43) |
| Bail_B4 | 2.45±6.67 | 1.16(±1.40) | 3.03(±1.22) |
22 changes: 18 additions & 4 deletions gammagl/models/fatragnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def forward(self, h):


class FatraGNNModel(tlx.nn.Module):
r"""FatraGNN from `"Graph Fairness Learning under Distribution Shifts"
<https://arxiv.org/abs/2401.16784>`_ paper.
Parameters
----------
in_features: int
input feature dimension.
hidden: int
hidden dimension.
out_features: int
number of output feature dimension.
drop_rate: float
dropout rate.
"""
def __init__(self, args):
super(FatraGNNModel, self).__init__()
self.classifier = MLP_classifier(args)
Expand Down Expand Up @@ -100,15 +114,15 @@ def modify_structure1(self, edge_index, A2_edge, sens, nodes_num, drop=0.8, add=
random.seed(self.seed)
src_node, targ_node = edge_index[0], edge_index[1]
matching = tlx.gather(sens, src_node) == tlx.gather(sens, targ_node)
# 去掉异配边

yipei = mask_to_index(matching == False)
drop_index = tlx.convert_to_tensor(random.sample(range(yipei.shape[0]), int(yipei.shape[0] * drop)))
yipei_drop = tlx.gather(yipei, drop_index)
keep_indices0 = tlx.ones(src_node.shape, dtype=tlx.bool)
keep_indices = tlx.scatter_update(keep_indices0, yipei_drop, tlx.zeros((yipei_drop.shape), dtype=tlx.bool))
n_src_node = src_node[keep_indices]
n_targ_node = targ_node[keep_indices]
# 加同配

src_node2, targ_node2 = A2_edge[0], A2_edge[1]
matching2 = tlx.gather(sens, src_node2) == tlx.gather(sens, targ_node2)
matching3 = src_node2 == targ_node2
Expand All @@ -131,7 +145,7 @@ def modify_structure2(self, edge_index, A2_edge, sens, nodes_num, drop=0.6, add=
random.seed(self.seed)
src_node, targ_node = edge_index[0], edge_index[1]
matching = tlx.gather(sens, src_node) == tlx.gather(sens, targ_node)
# 去掉异配边


yipei = mask_to_index(matching == False)
yipei_np = tlx.convert_to_numpy(yipei)
Expand All @@ -144,7 +158,7 @@ def modify_structure2(self, edge_index, A2_edge, sens, nodes_num, drop=0.6, add=
keep_indices = tlx.scatter_update(keep_indices0, yipei_drop, tlx.zeros((yipei_drop.shape), dtype=tlx.bool))
n_src_node = src_node[keep_indices]
n_targ_node = targ_node[keep_indices]
# 加同配

src_node2, targ_node2 = A2_edge[0], A2_edge[1]
matching2 = tlx.gather(sens, src_node2) != tlx.gather(sens, targ_node2)
matching3 = src_node2 == targ_node2
Expand Down

0 comments on commit dc2580f

Please sign in to comment.