-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] RoHe #220
base: main
Are you sure you want to change the base?
[Model] RoHe #220
Conversation
在添加了 |
examples/rohehan/rohehan_trainer.py
Outdated
} | ||
|
||
# Training loop | ||
best_model_saver = BestModelSaver() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要使用这个工具类,把这里代码以及 utils.py
中对应的代码删去即可。在gammagl中,一般使用loss或者val_acc来作为判断是否更新权重的方式,可以参考其他文件的写法。
examples/rohehan/utils.py
Outdated
def load_acm_raw(): | ||
data_path = 'acm/ACM.mat' # Path to the dataset | ||
data = sio.loadmat(data_path) | ||
p_vs_f = data['PvsL'] # Paper-field adjacency | ||
p_vs_a = data['PvsA'] # Paper-author adjacency | ||
p_vs_t = data['PvsT'] # Paper-term feature matrix | ||
p_vs_c = data['PvsC'] # Paper-conference labels | ||
|
||
# Assigning classes to specific conferences | ||
conf_ids = [0, 1, 9, 10, 13] | ||
label_ids = [0, 1, 2, 2, 1] | ||
|
||
# Filter papers with conference labels | ||
p_vs_c_filter = p_vs_c[:, conf_ids] | ||
p_selected = np.nonzero(p_vs_c_filter.sum(1))[0] | ||
p_vs_f = p_vs_f[p_selected] | ||
p_vs_a = p_vs_a[p_selected] | ||
p_vs_t = p_vs_t[p_selected] | ||
p_vs_c = p_vs_c[p_selected] | ||
|
||
# Construct edge indices | ||
edge_index_pa = np.vstack(p_vs_a.nonzero()) | ||
edge_index_ap = edge_index_pa[[1, 0]] | ||
edge_index_pf = np.vstack(p_vs_f.nonzero()) | ||
edge_index_fp = edge_index_pf[[1, 0]] | ||
|
||
# Create node features dictionary | ||
features = tlx.convert_to_tensor(p_vs_t.toarray(), dtype=tlx.float32) | ||
features_dict = {'paper': features} | ||
|
||
# Process labels | ||
pc_p, pc_c = p_vs_c.nonzero() | ||
labels = np.zeros(len(p_selected), dtype=np.int64) | ||
for conf_id, label_id in zip(conf_ids, label_ids): | ||
labels[pc_p[pc_c == conf_id]] = label_id | ||
labels = tlx.convert_to_tensor(labels, dtype=tlx.int64) | ||
|
||
num_classes = 3 | ||
|
||
# Create train, val, and test indices | ||
float_mask = np.zeros(len(pc_p)) | ||
for conf_id in conf_ids: | ||
pc_c_mask = (pc_c == conf_id) | ||
float_mask[pc_p[pc_c_mask]] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) | ||
train_idx = np.where(float_mask <= 0.2)[0] | ||
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] | ||
test_idx = np.where(float_mask > 0.3)[0] | ||
|
||
num_nodes = features.shape[0] | ||
train_mask = np.zeros(num_nodes, dtype=bool) | ||
train_mask[train_idx] = True | ||
val_mask = np.zeros(num_nodes, dtype=bool) | ||
val_mask[val_idx] = True | ||
test_mask = np.zeros(num_nodes, dtype=bool) | ||
test_mask[test_idx] = True | ||
|
||
# Create the raw heterogeneous graph | ||
graph = HeteroGraph() | ||
graph['paper'].x = features_dict['paper'] | ||
graph['paper'].num_nodes = num_nodes | ||
graph['author'].num_nodes = p_vs_a.shape[1] | ||
graph['field'].num_nodes = p_vs_f.shape[1] | ||
|
||
# Add edges to the graph | ||
graph['paper', 'pa', 'author'].edge_index = edge_index_pa | ||
graph['author', 'ap', 'paper'].edge_index = edge_index_ap | ||
graph['paper', 'pf', 'field'].edge_index = edge_index_pf | ||
graph['field', 'fp', 'paper'].edge_index = edge_index_fp | ||
|
||
# Assign labels and masks to paper nodes | ||
graph['paper'].y = labels | ||
graph['paper'].train_mask = train_mask | ||
graph['paper'].val_mask = val_mask | ||
graph['paper'].test_mask = test_mask | ||
|
||
# Create meta-path graphs (PAP and PFP) | ||
pap_adj = p_vs_a.dot(p_vs_a.T) | ||
pap_edge_index = np.vstack(pap_adj.nonzero()) | ||
pfp_adj = p_vs_f.dot(p_vs_f.T) | ||
pfp_edge_index = np.vstack(pfp_adj.nonzero()) | ||
|
||
# Build the meta-path graph | ||
meta_graph = HeteroGraph() | ||
meta_graph['paper'].x = features_dict['paper'] | ||
meta_graph['paper'].num_nodes = num_nodes | ||
meta_graph['paper', 'author', 'paper'].edge_index = pap_edge_index | ||
meta_graph['paper', 'field', 'paper'].edge_index = pfp_edge_index | ||
|
||
# Add labels and masks to the meta-path graph | ||
meta_graph['paper'].y = labels | ||
meta_graph['paper'].train_mask = train_mask | ||
meta_graph['paper'].val_mask = val_mask | ||
meta_graph['paper'].test_mask = test_mask | ||
|
||
return graph, meta_graph, features_dict, labels, num_classes, train_idx, val_idx, test_idx, \ | ||
train_mask, val_mask, test_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码看来是加载 acm 数据集的代码,现在gammagl中已经有了acm数据集,现在需要你查看gammagl中已有的acm数据集和你处理得到的acm数据集是否一致。如果一致的话,就可以直接使用gammagl中已有的acm数据集,如果有差别的话,需要你在 gammagl/datasets/
路径下新建一个python文件,命名为 acm4rohe.py
,并根据流程编写对应代码,这样就可以直接在 rohehan_trainer.py
中直接通过接口进行数据集导入了。
此外,还有一个问题,在上传代码过程中,数据集的相关数据是不能够上传的,需要你在 acm4rohe.py
中 指定文件的下载路径,而不是直接把相关文件上传。
examples/rohehan/utils.py
Outdated
def setup_log_dir(args, sampling=False): | ||
date_postfix = get_date_postfix() | ||
log_dir = os.path.join(args.log_dir, f'{args.dataset}_{date_postfix}') | ||
|
||
if sampling: | ||
log_dir += '_sampling' | ||
|
||
mkdir_p(log_dir) | ||
return log_dir | ||
|
||
# Set random seed and device settings based on input arguments | ||
def setup(args): | ||
tlx.set_seed(args.seed) | ||
args.dataset = 'ACMRaw' | ||
args.log_dir = setup_log_dir(args) | ||
if args.gpu >= 0: | ||
tlx.set_device("GPU", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
return args | ||
|
||
# Create a binary mask from a list of indices | ||
def get_binary_mask(total_size, indices): | ||
mask = np.zeros(total_size, dtype=bool) | ||
mask[indices] = True | ||
return mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在编码时,需要判断 utils.py 中的函数是否都是必须的,有没有可以删减的空间,尽量避免把作者代码照搬过来
examples/rohehan/rohehan_trainer.py
Outdated
for i in range(1): | ||
with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f: | ||
tar_tmp = np.sort(pkl.load(f)) | ||
tar_idx.extend(tar_tmp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里这个for循环似乎只运行一次,是不是没必要写for循环
examples/rohehan/readme.md
Outdated
## Performance | ||
|
||
Reference performance numbers for the ACM dataset: | ||
|
||
| Dataset | Clean (no attack) | Attack(1 perturbation) | Attack(3 perturbations) | Attack(5 perturbations) | | ||
| ------- | ----------------- | ---------------------- | ----------------------- | ----------------------- | | ||
| ACM | 0.930 | 0.915 | 0.905 | 0.895 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要在 readme.md
文件中指明我们所实现的模型在各种后端下的运行结果,并给出作者论文中给出的结果。
examples/rohehan/utils.py
Outdated
def evaluate(model, data, labels, mask, loss_func): | ||
model.set_eval() | ||
logits = model(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict']) | ||
logits = logits['paper'] # Focus evaluation on 'paper' nodes | ||
mask_indices = mask # Assuming mask is an array of indices | ||
logits_masked = tlx.gather(logits, tlx.convert_to_tensor(mask_indices, dtype=tlx.int64)) | ||
labels_masked = tlx.gather(labels, tlx.convert_to_tensor(mask_indices, dtype=tlx.int64)) | ||
loss = loss_func(logits_masked, labels_masked) | ||
|
||
accuracy, micro_f1, macro_f1 = score(logits_masked, labels_masked) | ||
return loss, accuracy, micro_f1, macro_f1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
evaluate
函数可以放到 rohehan_trainer.py
文件中
examples/rohehan/utils.py
Outdated
def get_hg(dataname, given_adj_dict, features_dict, labels=None, train_mask=None, val_mask=None, test_mask=None): | ||
meta_graph = HeteroGraph() | ||
meta_graph['paper'].x = features_dict['paper'] | ||
meta_graph['paper'].num_nodes = features_dict['paper'].shape[0] | ||
|
||
# Add meta-path-based edges | ||
meta_graph['paper', 'author', 'paper'].edge_index = np.array(given_adj_dict['pa'].dot(given_adj_dict['ap']).nonzero()) | ||
meta_graph['paper', 'field', 'paper'].edge_index = np.array(given_adj_dict['pf'].dot(given_adj_dict['fp']).nonzero()) | ||
|
||
# Add labels and masks | ||
meta_graph['paper'].y = labels | ||
meta_graph['paper'].train_mask = train_mask | ||
meta_graph['paper'].val_mask = val_mask | ||
meta_graph['paper'].test_mask = test_mask | ||
|
||
return meta_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数的逻辑应该放到 datasets/acm4rohe.py
文件中
examples/rohehan/utils.py
Outdated
def download_attack_data_files(): | ||
"""Download necessary ACM data files from GitHub if they are missing.""" | ||
# Define the base URL for raw files in the repository | ||
base_url = "https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/main/Code/data" | ||
|
||
# Files to download | ||
files_to_download = [ | ||
"generated_attacks/adv_acm_pap_pa_1.pkl", | ||
"generated_attacks/adv_acm_pap_pa_3.pkl", | ||
"generated_attacks/adv_acm_pap_pa_5.pkl", | ||
"preprocess/target_nodes/acm_r_target0.pkl", | ||
"preprocess/target_nodes/acm_r_target1.pkl", | ||
"preprocess/target_nodes/acm_r_target2.pkl", | ||
"preprocess/target_nodes/acm_r_target3.pkl", | ||
"preprocess/target_nodes/acm_r_target4.pkl" | ||
] | ||
|
||
# Download each file if it does not exist locally | ||
for file_path in files_to_download: | ||
# Construct the full URL and local save path | ||
file_url = f"{base_url}/{file_path}" | ||
save_folder = os.path.join("data", os.path.dirname(file_path)) | ||
|
||
# Ensure directories exist and download the file | ||
if not os.path.exists(os.path.join(save_folder, os.path.basename(file_path))): | ||
download_url(file_url, save_folder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数的逻辑应该放到 datasets/acm4rohe.py
文件中
examples/rohehan/utils.py
Outdated
def setup(args): | ||
tlx.set_seed(args.seed) | ||
if args.gpu >= 0: | ||
tlx.set_device("GPU", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
return args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个逻辑应该放到 rohehan_trainer.py
文件中,具体可以参考其他 trainer 的写法
examples/rohehan/utils.py
Outdated
def to_item(tensor): | ||
if tlx.BACKEND == 'torch': | ||
return tensor.item() | ||
elif tlx.BACKEND == 'tensorflow': | ||
return tensor.numpy().item() | ||
else: | ||
raise NotImplementedError(f"Unsupported backend: {tlx.BACKEND}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其实不用单独写这个函数,用tensorlayerx框架提供的tlx.convert2numpy
接口就可以
examples/rohehan/utils.py
Outdated
def edge_index_to_adj_matrix(edge_index, num_src_nodes, num_dst_nodes): | ||
src, dst = edge_index | ||
data = np.ones(src.shape[0]) | ||
adj_matrix = sp.csc_matrix((data, (src, dst)), shape=(num_src_nodes, num_dst_nodes)) | ||
return adj_matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数可以写到 gammagl/utils/
路径下面,作为一个通用的工具
gammagl/datasets/acm4rohe.py
Outdated
class ACM4Rohe(InMemoryDataset): | ||
url = "https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要写rst文档
examples/rohehan/rohehan_trainer.py
Outdated
dataname = 'acm' | ||
dataset = ACM4Rohe(root = "./") | ||
g = dataset[0] | ||
dataset.download_attack_data_files() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在调用的时候不需要调用下载的接口,在dataset类中,你可以通过实现download方法实现数据集的下载
examples/rohehan/rohehan_trainer.py
Outdated
train_idx = np.where(train_mask)[0] | ||
val_idx = np.where(val_mask)[0] | ||
test_idx = np.where(test_mask)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用 gammagl.utils
下的 mask_to_index
方法进行转换
examples/rohehan/rohehan_trainer.py
Outdated
if not os.path.exists(args.best_model_path): | ||
os.makedirs(args.best_model_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般最优模型权重默认保存到 ./
路径下,路径肯定是存在的,不需要写这两行代码
examples/rohehan/rohehan_trainer.py
Outdated
with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f: | ||
tar_tmp = np.sort(pkl.load(f)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个路径应该是和你的数据集保存路径有关的,不能写成固定路径
examples/rohehan/rohehan_trainer.py
Outdated
adv_filename = f'data/generated_attacks/adv_acm_pap_pa_{n_perturbation}.pkl' | ||
with open(adv_filename, 'rb') as f: | ||
modified_opt = pkl.load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
Robust Heterogeneous Graph Neural Network (RoHeHAN)
This is an implementation of
RoHeHAN
, a robust heterogeneous graph neural network designed to defend against adversarial attacks on heterogeneous graphs.tensorlayerx
andgammagl
libraries.Usage
To reproduce the RoHeHAN results on the ACM dataset, run the following command:
Performance
Reference performance numbers for the ACM dataset:
ACM dataset link: https://github.com/Jhy1993/HAN/raw/master/data/acm/ACM.mat
Example Commands
You can adjust training settings, such as the number of epochs, learning rate, and dropout rate, with the following commands:
TL_BACKEND="torch" python rohehan_trainer.py --num_epochs 200 --lr 0.005 --dropout 0.6 --gpu 0 --seed 0
Notes
settings
in the RoHeGAT layer control the attention purifier mechanism, which ensures robustness against adversarial attacks by pruning unreliable neighbors.This implementation builds on the idea of using metapath-based transiting probability and attention purification to improve the robustness of heterogeneous graph neural networks (HGNNs).
Original Paper Results
The original paper reports the following performance metrics under clean and adversarial settings: