Skip to content

Commit

Permalink
Merge pull request #143 from AI-SDC/addressing-issue-142
Browse files Browse the repository at this point in the history
Addressing issue 142
  • Loading branch information
rpreen authored May 5, 2023
2 parents d43a988 + dd891c1 commit eb5d716
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions aisdc/safemodel/classifiers/saferandomforestclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..safemodel import SafeModel
from .safedecisiontreeclassifier import decision_trees_are_equal

# pylint: disable=too-many-ancestors,too-many-instance-attributes
# pylint: disable=too-many-ancestors,too-many-instance-attributes, unidiomatic-typecheck


class SafeRandomForestClassifier(SafeModel, RandomForestClassifier):
Expand Down Expand Up @@ -72,7 +72,7 @@ def additional_checks( # pylint: disable=too-many-nested-blocks,too-many-branch
for item in self.examine_seperately_items:
# template for class of things that make up forest
if item == "base_estimator":
if curr_separate[item] != saved_separate[item]:
if type(curr_separate[item]) != type(saved_separate[item]):
# msg += get_reporting_string(name="basic_params_differ",length=1)
msg += get_reporting_string(
name="param_changed_from_to",
Expand Down
36 changes: 18 additions & 18 deletions tests/test_saferandomforestclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ def test_randomforest_hacked_postfit():
part2 = get_reporting_string(
name="param_changed_from_to", key="bootstrap", val=False, cur_val=True
)
part3 = get_reporting_string(
name="param_changed_from_to",
key="base_estimator",
val="DecisionTreeClassifier()",
cur_val="DecisionTreeClassifier()",
)
part3 = "" # get_reporting_string(
# name="param_changed_from_to",
# key="base_estimator",
# val="DecisionTreeClassifier()",
# cur_val="DecisionTreeClassifier()",
# )
correct_msg2 = part1 + part2 + part3
# print(f'Correct: {correct_msg2}\n Actual: {msg2}')

Expand Down Expand Up @@ -267,12 +267,12 @@ def test_randomforest_hacked_postfit_trees_swapped():
name="param_changed_from_to", key="max_depth", val="None", cur_val="2"
)
part3 = get_reporting_string(name="forest_estimators_differ", idx=5)
part4 = get_reporting_string(
name="param_changed_from_to",
key="base_estimator",
val="DecisionTreeClassifier()",
cur_val="DecisionTreeClassifier()",
)
part4 = "" # get_reporting_string(
# name="param_changed_from_to",
# key="base_estimator",
# val="DecisionTreeClassifier()",
# cur_val="DecisionTreeClassifier()",
# )
correct_msg = part1 + part2 + part3 + part4
# print(f'Correct:\n{correct_msg} Actual:\n{msg}')
assert msg == correct_msg, f"{msg}\n should be {correct_msg}"
Expand All @@ -298,12 +298,12 @@ def test_randomforest_hacked_postfit_moretrees():
name="param_changed_from_to", key="n_estimators", val="5", cur_val="10"
)
part3 = get_reporting_string(name="different_num_estimators", num1=10, num2=5)
part4 = get_reporting_string(
name="param_changed_from_to",
key="base_estimator",
val="DecisionTreeClassifier()",
cur_val="DecisionTreeClassifier()",
)
part4 = "" # get_reporting_string(
# name="param_changed_from_to",
# key="base_estimator",
# val="DecisionTreeClassifier()",
# cur_val="DecisionTreeClassifier()",
# )
correct_msg = part1 + part2 + part3 + part4
# print(f'Correct:\n{correct_msg} Actual:\n{msg}')
assert msg == correct_msg, f"{msg}\n should be {correct_msg}"
Expand Down

0 comments on commit eb5d716

Please sign in to comment.