Skip to content

Commit

Permalink
pre-commit autoupdate (#273)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/codespell-project/codespell: v2.2.6 → v2.3.0](codespell-project/codespell@v2.2.6...v2.3.0)
- [github.com/psf/black: 24.4.0 → 24.4.2](psf/black@24.4.0...24.4.2)

* fix lira attack from config error checking

* ignore assertIn spelling

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Richard Preen <[email protected]>
  • Loading branch information
pre-commit-ci[bot] and rpreen authored May 30, 2024
1 parent a0b0a90 commit feb7e4d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ repos:

# Check for spelling
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
args: ["-L", "fpr, tre, sav, provid"]
args: ["-L", "fpr, tre, sav, provid, assertIn"]
exclude: >
(?x)^(
.*\.svg|
Expand Down Expand Up @@ -53,7 +53,7 @@ repos:

# Black format Python and notebooks
- repo: https://github.com/psf/black
rev: 24.4.0
rev: 24.4.2
hooks:
- id: black-jupyter

Expand Down
16 changes: 9 additions & 7 deletions aisdc/attacks/likelihood_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,15 @@ def attack_from_config(self) -> None: # pylint: disable = too-many-locals
logger.info("Loading test predictions form %s", self.test_preds_filename)
test_preds = np.loadtxt(self.test_preds_filename, delimiter=",")
assert len(test_preds) == len(test_X)
if self.target_model is not None:
clf_module_name, clf_class_name = self.target_model
module = importlib.import_module(clf_module_name)
clf_class = getattr(module, clf_class_name)
if self.target_model_hyp is not None:
clf_params = self.target_model_hyp
clf = clf_class(**clf_params)
if self.target_model is None:
raise ValueError("Target model cannot be None")
if self.target_model_hyp is None:
raise ValueError("Target model hyperparameters cannot be None")
clf_module_name, clf_class_name = self.target_model
module = importlib.import_module(clf_module_name)
clf_class = getattr(module, clf_class_name)
clf_params = self.target_model_hyp
clf = clf_class(**clf_params)
logger.info("Created model: %s", str(clf))
self.run_scenario_from_preds(
clf, train_X, train_y, train_preds, test_X, test_y, test_preds
Expand Down

0 comments on commit feb7e4d

Please sign in to comment.