Skip to content

Commit

Permalink
[Feature] add length of data dividing term to gradient, hessian (#47)
Browse files Browse the repository at this point in the history
Related to #44
  • Loading branch information
RektPunk authored Nov 2, 2024
1 parent 64e2681 commit d2c6d1f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
15 changes: 8 additions & 7 deletions mqboost/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from mqboost.utils import delta_validate, epsilon_validate

CHECK_LOSS: str = "check_loss"
GradFnLike = Callable[[Any], np.ndarray]
HessFnLike = Callable[[Any], np.ndarray]
GradFnLike = Callable[[np.ndarray, float, Any], np.ndarray]
HessFnLike = Callable[[np.ndarray, float, Any], np.ndarray]
ObjLike = Callable[
[np.ndarray, DtrainLike, list[float], Any], tuple[np.ndarray, np.ndarray]
]
Expand Down Expand Up @@ -79,14 +79,15 @@ def _compute_grads_hess(
_y_train, _y_pred = _train_pred_reshape(
y_pred=y_pred, dtrain=dtrain, len_alpha=_len_alpha
)
grads = []
hess = []
grads: list[np.ndarray] = []
hess: list[np.ndarray] = []
_len_y = len(_y_train[0])
for alpha_inx in range(len(alphas)):
_err_for_alpha = _y_train[alpha_inx] - _y_pred[alpha_inx]
_err_for_alpha: np.ndarray = _y_train[alpha_inx] - _y_pred[alpha_inx]
_grad = grad_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
_hess = hess_fn(error=_err_for_alpha, alpha=alphas[alpha_inx], **kwargs)
grads.append(_grad)
hess.append(_hess)
grads.append(_grad / _len_y)
hess.append(_hess / _len_y)

return np.concatenate(grads), np.concatenate(hess)

Expand Down
20 changes: 10 additions & 10 deletions tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_check_loss_grad_hess(dummy_data):
dtrain = dummy_data(y_true)
grads, hess = check_loss_grad_hess(y_pred=y_pred, dtrain=dtrain, alphas=alphas)
# fmt: off
expected_grads = [-0.1, -0.1, 0.9, -0.1, -0.1, -0.5, -0.5, 0.5, -0.5, -0.5, -0.9, -0.9, 0.1, -0.9, -0.9]
expected_grads = [-0.02, -0.02, 0.18, -0.02, -0.02, -0.1, -0.1, 0.1, -0.1, -0.1, -0.18, -0.18, 0.02, -0.18, -0.18]
# fmt: on
np.testing.assert_almost_equal(grads, np.array(expected_grads))
assert grads.shape == hess.shape
Expand All @@ -92,9 +92,9 @@ def test_check_loss_grad_hess(dummy_data):
@pytest.mark.parametrize(
"delta, expected_grads",
[
(0.01, [-0.1, -0.1, 0.9, -0.1, -0.1, -0.5, -0.5, 0.5, -0.5, -0.5, -0.9, -0.9, 0.1, -0.9, -0.9]),
(0.02, [0.001, 0.001, 0.9, -0.1, -0.1, 0.005, 0.005, 0.5, -0.5, -0.5, 0.009, 0.009, 0.1, -0.9, -0.9]),
(0.05, [0.001, 0.001, 0.018, 0.003, 0.005, 0.005, 0.005, 0.01, 0.015, 0.025, 0.009, 0.009, 0.002, 0.027, 0.045]),
(0.01, [-0.02, -0.02, 0.18, -0.02, -0.02, -0.1, -0.1, 0.1, -0.1, -0.1, -0.18, -0.18, 0.02, -0.18, -0.18]),
(0.02, [0.0002, 0.0002, 0.18, -0.02, -0.02, 0.001, 0.001, 0.1, -0.1, -0.1, 0.0018, 0.0018, 0.02, -0.18, -0.18]),
(0.05, [0.0002, 0.0002, 0.0036, 0.0006, 0.001, 0.001, 0.001, 0.002, 0.003, 0.005, 0.0018, 0.0018, 0.0004, 0.0054, 0.009]),
],
)
# fmt: on
Expand All @@ -116,18 +116,18 @@ def test_huber_loss_grad_hess(dummy_data, delta, expected_grads):
[
(
0.01,
[0.15, 0.15, 0.7333, 0.025, -0.0167, -0.25, -0.25, 0.3333, -0.375, -0.4167, -0.65, -0.65, -0.0667, -0.775, -0.8167],
[25.0, 25.0, 16.6666, 12.5, 8.3333, 25.0, 25.0, 16.6666, 12.5, 8.33333333, 25.0, 25.0, 16.6666, 12.5, 8.33333333],
[0.03, 0.03, 0.1467, 0.005, -0.0033, -0.05, -0.05, 0.0667, -0.075, -0.0833, -0.13, -0.13, -0.0133, -0.155, -0.1633],
[5.0, 5.0, 3.3333, 2.5, 1.6667, 5.0, 5.0, 3.3333, 2.5, 1.66666667, 5.0, 5.0, 3.3333, 2.5, 1.66666667],
),
(
0.005,
[0.0667, 0.0667, 0.8, -0.02857, -0.0545, -0.3333, -0.3333, 0.4, -0.4286, -0.4545, -0.7333, -0.73333, 0.0, -0.8286, -0.8545],
[33.3333, 33.3333, 20.0, 14.2857, 9.0909, 33.3333, 33.3333, 20.0, 14.2857, 9.0909, 33.3333, 33.3333, 20.0, 14.2857, 9.0909],
[0.0133, 0.0133, 0.16, -0.00571, -0.0109, -0.0667, -0.0667, 0.08, -0.0857, -0.0909, -0.1467, -0.14667, 0.0, -0.1657, -0.1709],
[6.6667, 6.6667, 4.0, 2.8571, 1.8182, 6.6667, 6.6667, 4.0, 2.8571, 1.8182, 6.6667, 6.6667, 4.0, 2.8571, 1.8182],
),
(
0.001,
[-0.0545, -0.0545, 0.8762, -0.0838, -0.0902, -0.4545, -0.4545, 0.4762, -0.4839, -0.4902, -0.8545, -0.8545, 0.0762, -0.8839, -0.8902],
[45.4545, 45.4545, 23.8095, 16.1290, 9.8039, 45.4545, 45.4545, 23.8095, 16.1290, 9.8039, 45.4545, 45.4545, 23.8095, 16.1290, 9.8039],
[-0.0109, -0.0109, 0.1752, -0.01676, -0.01804, -0.0909, -0.0909, 0.0952, -0.09678, -0.09804, -0.1709, -0.1709, 0.01524, -0.17678, -0.17804],
[9.0909, 9.0909, 4.7619, 3.2258, 1.9608, 9.0909, 9.0909, 4.7619, 3.2258, 1.9608, 9.0909, 9.0909, 4.7619, 3.2258, 1.9608],
),
],
)
Expand Down

1 comment on commit d2c6d1f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
91 0 💤 0 ❌ 0 🔥 6.012s ⏱️

Please sign in to comment.