diff --git a/examples/train_rf_breast_cancer.py b/examples/train_rf_breast_cancer.py index f5afad5d..48184586 100644 --- a/examples/train_rf_breast_cancer.py +++ b/examples/train_rf_breast_cancer.py @@ -24,9 +24,14 @@ model.fit(X_train, y_train) logging.info("Wrapping the model and data in a Target object") - target = Target(model=model) - target.dataset_name = "breast cancer" - target.add_processed_data(X_train, y_train, X_test, y_test) + target = Target( + model=model, + dataset_name="breast cancer", + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) logging.info("Writing Target object to directory: '%s'", output_dir) target.save(output_dir) diff --git a/examples/train_rf_nursery.py b/examples/train_rf_nursery.py index 980ea876..4d604872 100644 --- a/examples/train_rf_nursery.py +++ b/examples/train_rf_nursery.py @@ -42,10 +42,22 @@ logging.info("Base model test accuracy: %.4f", acc_test) logging.info("Wrapping the model and data in a Target object") - target = Target(model=model) - target.dataset_name = "nursery" - target.add_processed_data(X_train, y_train, X_test, y_test) - target.add_raw_data(X, y, X_train_orig, y_train_orig, X_test_orig, y_test_orig) + target = Target( + model=model, + dataset_name="nursery", + # processed data + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + # original unprocessed data + X_orig=X, + y_orig=y, + X_train_orig=X_train_orig, + y_train_orig=y_train_orig, + X_test_orig=X_test_orig, + y_test_orig=y_test_orig, + ) logging.info("Wrapping feature details and encoding for attribute inference") feature_indices = [ @@ -59,7 +71,11 @@ [24, 25, 26], # health ] for i, index in enumerate(feature_indices): - target.add_feature(nursery_data.feature_names[i], index, "onehot") + target.add_feature( + name=nursery_data.feature_names[i], + indices=index, + encoding="onehot", + ) logging.info("Writing Target object to directory: '%s'", output_dir) target.save(output_dir)