Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] RoHe #220

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
42 changes: 42 additions & 0 deletions examples/rohehan/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Robust Heterogeneous Graph Neural Network (RoHeHAN)

This is an implementation of `RoHeHAN`, a robust heterogeneous graph neural network designed to defend against adversarial attacks on heterogeneous graphs.

- Paper link: [https://cdn.aaai.org/ojs/20357/20357-13-24370-1-2-20220628.pdf](https://cdn.aaai.org/ojs/20357/20357-13-24370-1-2-20220628.pdf)
- Original paper title: *Robust Heterogeneous Graph Neural Networks against Adversarial Attacks*
- Implemented using `tensorlayerx` and `gammagl` libraries.

## Usage

To reproduce the RoHeHAN results on the ACM dataset, run the following command:

```bash
TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 100 --gpu 0
```

## Performance

Reference performance numbers for the ACM dataset:

| Dataset | Clean (no attack) | Attack(1 perturbation) | Attack(3 perturbations) | Attack(5 perturbations) |
| ------- | ----------------- | ---------------------- | ----------------------- | ----------------------- |
| ACM | 0.930 | 0.915 | 0.905 | 0.895 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要在 readme.md 文件中指明我们所实现的模型在各种后端下的运行结果,并给出作者论文中给出的结果。


ACM dataset link: [https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat](https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat)

### Example Commands

You can adjust training settings, such as the number of epochs, learning rate, and dropout rate, with the following commands:

```bash

TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 200 --lr 0.005 --dropout 0.6 --gpu 0

```

## Notes

- Early stopping is used to prevent overfitting during training.
- The `settings` in the RoHeGAT layer control the attention purifier mechanism, which ensures robustness against adversarial attacks by pruning unreliable neighbors.

This implementation builds on the idea of using metapath-based transiting probability and attention purification to improve the robustness of heterogeneous graph neural networks (HGNNs).
260 changes: 260 additions & 0 deletions examples/rohehan/rohehan_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# -*- coding: UTF-8 -*-
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
import argparse
import numpy as np
import tensorlayerx as tlx
from gammagl.models import RoheHAN
from utils import *
import pickle as pkl
from gammagl.datasets.acm4rohe import ACM4Rohe

class SemiSpvzLoss(tlx.nn.Module):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__()
self.net = net
self.loss_fn = loss_fn

def forward(self, data, y):
logits = self.net(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
train_logits = tlx.gather(logits['paper'], data['train_idx'])
train_y = tlx.gather(y, data['train_idx'])
loss = self.loss_fn(train_logits, train_y)
return loss

def main(args):
download_attack_data_files()
# Load ACM raw dataset
dataname = 'acm'
dataset = ACM4Rohe(root = "./")
g = dataset[0]
features_dict = {ntype: g[ntype].x for ntype in g.node_types if hasattr(g[ntype], 'x')}
labels = g['paper'].y
train_mask = g['paper'].train_mask
val_mask = g['paper'].val_mask
test_mask = g['paper'].test_mask

# Compute number of classes
num_classes = int(tlx.reduce_max(labels)) + 1

# Get train_idx, val_idx, test_idx from masks
train_idx = np.where(train_mask)[0]
val_idx = np.where(val_mask)[0]
test_idx = np.where(test_mask)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用 gammagl.utils 下的 mask_to_index 方法进行转换


x_dict = features_dict
y = labels
features = features_dict['paper']

# Define meta-paths (PAP, PFP)
meta_paths = [[('paper', 'pa', 'author'), ('author', 'ap', 'paper')],
[('paper', 'pf', 'field'), ('field', 'fp', 'paper')]]

# Define initial settings for each edge type
settings = {
('paper', 'author', 'paper'): {'T': 3, 'TransM': None},
('paper', 'field', 'paper'): {'T': 5, 'TransM': None},
}

# Prepare adjacency matrices
hete_adjs = {
'pa': edge_index_to_adj_matrix(g['paper', 'pa', 'author'].edge_index, g['paper'].num_nodes, g['author'].num_nodes),
'ap': edge_index_to_adj_matrix(g['author', 'ap', 'paper'].edge_index, g['author'].num_nodes, g['paper'].num_nodes),
'pf': edge_index_to_adj_matrix(g['paper', 'pf', 'field'].edge_index, g['paper'].num_nodes, g['field'].num_nodes),
'fp': edge_index_to_adj_matrix(g['field', 'fp', 'paper'].edge_index, g['field'].num_nodes, g['paper'].num_nodes)
}
meta_g = get_hg(dataname, hete_adjs, features_dict, labels, train_mask, val_mask, test_mask)
# Prepare edge index and node count dictionaries
edge_index_dict = {etype: meta_g[etype].edge_index for etype in meta_g.edge_types}
num_nodes_dict = {ntype: meta_g[ntype].num_nodes for ntype in meta_g.node_types}

# Compute edge transformation matrices
trans_edge_weights_list = get_transition(hete_adjs, meta_paths, edge_index_dict, meta_g.metadata()[1])
for i, edge_type in enumerate(meta_g.metadata()[1]):
settings[edge_type]['TransM'] = trans_edge_weights_list[i]

layer_settings = [settings, settings]

# Initialize the RoheHAN model
model = RoheHAN(
metadata=meta_g.metadata(),
in_channels=features.shape[1],
hidden_size=args.hidden_units,
out_size=num_classes,
num_heads=args.num_heads,
dropout_rate=args.dropout,
settings=layer_settings
)

# Define optimizer and loss function
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay)
loss_func = tlx.losses.softmax_cross_entropy_with_logits
semi_spvz_loss = SemiSpvzLoss(model, loss_func)

# Prepare training components
train_weights = model.trainable_weights
train_one_step = tlx.model.TrainOneStep(semi_spvz_loss, optimizer, train_weights)

# Prepare data dictionary
data = {
"x_dict": x_dict,
"edge_index_dict": edge_index_dict,
"num_nodes_dict": num_nodes_dict,
"train_idx": tlx.convert_to_tensor(train_idx, dtype=tlx.int64),
"val_idx": tlx.convert_to_tensor(val_idx, dtype=tlx.int64),
"test_idx": tlx.convert_to_tensor(test_idx, dtype=tlx.int64),
"y": y
}

# Ensure the best model path exists
if not os.path.exists(args.best_model_path):
os.makedirs(args.best_model_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般最优模型权重默认保存到 ./ 路径下,路径肯定是存在的,不需要写这两行代码


# Training loop
best_val_acc = 0.0

for epoch in range(args.num_epochs):
model.set_train()
# Forward and backward pass
loss = train_one_step(data, y)

# Evaluate on validation set
model.set_eval()
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, data, y, val_idx, loss_func)

print(f"Epoch {epoch+1} | Train Loss: {loss.item():.4f} | Val Micro-F1: {val_micro_f1:.4f} | Val Macro-F1: {val_macro_f1:.4f}")

# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
# Save model weights
model.save_weights(os.path.join(args.best_model_path, 'best_model.npz'), format='npz_dict')

# Load the best model
model.load_weights(os.path.join(args.best_model_path, 'best_model.npz'), format='npz_dict')

# Test the model
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, data, y, test_idx, loss_func)
print(f"Test Micro-F1: {test_micro_f1:.4f} | Test Macro-F1: {test_macro_f1:.4f}")

