From ea73f6b99015f758fe0f35ee147cbb42ef0960dc Mon Sep 17 00:00:00 2001 From: Jim-smith Date: Fri, 5 May 2023 11:49:24 +0100 Subject: [PATCH 1/4] fix to line 75 so it compares types not exact objects Signed-off-by: Jim-smith --- aisdc/safemodel/classifiers/saferandomforestclassifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aisdc/safemodel/classifiers/saferandomforestclassifier.py b/aisdc/safemodel/classifiers/saferandomforestclassifier.py index 3315f96f..ee970132 100644 --- a/aisdc/safemodel/classifiers/saferandomforestclassifier.py +++ b/aisdc/safemodel/classifiers/saferandomforestclassifier.py @@ -72,7 +72,7 @@ def additional_checks( # pylint: disable=too-many-nested-blocks,too-many-branch for item in self.examine_seperately_items: # template for class of things that make up forest if item == "base_estimator": - if curr_separate[item] != saved_separate[item]: + if type(curr_separate[item]) != type(saved_separate[item]): # msg += get_reporting_string(name="basic_params_differ",length=1) msg += get_reporting_string( name="param_changed_from_to", From 85e1d67152506a92ece7ed1f56c0b779cbef9c76 Mon Sep 17 00:00:00 2001 From: Jim-smith Date: Fri, 5 May 2023 12:08:23 +0100 Subject: [PATCH 2/4] fix lint warning --- aisdc/safemodel/classifiers/saferandomforestclassifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aisdc/safemodel/classifiers/saferandomforestclassifier.py b/aisdc/safemodel/classifiers/saferandomforestclassifier.py index ee970132..45670be1 100644 --- a/aisdc/safemodel/classifiers/saferandomforestclassifier.py +++ b/aisdc/safemodel/classifiers/saferandomforestclassifier.py @@ -12,7 +12,7 @@ from ..safemodel import SafeModel from .safedecisiontreeclassifier import decision_trees_are_equal -# pylint: disable=too-many-ancestors,too-many-instance-attributes +# pylint: disable=too-many-ancestors,too-many-instance-attributes, unidiomatic-typecheck class SafeRandomForestClassifier(SafeModel, RandomForestClassifier): From df224eaeed8cc5a92fe2a2d3c5367417afc0aea2 Mon Sep 17 00:00:00 2001 From: Jim-smith Date: Fri, 5 May 2023 12:17:05 +0100 Subject: [PATCH 3/4] fix safe random forest tests --- tests/test_saferandomforestclassifier.py | 46 ++++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/test_saferandomforestclassifier.py b/tests/test_saferandomforestclassifier.py index f766f76f..736fb004 100644 --- a/tests/test_saferandomforestclassifier.py +++ b/tests/test_saferandomforestclassifier.py @@ -151,12 +151,12 @@ def test_randomforest_hacked_postfit(): part2 = get_reporting_string( name="param_changed_from_to", key="bootstrap", val=False, cur_val=True ) - part3 = get_reporting_string( - name="param_changed_from_to", - key="base_estimator", - val="DecisionTreeClassifier()", - cur_val="DecisionTreeClassifier()", - ) + part3 = ""#get_reporting_string( +# name="param_changed_from_to", +# key="base_estimator", +# val="DecisionTreeClassifier()", +# cur_val="DecisionTreeClassifier()", +# ) correct_msg2 = part1 + part2 + part3 # print(f'Correct: {correct_msg2}\n Actual: {msg2}') @@ -197,11 +197,11 @@ def test_randomforest_modeltype_changed(): # correct_msg += get_reporting_string(name="basic_params_differ",length=1) correct_msg = get_reporting_string(name="forest_estimators_differ", idx=5) correct_msg += get_reporting_string( - name="param_changed_from_to", - key="base_estimator", - val="DecisionTreeClassifier()", - cur_val="DummyClassifier()", - ) + name="param_changed_from_to", + key="base_estimator", + val="DecisionTreeClassifier()", + cur_val="DummyClassifier()", + ) # correct_msg += ("structure base_estimator has 1 differences: [('change', '', " # "(DecisionTreeClassifier(), DecisionTreeClassifier()))]" # ) @@ -267,12 +267,12 @@ def test_randomforest_hacked_postfit_trees_swapped(): name="param_changed_from_to", key="max_depth", val="None", cur_val="2" ) part3 = get_reporting_string(name="forest_estimators_differ", idx=5) - part4 = get_reporting_string( - name="param_changed_from_to", - key="base_estimator", - val="DecisionTreeClassifier()", - cur_val="DecisionTreeClassifier()", - ) + part4 = ""#get_reporting_string( +# name="param_changed_from_to", +# key="base_estimator", +# val="DecisionTreeClassifier()", +# cur_val="DecisionTreeClassifier()", +# ) correct_msg = part1 + part2 + part3 + part4 # print(f'Correct:\n{correct_msg} Actual:\n{msg}') assert msg == correct_msg, f"{msg}\n should be {correct_msg}" @@ -298,12 +298,12 @@ def test_randomforest_hacked_postfit_moretrees(): name="param_changed_from_to", key="n_estimators", val="5", cur_val="10" ) part3 = get_reporting_string(name="different_num_estimators", num1=10, num2=5) - part4 = get_reporting_string( - name="param_changed_from_to", - key="base_estimator", - val="DecisionTreeClassifier()", - cur_val="DecisionTreeClassifier()", - ) + part4 = ""#get_reporting_string( +# name="param_changed_from_to", +# key="base_estimator", +# val="DecisionTreeClassifier()", +# cur_val="DecisionTreeClassifier()", +# ) correct_msg = part1 + part2 + part3 + part4 # print(f'Correct:\n{correct_msg} Actual:\n{msg}') assert msg == correct_msg, f"{msg}\n should be {correct_msg}" From dd891c118402c5b9928831c3c3e30fb69626feb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 May 2023 13:09:42 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_saferandomforestclassifier.py | 46 ++++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/test_saferandomforestclassifier.py b/tests/test_saferandomforestclassifier.py index 736fb004..79bf4021 100644 --- a/tests/test_saferandomforestclassifier.py +++ b/tests/test_saferandomforestclassifier.py @@ -151,12 +151,12 @@ def test_randomforest_hacked_postfit(): part2 = get_reporting_string( name="param_changed_from_to", key="bootstrap", val=False, cur_val=True ) - part3 = ""#get_reporting_string( -# name="param_changed_from_to", -# key="base_estimator", -# val="DecisionTreeClassifier()", -# cur_val="DecisionTreeClassifier()", -# ) + part3 = "" # get_reporting_string( + # name="param_changed_from_to", + # key="base_estimator", + # val="DecisionTreeClassifier()", + # cur_val="DecisionTreeClassifier()", + # ) correct_msg2 = part1 + part2 + part3 # print(f'Correct: {correct_msg2}\n Actual: {msg2}') @@ -197,11 +197,11 @@ def test_randomforest_modeltype_changed(): # correct_msg += get_reporting_string(name="basic_params_differ",length=1) correct_msg = get_reporting_string(name="forest_estimators_differ", idx=5) correct_msg += get_reporting_string( - name="param_changed_from_to", - key="base_estimator", - val="DecisionTreeClassifier()", - cur_val="DummyClassifier()", - ) + name="param_changed_from_to", + key="base_estimator", + val="DecisionTreeClassifier()", + cur_val="DummyClassifier()", + ) # correct_msg += ("structure base_estimator has 1 differences: [('change', '', " # "(DecisionTreeClassifier(), DecisionTreeClassifier()))]" # ) @@ -267,12 +267,12 @@ def test_randomforest_hacked_postfit_trees_swapped(): name="param_changed_from_to", key="max_depth", val="None", cur_val="2" ) part3 = get_reporting_string(name="forest_estimators_differ", idx=5) - part4 = ""#get_reporting_string( -# name="param_changed_from_to", -# key="base_estimator", -# val="DecisionTreeClassifier()", -# cur_val="DecisionTreeClassifier()", -# ) + part4 = "" # get_reporting_string( + # name="param_changed_from_to", + # key="base_estimator", + # val="DecisionTreeClassifier()", + # cur_val="DecisionTreeClassifier()", + # ) correct_msg = part1 + part2 + part3 + part4 # print(f'Correct:\n{correct_msg} Actual:\n{msg}') assert msg == correct_msg, f"{msg}\n should be {correct_msg}" @@ -298,12 +298,12 @@ def test_randomforest_hacked_postfit_moretrees(): name="param_changed_from_to", key="n_estimators", val="5", cur_val="10" ) part3 = get_reporting_string(name="different_num_estimators", num1=10, num2=5) - part4 = ""#get_reporting_string( -# name="param_changed_from_to", -# key="base_estimator", -# val="DecisionTreeClassifier()", -# cur_val="DecisionTreeClassifier()", -# ) + part4 = "" # get_reporting_string( + # name="param_changed_from_to", + # key="base_estimator", + # val="DecisionTreeClassifier()", + # cur_val="DecisionTreeClassifier()", + # ) correct_msg = part1 + part2 + part3 + part4 # print(f'Correct:\n{correct_msg} Actual:\n{msg}') assert msg == correct_msg, f"{msg}\n should be {correct_msg}"