-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorchfromcpptestwiththreadedloop.cpp
168 lines (141 loc) · 5.61 KB
/
pytorchfromcpptestwiththreadedloop.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
// Copyright 2023 Tom Vercauteren. All rights reserved.
//
// This software is licensed under the Apache 2 License.
// See the LICENSE file for details.
#include <pybind11/embed.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <chrono>
#include <future>
#include <iostream>
#include <thread>
namespace py = pybind11;
torch::Device getbesttorchdevice() {
torch::Device device = torch::kCPU;
if (torch::cuda::is_available()) {
std::cout << "CUDA is available. Running on GPU." << std::endl;
device = torch::kCUDA;
}
#if TORCH_VERSION_MAJOR >= 2
// See https://github.com/pytorch/pytorch/issues/96425
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR > 0
if (torch::mps::is_available()) {
#else
if (at::hasMPS()) {
#endif
std::cout << "MPS is available. Running on GPU." << std::endl;
device = torch::kMPS;
}
#endif
return device;
}
py::module setupandloadpymodule() {
// Add source dir to python path
py::module sys = py::module::import("sys");
sys.attr("path").attr("insert")(1, CUSTOM_MODULE_SYS_PATH);
// Add torch from libtorch dir to python path
sys.attr("path").attr("insert")(1, CUSTOM_TORCH_SYS_PATH);
// Load custom python module
py::module pycustomtorchmodule = py::module::import("pycustomtorchmodule");
std::cout << "Custom python module loaded from " << CUSTOM_MODULE_SYS_PATH
<< std::endl;
// py::module pytorchmodule = py::module::import("torch");
// std::cout << "Python torch module loaded from " << CUSTOM_TORCH_SYS_PATH
// << std::endl;*/
return pycustomtorchmodule;
}
int hybridcall() {
// using namespace std::chrono_literals;
using std::chrono_literals::operator""ms;
std::cout << "Starting test from c++" << std::endl;
std::cout << "PyTorch version: " << TORCH_VERSION << std::endl;
torch::Device device = getbesttorchdevice();
torch::Tensor globaltensor = torch::arange(0, 3, device);
std::cout << "Example globaltensor (c++ side):" << std::endl
<< globaltensor << std::endl;
py::scoped_interpreter guard{};
// As per
// https://github.com/pybind/pybind11/discussions/4673#discussioncomment-5939343
// the python interpreter has to be alive to properly catch exceptions
// stemming from python See also
// https://pybind11.readthedocs.io/en/stable/reference.html#_CPPv4NK17error_already_set4whatEv
try {
py::module pycustomtorchmodule = setupandloadpymodule();
// Add variables to the custom module
pycustomtorchmodule.attr("globalval") = globaltensor;
// Prepare thread. Don't use a thread pool to avoid messing up with th GIL
py::gil_scoped_release no_gil;
std::future<void> pyfuture;
// Define the function that will call python in the thread
// Silence the warning about using a reference for the tensor as this
// leds to crashs and tensors use shallow copy anyway
// NOLINTNEXTLINE(performance-unnecessary-value-param)
auto wrappedop = [](py::module& custommodule, torch::Tensor inputtensor) {
try {
py::gil_scoped_acquire gil;
py::function pyop = custommodule.attr("opwithglobal");
torch::Tensor pyretval = pyop(inputtensor).cast<torch::Tensor>();
// Simulate tie consuming task
std::this_thread::sleep_for(50ms);
std::cout << "Python return value (in c++) " << std::endl
<< pyretval << std::endl;
} catch (const py::error_already_set& e) {
std::cout << "Caught py::error_already_set exception in wrappedop"
<< e.what() << std::endl;
return;
} catch (const std::exception& e) {
// standard exceptions
std::cout << "Caught std::exception in wrappedop: " << e.what()
<< std::endl;
return;
} catch (...) {
// everything else
std::cout << "Caught unknown exception in wrappedop" << std::endl;
return;
}
};
// Simulate a fast loop running on c++
for (int i = 0; i < 100; ++i) {
// Let's assume we create a tensor
torch::Tensor localtensor = torch::arange(i, i + 3, device);
// std::cout << "Example localtensor (c++ side):" << std::endl
// << localtensor << std::endl;
// Simulate fast (but not immediate) processing
std::this_thread::sleep_for(5ms);
// Use wait_for() with zero milliseconds to check thread status.
bool threadready = ((!pyfuture.valid()) || (pyfuture.wait_for(0ms) ==
std::future_status::ready));
if (threadready) {
// clone the tensor
torch::Tensor localtensorclone = torch::clone(localtensor);
// Run Python op
std::cout << "Launching thread at iter " << i << " with cloned tensor "
<< std::endl
<< localtensorclone << std::endl;
pyfuture = std::async(std::launch::async, wrappedop,
std::ref(pycustomtorchmodule), localtensorclone);
} else {
std::cout << "Thread is busy at iter " << i << std::endl;
}
}
pyfuture.get();
} catch (const py::error_already_set& e) {
std::cout << "Rethrowing py::error_already_set exception" << std::endl;
throw std::runtime_error(e.what());
}
return EXIT_SUCCESS;
}
int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) {
try {
return hybridcall();
} catch (const std::exception& e) {
// standard exceptions
std::cout << "Caught std::exception in main: " << e.what() << std::endl;
return EXIT_FAILURE;
} catch (...) {
// everything else
std::cout << "Caught unknown exception in main" << std::endl;
return EXIT_FAILURE;
}
}