diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c211cfe..327f447e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,10 +21,10 @@ repos: # Check for spelling - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell - args: ["-L", "fpr, tre, sav, provid"] + args: ["-L", "fpr, tre, sav, provid, assertIn"] exclude: > (?x)^( .*\.svg| @@ -53,7 +53,7 @@ repos: # Black format Python and notebooks - repo: https://github.com/psf/black - rev: 24.4.0 + rev: 24.4.2 hooks: - id: black-jupyter diff --git a/aisdc/attacks/likelihood_attack.py b/aisdc/attacks/likelihood_attack.py index d0bdae17..29af0853 100644 --- a/aisdc/attacks/likelihood_attack.py +++ b/aisdc/attacks/likelihood_attack.py @@ -564,13 +564,15 @@ def attack_from_config(self) -> None: # pylint: disable = too-many-locals logger.info("Loading test predictions form %s", self.test_preds_filename) test_preds = np.loadtxt(self.test_preds_filename, delimiter=",") assert len(test_preds) == len(test_X) - if self.target_model is not None: - clf_module_name, clf_class_name = self.target_model - module = importlib.import_module(clf_module_name) - clf_class = getattr(module, clf_class_name) - if self.target_model_hyp is not None: - clf_params = self.target_model_hyp - clf = clf_class(**clf_params) + if self.target_model is None: + raise ValueError("Target model cannot be None") + if self.target_model_hyp is None: + raise ValueError("Target model hyperparameters cannot be None") + clf_module_name, clf_class_name = self.target_model + module = importlib.import_module(clf_module_name) + clf_class = getattr(module, clf_class_name) + clf_params = self.target_model_hyp + clf = clf_class(**clf_params) logger.info("Created model: %s", str(clf)) self.run_scenario_from_preds( clf, train_X, train_y, train_preds, test_X, test_y, test_preds