Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] RoHe #220

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

[Model] RoHe #220

wants to merge 18 commits into from

Conversation

n1108
Copy link

@n1108 n1108 commented Sep 30, 2024

Robust Heterogeneous Graph Neural Network (RoHeHAN)

This is an implementation of RoHeHAN, a robust heterogeneous graph neural network designed to defend against adversarial attacks on heterogeneous graphs.

Usage

To reproduce the RoHeHAN results on the ACM dataset, run the following command:

TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 100 --gpu 0
TL_BACKEND="tensorflow" python rohehan_trainer.py --num_epochs 100 --gpu 0

Performance

Reference performance numbers for the ACM dataset:

Backend Clean (no attack) Attack (1 perturbation) Attack (3 perturbations) Attack (5 perturbations)
torch 0.955 0.950 0.940 0.905
tensorflow 0.965 0.935 0.910 0.905

ACM dataset link: https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat

Example Commands

You can adjust training settings, such as the number of epochs, learning rate, and dropout rate, with the following commands:

TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 200 --lr 0.005 --dropout 0.6 --gpu 0 --seed 0

Notes

  • The settings in the RoHeGAT layer control the attention purifier mechanism, which ensures robustness against adversarial attacks by pruning unreliable neighbors.

This implementation builds on the idea of using metapath-based transiting probability and attention purification to improve the robustness of heterogeneous graph neural networks (HGNNs).

Original Paper Results

The original paper reports the following performance metrics under clean and adversarial settings:

Dataset Clean (no attack) Attack (1 perturbation) Attack (3 perturbations) Attack (5 perturbations)
ACM 0.920 0.904 0.902 0.882

@gyzhou2000 gyzhou2000 changed the title Han Rohe 牛宇韬 [Model] RoHe Oct 9, 2024
gammagl/models/rohehan.py Outdated Show resolved Hide resolved
gammagl/layers/conv/rohegat_conv.py Outdated Show resolved Hide resolved
examples/rohehan/rohehan_trainer.py Outdated Show resolved Hide resolved
examples/rohehan/rohehan_trainer.py Outdated Show resolved Hide resolved
examples/rohehan/rohehan_trainer.py Outdated Show resolved Hide resolved
examples/rohehan/utils.py Outdated Show resolved Hide resolved
gammagl/layers/conv/rohegat_conv.py Outdated Show resolved Hide resolved
gammagl/models/rohehan.py Outdated Show resolved Hide resolved
gammagl/models/rohehan.py Outdated Show resolved Hide resolved
gammagl/models/rohehan.py Outdated Show resolved Hide resolved
@gyzhou2000
Copy link
Contributor

在添加了 layers/conv/rohehan_conv.py 文件之后,也应该修改在 layers/conv 路径下的 __init__.py 文件,确保能够导入对应内容。 models/rohehan.py 同理。

}

# Training loop
best_model_saver = BestModelSaver()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要使用这个工具类,把这里代码以及 utils.py 中对应的代码删去即可。在gammagl中,一般使用loss或者val_acc来作为判断是否更新权重的方式,可以参考其他文件的写法。

Comment on lines 120 to 215
def load_acm_raw():
data_path = 'acm/ACM.mat' # Path to the dataset
data = sio.loadmat(data_path)
p_vs_f = data['PvsL'] # Paper-field adjacency
p_vs_a = data['PvsA'] # Paper-author adjacency
p_vs_t = data['PvsT'] # Paper-term feature matrix
p_vs_c = data['PvsC'] # Paper-conference labels

# Assigning classes to specific conferences
conf_ids = [0, 1, 9, 10, 13]
label_ids = [0, 1, 2, 2, 1]

# Filter papers with conference labels
p_vs_c_filter = p_vs_c[:, conf_ids]
p_selected = np.nonzero(p_vs_c_filter.sum(1))[0]
p_vs_f = p_vs_f[p_selected]
p_vs_a = p_vs_a[p_selected]
p_vs_t = p_vs_t[p_selected]
p_vs_c = p_vs_c[p_selected]

# Construct edge indices
edge_index_pa = np.vstack(p_vs_a.nonzero())
edge_index_ap = edge_index_pa[[1, 0]]
edge_index_pf = np.vstack(p_vs_f.nonzero())
edge_index_fp = edge_index_pf[[1, 0]]

# Create node features dictionary
features = tlx.convert_to_tensor(p_vs_t.toarray(), dtype=tlx.float32)
features_dict = {'paper': features}

# Process labels
pc_p, pc_c = p_vs_c.nonzero()
labels = np.zeros(len(p_selected), dtype=np.int64)
for conf_id, label_id in zip(conf_ids, label_ids):
labels[pc_p[pc_c == conf_id]] = label_id
labels = tlx.convert_to_tensor(labels, dtype=tlx.int64)

num_classes = 3

