From feb7e4d952d52b74b4c81ebe837c5f4aa836437b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 22:09:45 +0100 Subject: [PATCH] pre-commit autoupdate (#273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/codespell-project/codespell: v2.2.6 → v2.3.0](https://github.com/codespell-project/codespell/compare/v2.2.6...v2.3.0) - [github.com/psf/black: 24.4.0 → 24.4.2](https://github.com/psf/black/compare/24.4.0...24.4.2) * fix lira attack from config error checking * ignore assertIn spelling --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Richard Preen --- .pre-commit-config.yaml | 6 +++--- aisdc/attacks/likelihood_attack.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c211cfe..327f447e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,10 +21,10 @@ repos: # Check for spelling - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell - args: ["-L", "fpr, tre, sav, provid"] + args: ["-L", "fpr, tre, sav, provid, assertIn"] exclude: > (?x)^( .*\.svg| @@ -53,7 +53,7 @@ repos: # Black format Python and notebooks - repo: https://github.com/psf/black - rev: 24.4.0 + rev: 24.4.2 hooks: - id: black-jupyter diff --git a/aisdc/attacks/likelihood_attack.py b/aisdc/attacks/likelihood_attack.py index d0bdae17..29af0853 100644 --- a/aisdc/attacks/likelihood_attack.py +++ b/aisdc/attacks/likelihood_attack.py @@ -564,13 +564,15 @@ def attack_from_config(self) -> None: # pylint: disable = too-many-locals logger.info("Loading test predictions form %s", self.test_preds_filename) test_preds = np.loadtxt(self.test_preds_filename, delimiter=",") assert len(test_preds) == len(test_X) - if self.target_model is not None: - clf_module_name, clf_class_name = self.target_model - module = importlib.import_module(clf_module_name) - clf_class = getattr(module, clf_class_name) - if self.target_model_hyp is not None: - clf_params = self.target_model_hyp - clf = clf_class(**clf_params) + if self.target_model is None: + raise ValueError("Target model cannot be None") + if self.target_model_hyp is None: + raise ValueError("Target model hyperparameters cannot be None") + clf_module_name, clf_class_name = self.target_model + module = importlib.import_module(clf_module_name) + clf_class = getattr(module, clf_class_name) + clf_params = self.target_model_hyp + clf = clf_class(**clf_params) logger.info("Created model: %s", str(clf)) self.run_scenario_from_preds( clf, train_X, train_y, train_preds, test_X, test_y, test_preds