# Load target node IDs
print("Loading target nodes")
tar_idx = []
for i in range(1):
with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f:
tar_tmp = np.sort(pkl.load(f))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个路径应该是和你的数据集保存路径有关的,不能写成固定路径

tar_idx.extend(tar_tmp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这个for循环似乎只运行一次,是不是没必要写for循环


# Evaluate on target nodes
model.set_eval()
logits_dict = model(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
logits_clean = tlx.gather(logits_dict['paper'], tlx.convert_to_tensor(tar_idx, dtype=tlx.int64))
labels_clean = tlx.gather(y, tlx.convert_to_tensor(tar_idx, dtype=tlx.int64))
_, tar_micro_f1_clean, tar_macro_f1_clean = score(logits_clean, labels_clean)
print(f"Clean data: Micro-F1: {tar_micro_f1_clean:.4f} | Macro-F1: {tar_macro_f1_clean:.4f}")

# Load adversarial attacks
n_perturbation = 1
adv_filename = f'data/generated_attacks/adv_acm_pap_pa_{n_perturbation}.pkl'
with open(adv_filename, 'rb') as f:
modified_opt = pkl.load(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


# Apply adversarial attack
logits_adv_list = []
labels_adv_list = []
for items in modified_opt:
target_node = items[0]
del_list = items[2]
add_list = items[3]
if target_node not in tar_idx:
continue

# Modify adjacency matrices for the attack
mod_hete_adj_dict = {}
for key in hete_adjs.keys():
mod_hete_adj_dict[key] = hete_adjs[key].tolil()

# Delete and add edges
for edge in del_list:
mod_hete_adj_dict['pa'][edge[0], edge[1]] = 0
mod_hete_adj_dict['ap'][edge[1], edge[0]] = 0
for edge in add_list:
mod_hete_adj_dict['pa'][edge[0], edge[1]] = 1
mod_hete_adj_dict['ap'][edge[1], edge[0]] = 1

for key in mod_hete_adj_dict.keys():
mod_hete_adj_dict[key] = mod_hete_adj_dict[key].tocsc()

# Update edge index dictionary for the attack
edge_index_dict_atk = {}
meta_path_atk = [('paper', 'author', 'paper'), ('paper', 'field', 'paper')]
for idx, edge_type in enumerate(meta_path_atk):
# Recompute adjacency matrices for the attack
if edge_type == ('paper', 'author', 'paper'):
adj_matrix = mod_hete_adj_dict['pa'].dot(mod_hete_adj_dict['ap'])
elif edge_type == ('paper', 'field', 'paper'):
adj_matrix = mod_hete_adj_dict['pf'].dot(mod_hete_adj_dict['fp'])
else:
raise KeyError(f"Unknown edge type {edge_type}")

src, dst = adj_matrix.nonzero()
edge_index = np.vstack((src, dst))
edge_index_dict_atk[edge_type] = edge_index

# Update transformation matrices for the attack
trans_edge_weights_list = get_transition(mod_hete_adj_dict, meta_paths, edge_index_dict_atk, meta_path_atk)

for i, edge_type in enumerate(meta_path_atk):
key = '__'.join(edge_type)
if key in model.layer_list[0].gat_layers:
model.layer_list[0].gat_layers[key].settings['TransM'] = trans_edge_weights_list[i]
else:
raise KeyError(f"Edge type key '{key}' not found in gat_layers.")

# Prepare modified graph and data
mod_features_dict = {'paper': features}
g_atk = get_hg(dataname, mod_hete_adj_dict, mod_features_dict, y, train_mask, val_mask, test_mask)
data_atk = {
"x_dict": g_atk.x_dict,
"edge_index_dict": {etype: g_atk[etype].edge_index for etype in g_atk.edge_types},
"num_nodes_dict": {ntype: g_atk[ntype].num_nodes for ntype in g_atk.node_types},
}

# Run the model on the attacked graph
model.set_eval()
with no_grad():
logits_dict_atk = model(data_atk['x_dict'], data_atk['edge_index_dict'], data_atk['num_nodes_dict'])
logits_atk = logits_dict_atk['paper']
logits_adv = tlx.gather(logits_atk, tlx.convert_to_tensor([target_node], dtype=tlx.int64))
label_adv = tlx.gather(y, tlx.convert_to_tensor([target_node], dtype=tlx.int64))

logits_adv_list.append(logits_adv)
labels_adv_list.append(label_adv)

logits_adv = tlx.concat(logits_adv_list, axis=0)
labels_adv = tlx.concat(labels_adv_list, axis=0)

# Evaluate adversarial attack
_, tar_micro_f1_atk, tar_macro_f1_atk = score(logits_adv, labels_adv)
print(f"Attacked data: Micro-F1: {tar_micro_f1_atk:.4f} | Macro-F1: {tar_macro_f1_atk:.4f}")

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=2, help="Random seed.")
parser.add_argument("--lr", type=float, default=0.005, help="Learning rate.")
parser.add_argument("--num_heads", type=int, default=[8], help="Number of attention heads.")
parser.add_argument("--hidden_units", type=int, default=8, help="Hidden units.")
parser.add_argument("--dropout", type=float, default=0.6, help="Dropout rate.")
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay.")
parser.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.")
parser.add_argument("--gpu", type=int, default=0, help="GPU index. Use -1 for CPU.")
parser.add_argument("--best_model_path", type=str, default='./', help="Path to save the best model.")
args = parser.parse_args()

# Setup configuration
args = setup(args)

main(args)
Loading
Loading