# Create train, val, and test indices
float_mask = np.zeros(len(pc_p))
for conf_id in conf_ids:
pc_c_mask = (pc_c == conf_id)
float_mask[pc_p[pc_c_mask]] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum()))
train_idx = np.where(float_mask <= 0.2)[0]
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
test_idx = np.where(float_mask > 0.3)[0]

num_nodes = features.shape[0]
train_mask = np.zeros(num_nodes, dtype=bool)
train_mask[train_idx] = True
val_mask = np.zeros(num_nodes, dtype=bool)
val_mask[val_idx] = True
test_mask = np.zeros(num_nodes, dtype=bool)
test_mask[test_idx] = True

# Create the raw heterogeneous graph
graph = HeteroGraph()
graph['paper'].x = features_dict['paper']
graph['paper'].num_nodes = num_nodes
graph['author'].num_nodes = p_vs_a.shape[1]
graph['field'].num_nodes = p_vs_f.shape[1]

# Add edges to the graph
graph['paper', 'pa', 'author'].edge_index = edge_index_pa
graph['author', 'ap', 'paper'].edge_index = edge_index_ap
graph['paper', 'pf', 'field'].edge_index = edge_index_pf
graph['field', 'fp', 'paper'].edge_index = edge_index_fp

# Assign labels and masks to paper nodes
graph['paper'].y = labels
graph['paper'].train_mask = train_mask
graph['paper'].val_mask = val_mask
graph['paper'].test_mask = test_mask

# Create meta-path graphs (PAP and PFP)
pap_adj = p_vs_a.dot(p_vs_a.T)
pap_edge_index = np.vstack(pap_adj.nonzero())
pfp_adj = p_vs_f.dot(p_vs_f.T)
pfp_edge_index = np.vstack(pfp_adj.nonzero())

# Build the meta-path graph
meta_graph = HeteroGraph()
meta_graph['paper'].x = features_dict['paper']
meta_graph['paper'].num_nodes = num_nodes
meta_graph['paper', 'author', 'paper'].edge_index = pap_edge_index
meta_graph['paper', 'field', 'paper'].edge_index = pfp_edge_index

# Add labels and masks to the meta-path graph
meta_graph['paper'].y = labels
meta_graph['paper'].train_mask = train_mask
meta_graph['paper'].val_mask = val_mask
meta_graph['paper'].test_mask = test_mask

return graph, meta_graph, features_dict, labels, num_classes, train_idx, val_idx, test_idx, \
train_mask, val_mask, test_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码看来是加载 acm 数据集的代码,现在gammagl中已经有了acm数据集,现在需要你查看gammagl中已有的acm数据集和你处理得到的acm数据集是否一致。如果一致的话,就可以直接使用gammagl中已有的acm数据集,如果有差别的话,需要你在 gammagl/datasets/ 路径下新建一个python文件,命名为 acm4rohe.py ,并根据流程编写对应代码,这样就可以直接在 rohehan_trainer.py 中直接通过接口进行数据集导入了。
此外,还有一个问题,在上传代码过程中,数据集的相关数据是不能够上传的,需要你在 acm4rohe.py 中 指定文件的下载路径,而不是直接把相关文件上传。

Comment on lines 92 to 117
def setup_log_dir(args, sampling=False):
date_postfix = get_date_postfix()
log_dir = os.path.join(args.log_dir, f'{args.dataset}_{date_postfix}')

if sampling:
log_dir += '_sampling'

mkdir_p(log_dir)
return log_dir

# Set random seed and device settings based on input arguments
def setup(args):
tlx.set_seed(args.seed)
args.dataset = 'ACMRaw'
args.log_dir = setup_log_dir(args)
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")
return args

# Create a binary mask from a list of indices
def get_binary_mask(total_size, indices):
mask = np.zeros(total_size, dtype=bool)
mask[indices] = True
return mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在编码时,需要判断 utils.py 中的函数是否都是必须的,有没有可以删减的空间,尽量避免把作者代码照搬过来

