Skip to content

Commit

Permalink
Expose the global openmp thread for the dask interface. (#11175)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 21, 2025
1 parent b57840f commit 0d7821f
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 8 deletions.
5 changes: 5 additions & 0 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ The following parameters can be set in the global scope, using :py:func:`xgboost
(compiled) with the RMM plugin enabled. Valid values are ``true`` and ``false``. See
:doc:`/python/rmm-examples/index` for details.

* ``nthread``: Set the global number of threads for OpenMP. Use this only when you need to
override some OpenMP-related environment variables like ``OMP_NUM_THREADS``. Otherwise,
the ``nthread`` parameter from the Booster and the DMatrix should be preferred as the
former sets the global variable and might cause conflicts with other libraries.

******************
General Parameters
******************
Expand Down
2 changes: 2 additions & 0 deletions include/xgboost/global_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace xgboost {
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
std::int32_t verbosity{1};
bool use_rmm{false};
// This is not a dmlc parameter to avoid conflict with the context class.
std::int32_t nthread{0};
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
DMLC_DECLARE_FIELD(verbosity)
.set_range(0, 3)
Expand Down
1 change: 1 addition & 0 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ def do_train( # pylint: disable=too-many-positional-arguments
local_param.update({"nthread": n_threads, "n_jobs": n_threads})

local_history: TrainingCallback.EvalsLog = {}
global_config.update({"nthread": n_threads})

with CommunicatorContext(**coll_args), config.config_context(**global_config):
Xy, evals = _get_dmatrices(
Expand Down
18 changes: 16 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2024, XGBoost Contributors
* Copyright 2014-2025, XGBoost Contributors
*/
#include "xgboost/c_api.h"

Expand Down Expand Up @@ -143,7 +143,19 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
xgboost_CHECK_C_ARG_PTR(json_str);
Json config{Json::Load(StringView{json_str})};

for (auto& items : get<Object>(config)) {
// handle nthread, it's not a dmlc parameter.
auto& obj = get<Object>(config);
auto it = obj.find("nthread");
if (it != obj.cend()) {
auto nthread = OptionalArg<Integer>(config, "nthread", Integer::Int{0});
if (nthread > 0) {
omp_set_num_threads(nthread);
GlobalConfigThreadLocalStore::Get()->nthread = nthread;
}
get<Object>(config).erase("nthread");
}

for (auto &items : obj) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kInteger: {
items.second = String{std::to_string(get<Integer const>(items.second))};
Expand Down Expand Up @@ -183,6 +195,7 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
}
LOG(FATAL) << ss.str() << " }";
}

API_END();
}

Expand Down Expand Up @@ -216,6 +229,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
}
}

config["nthread"] = GlobalConfigThreadLocalStore::Get()->nthread;
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
Json::Dump(config, &local.ret_str);

Expand Down
10 changes: 8 additions & 2 deletions src/global_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
* \author Hyunsu Cho
*/

#include <dmlc/thread_local.h>
#include "xgboost/global_config.h"

#include <dmlc/thread_local.h>

namespace xgboost {
DMLC_REGISTER_PARAMETER(GlobalConfiguration);

void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
void InitNewThread::operator()() const {
*GlobalConfigThreadLocalStore::Get() = config;
if (config.nthread > 0) {
omp_set_num_threads(config.nthread);
}
}
} // namespace xgboost
21 changes: 17 additions & 4 deletions tests/cpp/test_global_config.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
/**
* Copyright 2020-2025, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <xgboost/global_config.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/global_config.h>

namespace xgboost {

TEST(GlobalConfiguration, Verbosity) {
// Configure verbosity via global configuration
Json config{JsonObject()};
Expand All @@ -15,7 +18,7 @@ TEST(GlobalConfiguration, Verbosity) {
EXPECT_EQ(ConsoleLogger::GlobalVerbosity(), ConsoleLogger::LogVerbosity::kSilent);
EXPECT_NE(ConsoleLogger::LogVerbosity::kSilent, ConsoleLogger::DefaultVerbosity());
// GetConfig() should also return updated verbosity
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
Json current_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
}

Expand All @@ -25,8 +28,18 @@ TEST(GlobalConfiguration, UseRMM) {
auto& global_config = *GlobalConfigThreadLocalStore::Get();
FromJson(config, &global_config);
// GetConfig() should return updated use_rmm flag
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
Json current_config{ToJson(*GlobalConfigThreadLocalStore::Get())};
EXPECT_EQ(get<String>(current_config["use_rmm"]), "1");
}

TEST(GlobalConfiguration, Threads) {
char const* config;
ASSERT_EQ(XGBGetGlobalConfig(&config), 0);
auto jconfig = Json::Load(config);
auto nthread = get<Integer const>(jconfig["nthread"]);
ASSERT_LE(nthread, 0);
auto n_omp = omp_get_num_threads();
ASSERT_EQ(XGBSetGlobalConfig(config), 0);
ASSERT_EQ(n_omp, omp_get_num_threads());
}
} // namespace xgboost
5 changes: 5 additions & 0 deletions tests/python/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,8 @@ def test_thread_safety():

for f in futures:
f.result()


def test_nthread() -> None:
config = xgb.get_config()
assert config["nthread"] == 0

0 comments on commit 0d7821f

Please sign in to comment.