Skip to content

Commit

Permalink
simplify target wrapping in examples
Browse files Browse the repository at this point in the history
  • Loading branch information
rpreen committed Jul 10, 2024
1 parent f6be107 commit 07fa47b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
11 changes: 8 additions & 3 deletions examples/train_rf_breast_cancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
model.fit(X_train, y_train)

logging.info("Wrapping the model and data in a Target object")
target = Target(model=model)
target.dataset_name = "breast cancer"
target.add_processed_data(X_train, y_train, X_test, y_test)
target = Target(
model=model,
dataset_name="breast cancer",
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
)

logging.info("Writing Target object to directory: '%s'", output_dir)
target.save(output_dir)
26 changes: 21 additions & 5 deletions examples/train_rf_nursery.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,22 @@
logging.info("Base model test accuracy: %.4f", acc_test)

logging.info("Wrapping the model and data in a Target object")
target = Target(model=model)
target.dataset_name = "nursery"
target.add_processed_data(X_train, y_train, X_test, y_test)
target.add_raw_data(X, y, X_train_orig, y_train_orig, X_test_orig, y_test_orig)
target = Target(
model=model,
dataset_name="nursery",
# processed data
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
# original unprocessed data
X_orig=X,
y_orig=y,
X_train_orig=X_train_orig,
y_train_orig=y_train_orig,
X_test_orig=X_test_orig,
y_test_orig=y_test_orig,
)

logging.info("Wrapping feature details and encoding for attribute inference")
feature_indices = [
Expand All @@ -59,7 +71,11 @@
[24, 25, 26], # health
]
for i, index in enumerate(feature_indices):
target.add_feature(nursery_data.feature_names[i], index, "onehot")
target.add_feature(
name=nursery_data.feature_names[i],
indices=index,
encoding="onehot",
)

logging.info("Writing Target object to directory: '%s'", output_dir)
target.save(output_dir)

0 comments on commit 07fa47b

Please sign in to comment.