Comment on lines 121 to 124
for i in range(1):
with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f:
tar_tmp = np.sort(pkl.load(f))
tar_idx.extend(tar_tmp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这个for循环似乎只运行一次,是不是没必要写for循环

Comment on lines 17 to 23
## Performance

Reference performance numbers for the ACM dataset:

| Dataset | Clean (no attack) | Attack(1 perturbation) | Attack(3 perturbations) | Attack(5 perturbations) |
| ------- | ----------------- | ---------------------- | ----------------------- | ----------------------- |
| ACM | 0.930 | 0.915 | 0.905 | 0.895 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要在 readme.md 文件中指明我们所实现的模型在各种后端下的运行结果,并给出作者论文中给出的结果。

Comment on lines 24 to 34
def evaluate(model, data, labels, mask, loss_func):
model.set_eval()
logits = model(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
logits = logits['paper'] # Focus evaluation on 'paper' nodes
mask_indices = mask # Assuming mask is an array of indices
logits_masked = tlx.gather(logits, tlx.convert_to_tensor(mask_indices, dtype=tlx.int64))
labels_masked = tlx.gather(labels, tlx.convert_to_tensor(mask_indices, dtype=tlx.int64))
loss = loss_func(logits_masked, labels_masked)

accuracy, micro_f1, macro_f1 = score(logits_masked, labels_masked)
return loss, accuracy, micro_f1, macro_f1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

evaluate 函数可以放到 rohehan_trainer.py 文件中

Comment on lines 37 to 52
def get_hg(dataname, given_adj_dict, features_dict, labels=None, train_mask=None, val_mask=None, test_mask=None):
meta_graph = HeteroGraph()
meta_graph['paper'].x = features_dict['paper']
meta_graph['paper'].num_nodes = features_dict['paper'].shape[0]

# Add meta-path-based edges
meta_graph['paper', 'author', 'paper'].edge_index = np.array(given_adj_dict['pa'].dot(given_adj_dict['ap']).nonzero())
meta_graph['paper', 'field', 'paper'].edge_index = np.array(given_adj_dict['pf'].dot(given_adj_dict['fp']).nonzero())

# Add labels and masks
meta_graph['paper'].y = labels
meta_graph['paper'].train_mask = train_mask
meta_graph['paper'].val_mask = val_mask
meta_graph['paper'].test_mask = test_mask

return meta_graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的逻辑应该放到 datasets/acm4rohe.py 文件中

Comment on lines 118 to 143
def download_attack_data_files():
"""Download necessary ACM data files from GitHub if they are missing."""
# Define the base URL for raw files in the repository
base_url = "https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/main/Code/data"

# Files to download
files_to_download = [
"generated_attacks/adv_acm_pap_pa_1.pkl",
"generated_attacks/adv_acm_pap_pa_3.pkl",
"generated_attacks/adv_acm_pap_pa_5.pkl",
"preprocess/target_nodes/acm_r_target0.pkl",
"preprocess/target_nodes/acm_r_target1.pkl",
"preprocess/target_nodes/acm_r_target2.pkl",
"preprocess/target_nodes/acm_r_target3.pkl",
"preprocess/target_nodes/acm_r_target4.pkl"
]

# Download each file if it does not exist locally
for file_path in files_to_download:
# Construct the full URL and local save path
file_url = f"{base_url}/{file_path}"
save_folder = os.path.join("data", os.path.dirname(file_path))

# Ensure directories exist and download the file
if not os.path.exists(os.path.join(save_folder, os.path.basename(file_path))):
download_url(file_url, save_folder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的逻辑应该放到 datasets/acm4rohe.py 文件中

Comment on lines 55 to 61
def setup(args):
tlx.set_seed(args.seed)
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")
return args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个逻辑应该放到 rohehan_trainer.py 文件中,具体可以参考其他 trainer 的写法

Comment on lines 94 to 100
def to_item(tensor):
if tlx.BACKEND == 'torch':
return tensor.item()
elif tlx.BACKEND == 'tensorflow':
return tensor.numpy().item()
else:
raise NotImplementedError(f"Unsupported backend: {tlx.BACKEND}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实不用单独写这个函数,用tensorlayerx框架提供的tlx.convert2numpy 接口就可以

Comment on lines 87 to 91
def edge_index_to_adj_matrix(edge_index, num_src_nodes, num_dst_nodes):
src, dst = edge_index
data = np.ones(src.shape[0])
adj_matrix = sp.csc_matrix((data, (src, dst)), shape=(num_src_nodes, num_dst_nodes))
return adj_matrix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数可以写到 gammagl/utils/ 路径下面,作为一个通用的工具

Comment on lines 9 to 10
class ACM4Rohe(InMemoryDataset):
url = "https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要写rst文档

dataname = 'acm'
dataset = ACM4Rohe(root = "./")
g = dataset[0]
dataset.download_attack_data_files()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在调用的时候不需要调用下载的接口,在dataset类中,你可以通过实现download方法实现数据集的下载

Comment on lines 58 to 60
train_idx = np.where(train_mask)[0]
val_idx = np.where(val_mask)[0]
test_idx = np.where(test_mask)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用 gammagl.utils 下的 mask_to_index 方法进行转换

Comment on lines 127 to 128
if not os.path.exists(args.best_model_path):
os.makedirs(args.best_model_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般最优模型权重默认保存到 ./ 路径下,路径肯定是存在的,不需要写这两行代码

Comment on lines 162 to 163
with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f:
tar_tmp = np.sort(pkl.load(f))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个路径应该是和你的数据集保存路径有关的,不能写成固定路径

Comment on lines 176 to 178
adv_filename = f'data/generated_attacks/adv_acm_pap_pa_{n_perturbation}.pkl'
with open(adv_filename, 'rb') as f:
modified_opt = pkl.load(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants