From d2c6d1faff0b30ac3c2bfe625685c19443a4f84b Mon Sep 17 00:00:00 2001 From: RektPunk <110188257+RektPunk@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:11:27 +0900 Subject: [PATCH] [Feature] add length of data dividing term to gradient, hessian (#47) Related to https://github.com/RektPunk/MQBoost/issues/44 --- mqboost/objective.py | 15 ++++++++------- tests/test_objective.py | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mqboost/objective.py b/mqboost/objective.py index aa42bf2..51ce780 100644 --- a/mqboost/objective.py +++ b/mqboost/objective.py @@ -7,8 +7,8 @@ from mqboost.utils import delta_validate, epsilon_validate CHECK_LOSS: str = "check_loss" -GradFnLike = Callable[[Any], np.ndarray] -HessFnLike = Callable[[Any], np.ndarray] +GradFnLike = Callable[[np.ndarray, float, Any], np.ndarray] +HessFnLike = Callable[[np.ndarray, float, Any], np.ndarray] ObjLike = Callable[ [np.ndarray, DtrainLike, list[float], Any], tuple[np.ndarray, np.ndarray] ] @@ -79,14 +79,15 @@ def _compute_grads_hess( _y_train, _y_pred = _train_pred_reshape( y_pred=y_pred, dtrain=dtrain, len_alpha=_len_alpha ) - grads = [] - hess = [] + grads: list[np.ndarray] = [] + hess: list[np.ndarray] = [] + _len_y = len(_y_train[0]) for alpha_inx in range(len(alphas)): - _err_for_alpha = _y_train[alpha_inx] - _y_pred[alpha_inx] + _err_for_alpha: np.ndarray = _y_train[alpha_inx] - _y_pred[alpha_inx] _grad = grad_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs) _hess = hess_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs) - grads.append(_grad) - hess.append(_hess) + grads.append(_grad / _len_y) + hess.append(_hess / _len_y) return np.concatenate(grads), np.concatenate(hess) diff --git a/tests/test_objective.py b/tests/test_objective.py index 4123348..7fd70c4 100644 --- a/tests/test_objective.py +++ b/tests/test_objective.py @@ -81,7 +81,7 @@ def test_check_loss_grad_hess(dummy_data): dtrain = dummy_data(y_true) grads, hess = check_loss_grad_hess(y_pred=y_pred, dtrain=dtrain, alphas=alphas) # fmt: off - expected_grads = [-0.1, -0.1, 0.9, -0.1, -0.1, -0.5, -0.5, 0.5, -0.5, -0.5, -0.9, -0.9, 0.1, -0.9, -0.9] + expected_grads = [-0.02, -0.02, 0.18, -0.02, -0.02, -0.1, -0.1, 0.1, -0.1, -0.1, -0.18, -0.18, 0.02, -0.18, -0.18] # fmt: on np.testing.assert_almost_equal(grads, np.array(expected_grads)) assert grads.shape == hess.shape @@ -92,9 +92,9 @@ def test_check_loss_grad_hess(dummy_data): @pytest.mark.parametrize( "delta, expected_grads", [ - (0.01, [-0.1, -0.1, 0.9, -0.1, -0.1, -0.5, -0.5, 0.5, -0.5, -0.5, -0.9, -0.9, 0.1, -0.9, -0.9]), - (0.02, [0.001, 0.001, 0.9, -0.1, -0.1, 0.005, 0.005, 0.5, -0.5, -0.5, 0.009, 0.009, 0.1, -0.9, -0.9]), - (0.05, [0.001, 0.001, 0.018, 0.003, 0.005, 0.005, 0.005, 0.01, 0.015, 0.025, 0.009, 0.009, 0.002, 0.027, 0.045]), + (0.01, [-0.02, -0.02, 0.18, -0.02, -0.02, -0.1, -0.1, 0.1, -0.1, -0.1, -0.18, -0.18, 0.02, -0.18, -0.18]), + (0.02, [0.0002, 0.0002, 0.18, -0.02, -0.02, 0.001, 0.001, 0.1, -0.1, -0.1, 0.0018, 0.0018, 0.02, -0.18, -0.18]), + (0.05, [0.0002, 0.0002, 0.0036, 0.0006, 0.001, 0.001, 0.001, 0.002, 0.003, 0.005, 0.0018, 0.0018, 0.0004, 0.0054, 0.009]), ], ) # fmt: on @@ -116,18 +116,18 @@ def test_huber_loss_grad_hess(dummy_data, delta, expected_grads): [ ( 0.01, - [0.15, 0.15, 0.7333, 0.025, -0.0167, -0.25, -0.25, 0.3333, -0.375, -0.4167, -0.65, -0.65, -0.0667, -0.775, -0.8167], - [25.0, 25.0, 16.6666, 12.5, 8.3333, 25.0, 25.0, 16.6666, 12.5, 8.33333333, 25.0, 25.0, 16.6666, 12.5, 8.33333333], + [0.03, 0.03, 0.1467, 0.005, -0.0033, -0.05, -0.05, 0.0667, -0.075, -0.0833, -0.13, -0.13, -0.0133, -0.155, -0.1633], + [5.0, 5.0, 3.3333, 2.5, 1.6667, 5.0, 5.0, 3.3333, 2.5, 1.66666667, 5.0, 5.0, 3.3333, 2.5, 1.66666667], ), ( 0.005, - [0.0667, 0.0667, 0.8, -0.02857, -0.0545, -0.3333, -0.3333, 0.4, -0.4286, -0.4545, -0.7333, -0.73333, 0.0, -0.8286, -0.8545], - [33.3333, 33.3333, 20.0, 14.2857, 9.0909, 33.3333, 33.3333, 20.0, 14.2857, 9.0909, 33.3333, 33.3333, 20.0, 14.2857, 9.0909], + [0.0133, 0.0133, 0.16, -0.00571, -0.0109, -0.0667, -0.0667, 0.08, -0.0857, -0.0909, -0.1467, -0.14667, 0.0, -0.1657, -0.1709], + [6.6667, 6.6667, 4.0, 2.8571, 1.8182, 6.6667, 6.6667, 4.0, 2.8571, 1.8182, 6.6667, 6.6667, 4.0, 2.8571, 1.8182], ), ( 0.001, - [-0.0545, -0.0545, 0.8762, -0.0838, -0.0902, -0.4545, -0.4545, 0.4762, -0.4839, -0.4902, -0.8545, -0.8545, 0.0762, -0.8839, -0.8902], - [45.4545, 45.4545, 23.8095, 16.1290, 9.8039, 45.4545, 45.4545, 23.8095, 16.1290, 9.8039, 45.4545, 45.4545, 23.8095, 16.1290, 9.8039], + [-0.0109, -0.0109, 0.1752, -0.01676, -0.01804, -0.0909, -0.0909, 0.0952, -0.09678, -0.09804, -0.1709, -0.1709, 0.01524, -0.17678, -0.17804], + [9.0909, 9.0909, 4.7619, 3.2258, 1.9608, 9.0909, 9.0909, 4.7619, 3.2258, 1.9608, 9.0909, 9.0909, 4.7619, 3.2258, 1.9608], ), ], )