Skip to content

Commit

Permalink
add weight in dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Oct 24, 2024
1 parent 0347905 commit 3a7c2a6
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion mqboost/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Optional

import numpy as np
import pandas as pd

from mqboost.base import (
Expand All @@ -9,6 +10,7 @@
FittingException,
ModelName,
TypeName,
WeightLike,
XdataLike,
YdataLike,
)
Expand All @@ -32,6 +34,7 @@ class MQDataset:
Must be in ascending order and contain no duplicates.
data (pd.DataFrame | pd.Series | np.ndarray): The input features.
label (pd.Series | np.ndarray): The target labels (if provided).
weight (list[float] | list[int] | np.ndarray | pd.Series): Weight for each instance (if provided).
model (str): The model type (LightGBM or XGBoost).
reference (MQBoost | None): Reference dataset for label encoding and label mean.
Expand All @@ -52,6 +55,7 @@ def __init__(
alphas: AlphaLike,
data: XdataLike,
label: YdataLike | None = None,
weight: WeightLike | None = None,
model: str = ModelName.lightgbm.value,
reference: Optional["MQDataset"] = None,
) -> None:
Expand Down Expand Up @@ -85,6 +89,10 @@ def __init__(
self._label = prepare_y(y=label - self._label_mean, alphas=self._alphas)
self._is_none_label = False

if weight is not None:
_weight = np.array(weight) if not isinstance(weight, np.ndarray) else weight
self._weight = prepare_y(y=_weight, alphas=self._alphas)

@property
def train_dtype(self) -> Callable:
"""Get the data type function for training data."""
Expand Down Expand Up @@ -123,14 +131,20 @@ def label(self) -> pd.DataFrame:

@property
def label_mean(self) -> float:
"""Get the label mean."""
self.__label_available()
return self._label_mean

@property
def weight(self) -> WeightLike | None:
"""Get the weights."""
return getattr(self, "_weight", None)

@property
def dtrain(self) -> DtrainLike:
"""Get the training data in the required format for the model."""
self.__label_available()
return self._train_dtype(data=self._data, label=self._label)
return self._train_dtype(data=self._data, label=self._label, weight=self.weight)

@property
def dpredict(self) -> DtrainLike | Callable:
Expand Down

0 comments on commit 3a7c2a6

Please sign in to comment.