Skip to content

Commit

Permalink
Add example
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Dec 6, 2024
1 parent ed290d0 commit 9f07b1a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
3 changes: 3 additions & 0 deletions examples/sparse/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Sparse data

This example trains a youtube comment spam classifier on a sparse dataset. The comments as raw strings are converted to a sparse matrix of word counts using the `CountVectorizer` from scikit-learn.
47 changes: 47 additions & 0 deletions examples/sparse/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split

import legateboost as lb

# Alberto, T. & Lochter, J. (2015). YouTube Spam Collection [Dataset].
# UCI Machine Learning Repository. https://doi.org/10.24432/C58885.
dataset_names = [
"youtube-spam-psy",
"youtube-spam-shakira",
"youtube-spam-lmfao",
"youtube-spam-eminem",
"youtube-spam-katyperry",
]
X = []
for dataset_name in dataset_names:
dataset = fetch_openml(name=dataset_name, as_frame=True)
X.append(dataset.data)

X = pd.concat(X)
y = X["CLASS"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
vectorizer = CountVectorizer()
X_train_vectorized = vectorizer.fit_transform(X_train["CONTENT"])
X_test_vectorized = vectorizer.transform(X_test["CONTENT"])

model = lb.LBClassifier().fit(
X_train_vectorized, y_train, eval_set=[(X_test_vectorized, y_test)]
)


def evaluate_comment(comment):
print("Comment: {}".format(comment))
print(
"Probability of spam: {}".format(
model.predict_proba(vectorizer.transform([comment]))[0, 1]
)
)


evaluate_comment(X_test.iloc[15]["CONTENT"])
evaluate_comment(X_test.iloc[3]["CONTENT"])
evaluate_comment("Your text here")
4 changes: 2 additions & 2 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,8 @@ struct build_tree_csr_fn {
auto [h, h_shape, h_accessor] = GetInputStore<double, 3>(context.input(4).data());

auto num_rows = std::max<int64_t>(X_offsets_shape.hi[0] - X_offsets_shape.lo[0] + 1, 0);
auto num_outputs = g_shape.hi[1] - g_shape.lo[1] + 1;
EXPECT(g_shape.lo[1] == 0, "Outputs should not be split between workers.");
auto num_outputs = g_shape.hi[2] - g_shape.lo[2] + 1;
EXPECT(g_shape.lo[2] == 0, "Outputs should not be split between workers.");

// Scalars
auto max_depth = context.scalars().at(0).value<int>();
Expand Down

0 comments on commit 9f07b1a

Please sign in to comment.