diff --git a/.clang-tidy b/.clang-tidy index ac0619f02e..a815b783f2 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -15,7 +15,9 @@ Checks: > WarningsAsErrors: > *, - -clang-diagnostic-unused-command-line-argument + -clang-diagnostic-unused-command-line-argument, + -Wno-ignored-optimization-argument, + -Qunused-arguments #WarningsAsErrors: '*' HeaderFilterRegex: '.*\/include\/morpheus\/.*' diff --git a/examples/log_parsing/inference.py b/examples/log_parsing/inference.py index 97d162324c..486e5aada2 100644 --- a/examples/log_parsing/inference.py +++ b/examples/log_parsing/inference.py @@ -180,4 +180,12 @@ def _convert_one_response(output: MultiResponseMessage, inf: MultiInferenceNLPMe return MultiResponseMessage.from_message(inf, memory=memory, offset=inf.offset, count=inf.mess_count) def _get_inference_worker(self, inf_queue: ProducerConsumerQueue) -> TritonInferenceLogParsing: - return TritonInferenceLogParsing(inf_queue=inf_queue, c=self._config, **self._kwargs) + return TritonInferenceLogParsing(inf_queue=inf_queue, + c=self._config, + server_url=self._server_url, + model_name=self._model_name, + force_convert_inputs=self._force_convert_inputs, + use_shared_memory=self._use_shared_memory, + input_mapping=self._input_mapping, + output_mapping=self._output_mapping, + needs_logits=self._needs_logits) diff --git a/morpheus/_lib/cmake/libmorpheus.cmake b/morpheus/_lib/cmake/libmorpheus.cmake index 4a8915f621..388337cadd 100644 --- a/morpheus/_lib/cmake/libmorpheus.cmake +++ b/morpheus/_lib/cmake/libmorpheus.cmake @@ -71,6 +71,7 @@ add_library(morpheus src/stages/file_source.cpp src/stages/filter_detection.cpp src/stages/http_server_source_stage.cpp + src/stages/inference_client_stage.cpp src/stages/kafka_source.cpp src/stages/preprocess_fil.cpp src/stages/preprocess_nlp.cpp diff --git a/morpheus/_lib/include/morpheus/stages/inference_client_stage.hpp b/morpheus/_lib/include/morpheus/stages/inference_client_stage.hpp new file mode 100644 index 0000000000..24d142184d --- /dev/null +++ b/morpheus/_lib/include/morpheus/stages/inference_client_stage.hpp @@ -0,0 +1,168 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "morpheus/export.h" +#include "morpheus/messages/multi_inference.hpp" +#include "morpheus/messages/multi_response.hpp" +#include "morpheus/types.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace morpheus { + +struct MORPHEUS_EXPORT TensorModelMapping +{ + /** + * @brief The field name to/from the model used for mapping + */ + std::string model_field_name; + + /** + * @brief The field name to/from the tensor used for mapping + */ + std::string tensor_field_name; +}; + +class MORPHEUS_EXPORT IInferenceClientSession +{ + public: + virtual ~IInferenceClientSession() = default; + /** + @brief Gets the inference input mappings + */ + virtual std::vector get_input_mappings(std::vector input_map_overrides) = 0; + + /** + @brief Gets the inference output mappings + */ + virtual std::vector get_output_mappings( + std::vector output_map_overrides) = 0; + + /** + @brief Invokes a single tensor inference + */ + virtual mrc::coroutines::Task infer(TensorMap&& inputs) = 0; +}; + +class MORPHEUS_EXPORT IInferenceClient +{ + public: + virtual ~IInferenceClient() = default; + /** + @brief Creates an inference session. + */ + virtual std::unique_ptr create_session() = 0; +}; + +/** + * @addtogroup stages + * @{ + * @file + */ + +/** + * @brief Perform inference with Triton Inference Server. + * This class specifies which inference implementation category (Ex: NLP/FIL) is needed for inferencing. + */ +class MORPHEUS_EXPORT InferenceClientStage + : public mrc::pymrc::AsyncioRunnable, std::shared_ptr> +{ + public: + using sink_type_t = std::shared_ptr; + using source_type_t = std::shared_ptr; + + /** + * @brief Construct a new Inference Client Stage object + * + * @param client : Inference client instance. + * @param model_name : Name of the model specifies which model can handle the inference requests that are sent to + * Triton inference + * @param needs_logits : Determines if logits are required. + * @param inout_mapping : Dictionary used to map pipeline input/output names to Triton input/output names. Use this + * if the Morpheus names do not match the model. + */ + InferenceClientStage(std::unique_ptr&& client, + std::string model_name, + bool needs_logits, + std::vector input_mapping, + std::vector output_mapping); + + /** + * Process a single MultiInferenceMessage by running the constructor-provided inference client against it's Tensor, + * and yields the result as a MultiResponseMessage + */ + mrc::coroutines::AsyncGenerator> on_data( + std::shared_ptr&& data, std::shared_ptr on) override; + + private: + std::string m_model_name; + std::shared_ptr m_client; + std::shared_ptr m_session; + bool m_needs_logits{true}; + std::vector m_input_mapping; + std::vector m_output_mapping; + std::mutex m_session_mutex; + + int32_t m_retry_max = 10; +}; + +/****** InferenceClientStageInferenceProxy******************/ +/** + * @brief Interface proxy, used to insulate python bindings. + */ +struct MORPHEUS_EXPORT InferenceClientStageInterfaceProxy +{ + /** + * @brief Create and initialize a InferenceClientStage, and return the result + * + * @param builder : Pipeline context object reference + * @param name : Name of a stage reference + * @param model_name : Name of the model specifies which model can handle the inference requests that are sent to + * Triton inference + * @param server_url : Triton server URL. + * @param needs_logits : Determines if logits are required. + * @param inout_mapping : Dictionary used to map pipeline input/output names to Triton input/output names. Use this + * if the Morpheus names do not match the model. + * @return std::shared_ptr> + */ + static std::shared_ptr> init( + mrc::segment::Builder& builder, + const std::string& name, + std::string model_name, + std::string server_url, + bool needs_logits, + std::map input_mapping, + std::map output_mapping); +}; +/** @} */ // end of group + +} // namespace morpheus diff --git a/morpheus/_lib/include/morpheus/stages/triton_inference.hpp b/morpheus/_lib/include/morpheus/stages/triton_inference.hpp index 98c78a0910..923a75e2b7 100644 --- a/morpheus/_lib/include/morpheus/stages/triton_inference.hpp +++ b/morpheus/_lib/include/morpheus/stages/triton_inference.hpp @@ -17,141 +17,175 @@ #pragma once -#include "morpheus/messages/multi_inference.hpp" -#include "morpheus/messages/multi_response.hpp" // for MultiResponseMessage +#include "morpheus/export.h" #include "morpheus/objects/triton_in_out.hpp" +#include "morpheus/stages/inference_client_stage.hpp" #include "morpheus/types.hpp" -#include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include // for apply, make_subscriber, observable_member, is_on_error<>::not_void, is_on_next_of<>::not_void, from +#include + +#include // IWYU pragma: no_include "rxcpp/sources/rx-iterate.hpp" -#include #include #include -#include #include namespace morpheus { /****** Component public implementations *******************/ /****** InferenceClientStage********************************/ -/** - * @addtogroup stages - * @{ - * @file - */ +struct MORPHEUS_EXPORT TritonInferInput +{ + /** + * @brief The name of the triton inference input + */ + std::string name; -#pragma GCC visibility push(default) -/** - * @brief Perform inference with Triton Inference Server. - * This class specifies which inference implementation category (Ex: NLP/FIL) is needed for inferencing. - */ -class InferenceClientStage - : public mrc::pymrc::PythonNode, std::shared_ptr> + /** + * @brief The shape of the triton inference input + */ + std::vector shape; + + /** + * @brief The type of the triton inference input + */ + std::string type; + + /** + * @brief The triton inference input data + */ + std::vector data; +}; + +struct MORPHEUS_EXPORT TritonInferRequestedOutput +{ + std::string name; +}; + +class MORPHEUS_EXPORT ITritonClient { public: - using base_t = - mrc::pymrc::PythonNode, std::shared_ptr>; - using typename base_t::sink_type_t; - using typename base_t::source_type_t; - using typename base_t::subscribe_fn_t; - - /** - * @brief Construct a new Inference Client Stage object - * - * @param model_name : Name of the model specifies which model can handle the inference requests that are sent to - * Triton inference - * @param server_url : Triton server URL. - * @param force_convert_inputs : Instructs the stage to convert the incoming data to the same format that Triton is - * expecting. If set to False, data will only be converted if it would not result in the loss of data. - * @param use_shared_memory : Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using - * CUDA IPC reduces network transfer time but requires that Morpheus and Triton are located on the same machine. - * @param needs_logits : Determines if logits are required. - * @param inout_mapping : Dictionary used to map pipeline input/output names to Triton input/output names. Use this - * if the Morpheus names do not match the model. - */ - InferenceClientStage(std::string model_name, - std::string server_url, - bool force_convert_inputs, - bool use_shared_memory, - bool needs_logits, - std::map inout_mapping = {}); + virtual ~ITritonClient() = default; + /** + * @brief Checks if Triton Server is live + */ + virtual triton::client::Error is_server_live(bool* live) = 0; + + /** + * @brief Checks if Triton Server is ready + */ + virtual triton::client::Error is_server_ready(bool* ready) = 0; + + /** + * @brief Checks if the given model is ready + */ + virtual triton::client::Error is_model_ready(bool* ready, std::string& model_name) = 0; + + /** + * @brief Gets metadata for the given model + */ + virtual triton::client::Error model_metadata(std::string* model_metadata, std::string& model_name) = 0; + + /** + * @brief Gets the config for the given model + */ + virtual triton::client::Error model_config(std::string* model_config, std::string& model_name) = 0; + + /** + * @brief Runs Triton Server inference given the model options, inputs, and outputs + */ + virtual triton::client::Error async_infer(triton::client::InferenceServerHttpClient::OnCompleteFn callback, + const triton::client::InferOptions& options, + const std::vector& inputs, + const std::vector& outputs) = 0; +}; + +class MORPHEUS_EXPORT HttpTritonClient : public ITritonClient +{ private: + std::unique_ptr m_client; + + public: + HttpTritonClient(std::string server_url); + /** - * TODO(Documentation) + * @brief Checks if Triton Server is live using HTTP protocal */ - bool is_default_grpc_port(std::string& server_url); + triton::client::Error is_server_live(bool* live) override; /** - * TODO(Documentation) + * @brief Checks if Triton Server is ready using HTTP protocal */ - void connect_with_server(); + triton::client::Error is_server_ready(bool* ready) override; /** - * TODO(Documentation) + * @brief Checks if the given model is ready using HTTP protocal */ - subscribe_fn_t build_operator(); + triton::client::Error is_model_ready(bool* ready, std::string& model_name) override; + + /** + * @brief Gets the config for the given model using HTTP protocal + */ + triton::client::Error model_config(std::string* model_config, std::string& model_name) override; + + /** + * @brief Gets metadata for the given model using HTTP protocal + */ + triton::client::Error model_metadata(std::string* model_metadata, std::string& model_name) override; + + /** + * @brief Runs Triton Server inference given the model options, inputs, and outputs, using HTTP protocal + */ + triton::client::Error async_infer(triton::client::InferenceServerHttpClient::OnCompleteFn callback, + const triton::client::InferOptions& options, + const std::vector& inputs, + const std::vector& outputs) override; +}; +class MORPHEUS_EXPORT TritonInferenceClientSession : public IInferenceClientSession +{ + private: std::string m_model_name; - std::string m_server_url; - bool m_force_convert_inputs; - bool m_use_shared_memory; - bool m_needs_logits{true}; - std::map m_inout_mapping; - - // Below are settings created during handshake with server - // std::shared_ptr m_client; + TensorIndex m_max_batch_size = -1; std::vector m_model_inputs; std::vector m_model_outputs; - triton::client::InferOptions m_options; - TensorIndex m_max_batch_size{-1}; + std::shared_ptr m_client; + + public: + TritonInferenceClientSession(std::shared_ptr client, std::string model_name); + + /** + @brief Gets the inference input mappings for Triton + */ + std::vector get_input_mappings(std::vector input_map_overrides) override; + + /** + @brief Gets the inference output mappings for Triton + */ + std::vector get_output_mappings(std::vector output_map_overrides) override; + + /** + @brief Invokes a single tensor inference using the constructor-provided ITritonClient + */ + mrc::coroutines::Task infer(TensorMap&& inputs) override; }; -/****** InferenceClientStageInferenceProxy******************/ -/** - * @brief Interface proxy, used to insulate python bindings. - */ -struct InferenceClientStageInterfaceProxy +class MORPHEUS_EXPORT TritonInferenceClient : public IInferenceClient { + private: + std::shared_ptr m_client; + std::string m_model_name; + + public: + TritonInferenceClient(std::unique_ptr&& client, std::string model_name); + /** - * @brief Create and initialize a InferenceClientStage, and return the result - * - * @param builder : Pipeline context object reference - * @param name : Name of a stage reference - * @param model_name : Name of the model specifies which model can handle the inference requests that are sent to - * Triton inference - * @param server_url : Triton server URL. - * @param force_convert_inputs : Instructs the stage to convert the incoming data to the same format that Triton is - * expecting. If set to False, data will only be converted if it would not result in the loss of data. - * @param use_shared_memory : Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using - * CUDA IPC reduces network transfer time but requires that Morpheus and Triton are located on the same machine. - * @param needs_logits : Determines if logits are required. - * @param inout_mapping : Dictionary used to map pipeline input/output names to Triton input/output names. Use this - * if the Morpheus names do not match the model. - * @return std::shared_ptr> - */ - static std::shared_ptr> init( - mrc::segment::Builder& builder, - const std::string& name, - std::string model_name, - std::string server_url, - bool force_convert_inputs, - bool use_shared_memory, - bool needs_logits, - std::map inout_mapping); + @brief Creates a TritonInferenceClientSession + */ + std::unique_ptr create_session() override; }; -#pragma GCC visibility pop -/** @} */ // end of group + } // namespace morpheus diff --git a/morpheus/_lib/llm/include/py_llm_engine.hpp b/morpheus/_lib/llm/include/py_llm_engine.hpp index 143990d45c..dcf3f18a67 100644 --- a/morpheus/_lib/llm/include/py_llm_engine.hpp +++ b/morpheus/_lib/llm/include/py_llm_engine.hpp @@ -19,9 +19,9 @@ #include "py_llm_node.hpp" +#include "morpheus/llm/fwd.hpp" #include "morpheus/llm/input_map.hpp" #include "morpheus/llm/llm_engine.hpp" -#include "morpheus/llm/llm_task_handler.hpp" #include diff --git a/morpheus/_lib/llm/src/py_llm_engine.cpp b/morpheus/_lib/llm/src/py_llm_engine.cpp index ff82fdec86..5e6267456d 100644 --- a/morpheus/_lib/llm/src/py_llm_engine.cpp +++ b/morpheus/_lib/llm/src/py_llm_engine.cpp @@ -19,6 +19,8 @@ #include "py_llm_task_handler.hpp" +#include "morpheus/llm/llm_task_handler.hpp" + #include #include diff --git a/morpheus/_lib/messages/module.cpp b/morpheus/_lib/messages/module.cpp index b5b84ee071..453d691082 100644 --- a/morpheus/_lib/messages/module.cpp +++ b/morpheus/_lib/messages/module.cpp @@ -101,6 +101,14 @@ PYBIND11_MODULE(messages, _module) mrc::edge::EdgeConnector, mrc::pymrc::PyObjectHolder>::register_converter(); mrc::edge::EdgeConnector>::register_converter(); + mrc::edge::EdgeConnector, mrc::pymrc::PyObjectHolder>::register_converter(); + mrc::edge::EdgeConnector>::register_converter(); + + mrc::edge::EdgeConnector, + mrc::pymrc::PyObjectHolder>::register_converter(); + mrc::edge::EdgeConnector>::register_converter(); + // EdgeConnectors for derived classes of MultiMessage to MultiMessage mrc::edge::EdgeConnector, std::shared_ptr>::register_converter(); diff --git a/morpheus/_lib/src/stages/inference_client_stage.cpp b/morpheus/_lib/src/stages/inference_client_stage.cpp new file mode 100644 index 0000000000..069ccd557e --- /dev/null +++ b/morpheus/_lib/src/stages/inference_client_stage.cpp @@ -0,0 +1,293 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "morpheus/stages/inference_client_stage.hpp" + +#include "morpheus/messages/memory/response_memory.hpp" +#include "morpheus/messages/memory/tensor_memory.hpp" +#include "morpheus/objects/dev_mem_info.hpp" +#include "morpheus/objects/dtype.hpp" +#include "morpheus/objects/tensor.hpp" +#include "morpheus/objects/tensor_object.hpp" +#include "morpheus/stages/triton_inference.hpp" +#include "morpheus/utilities/matx_util.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +static morpheus::ShapeType get_seq_ids(const morpheus::InferenceClientStage::sink_type_t& message) +{ + // Take a copy of the sequence Ids allowing us to map rows in the response to rows in the dataframe + // The output tensors we store in `reponse_memory` will all be of the same length as the the + // dataframe. seq_ids has three columns, but we are only interested in the first column. + auto seq_ids = message->get_input("seq_ids"); + const auto item_size = seq_ids.dtype().item_size(); + + morpheus::ShapeType host_seq_ids(message->count); + MRC_CHECK_CUDA(cudaMemcpy2D(host_seq_ids.data(), + item_size, + seq_ids.data(), + seq_ids.stride(0) * item_size, + item_size, + host_seq_ids.size(), + cudaMemcpyDeviceToHost)); + + return host_seq_ids; +} + +static void reduce_outputs(const morpheus::InferenceClientStage::sink_type_t& x, morpheus::TensorMap& output_tensors) +{ + // When our tensor lengths are longer than our dataframe we will need to use the seq_ids array to + // lookup how the values should map back into the dataframe. + auto host_seq_ids = get_seq_ids(x); + + for (auto& mapping : output_tensors) + { + auto& output_tensor = mapping.second; + + morpheus::ShapeType shape = output_tensor.get_shape(); + morpheus::ShapeType stride = output_tensor.get_stride(); + + morpheus::ShapeType reduced_shape{shape}; + reduced_shape[0] = x->mess_count; + + auto reduced_buffer = morpheus::MatxUtil::reduce_max( + morpheus::DevMemInfo{ + output_tensor.data(), output_tensor.dtype(), output_tensor.get_memory(), shape, stride}, + host_seq_ids, + 0, + reduced_shape); + + output_tensor.swap( + morpheus::Tensor::create(std::move(reduced_buffer), output_tensor.dtype(), reduced_shape, stride, 0)); + } +} + +static void apply_logits(morpheus::TensorMap& output_tensors) +{ + for (auto& mapping : output_tensors) + { + auto& output_tensor = mapping.second; + + auto shape = output_tensor.get_shape(); + auto stride = output_tensor.get_stride(); + + auto output_buffer = morpheus::MatxUtil::logits(morpheus::DevMemInfo{ + output_tensor.data(), output_tensor.dtype(), output_tensor.get_memory(), shape, stride}); + + // For logits the input and output shapes will be the same + output_tensor.swap(morpheus::Tensor::create(std::move(output_buffer), output_tensor.dtype(), shape, stride, 0)); + } +} + +} // namespace + +namespace morpheus { + +InferenceClientStage::InferenceClientStage(std::unique_ptr&& client, + std::string model_name, + bool needs_logits, + std::vector input_mapping, + std::vector output_mapping) : + m_model_name(std::move(model_name)), + m_client(std::move(client)), + m_needs_logits(needs_logits), + m_input_mapping(std::move(input_mapping)), + m_output_mapping(std::move(output_mapping)) +{} + +struct ExponentialBackoff +{ + std::shared_ptr m_on; + std::chrono::milliseconds m_delay; + std::chrono::milliseconds m_delay_max; + + ExponentialBackoff(std::shared_ptr on, + std::chrono::milliseconds delay_initial, + std::chrono::milliseconds delay_max) : + m_on(std::move(on)), + m_delay(delay_initial), + m_delay_max(delay_max) + {} + + mrc::coroutines::Task<> yield() + { + if (m_delay > m_delay_max) + { + m_delay = m_delay_max; + } + + co_await m_on->yield_for(m_delay); + + m_delay *= 2; + } +}; + +mrc::coroutines::AsyncGenerator> InferenceClientStage::on_data( + std::shared_ptr&& x, std::shared_ptr on) +{ + int32_t retry_count = 0; + + using namespace std::chrono_literals; + + auto backoff = ExponentialBackoff(on, 100ms, 4000ms); + + while (true) + { + auto message_session = m_session; + + try + { + // Using the `count` which is the number of rows in the inference tensors. We will check later if this + // doesn't match the number of rows in the dataframe (`mess_count`). This happens when the size of the + // input is too large and needs to be broken up in chunks in the pre-process stage. When this is the + // case we will reduce the rows in the response outputs such that we have a single response for each + // row int he dataframe. + // TensorMap output_tensors; + // buffer_map_t output_buffers; + + if (message_session == nullptr) + { + auto lock = std::unique_lock(m_session_mutex); + + if (m_session == nullptr) + { + m_session = m_client->create_session(); + } + + message_session = m_session; + } + + // We want to prevent entering this section of code if the session is being reset, but we also want this + // section of code to be entered simultanously by multiple coroutines. To accomplish this, we use a shared + // lock instead of a unique lock. + + TensorMap model_input_tensors; + + for (auto mapping : message_session->get_input_mappings(m_input_mapping)) + { + if (x->memory->has_tensor(mapping.tensor_field_name)) + { + model_input_tensors[mapping.model_field_name].swap(x->get_input(mapping.tensor_field_name)); + } + } + + auto model_output_tensors = co_await message_session->infer(std::move(model_input_tensors)); + + co_await on->yield(); + + if (x->mess_count != x->count) + { + reduce_outputs(x, model_output_tensors); + } + + // If we need to do logits, do that here + if (m_needs_logits) + { + apply_logits(model_output_tensors); + } + + TensorMap output_tensor_map; + + for (auto mapping : message_session->get_output_mappings(m_output_mapping)) + { + auto pos = model_output_tensors.find(mapping.model_field_name); + + if (pos != model_output_tensors.end()) + { + output_tensor_map[mapping.tensor_field_name].swap( + std::move(model_output_tensors[mapping.model_field_name])); + + model_output_tensors.erase(pos); + } + } + + // Final output of all mini-batches + auto response_mem = std::make_shared(x->mess_count, std::move(output_tensor_map)); + + auto response = std::make_shared( + x->meta, x->mess_offset, x->mess_count, std::move(response_mem), 0, response_mem->count); + + co_yield std::move(response); + + co_return; + + } catch (...) + { + auto lock = std::unique_lock(m_session_mutex); + + if (m_session == message_session) + { + m_session.reset(); + } + + if (m_retry_max >= 0 and ++retry_count > m_retry_max) + { + throw; + } + + LOG(WARNING) << "Exception while processing message for InferenceClientStage, attempting retry."; + } + + co_await backoff.yield(); + } +} + +// ************ InferenceClientStageInterfaceProxy********* // +std::shared_ptr> InferenceClientStageInterfaceProxy::init( + mrc::segment::Builder& builder, + const std::string& name, + std::string server_url, + std::string model_name, + bool needs_logits, + std::map input_mappings, + std::map output_mappings) +{ + std::vector input_mappings_{}; + std::vector output_mappings_{}; + + for (auto& mapping : input_mappings) + { + input_mappings_.emplace_back(TensorModelMapping{mapping.first, mapping.second}); + } + + for (auto& mapping : output_mappings) + { + output_mappings_.emplace_back(TensorModelMapping{mapping.first, mapping.second}); + } + + auto triton_client = std::make_unique(server_url); + auto triton_inference_client = std::make_unique(std::move(triton_client), model_name); + auto stage = builder.construct_object( + name, std::move(triton_inference_client), model_name, needs_logits, input_mappings_, output_mappings_); + + return stage; +} + +} // namespace morpheus diff --git a/morpheus/_lib/src/stages/triton_inference.cpp b/morpheus/_lib/src/stages/triton_inference.cpp index 8b8b807ca7..6464c3be5d 100644 --- a/morpheus/_lib/src/stages/triton_inference.cpp +++ b/morpheus/_lib/src/stages/triton_inference.cpp @@ -17,45 +17,33 @@ #include "morpheus/stages/triton_inference.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" -#include "mrc/node/sink_properties.hpp" -#include "mrc/node/source_properties.hpp" -#include "mrc/segment/builder.hpp" -#include "mrc/segment/object.hpp" -#include "mrc/types.hpp" - -#include "morpheus/messages/memory/response_memory.hpp" -#include "morpheus/messages/memory/tensor_memory.hpp" // for TensorMemory -#include "morpheus/objects/dev_mem_info.hpp" // for DevMemInfo -#include "morpheus/objects/dtype.hpp" // for DType -#include "morpheus/objects/tensor.hpp" // for Tensor::create -#include "morpheus/objects/tensor_object.hpp" // for TensorObject -#include "morpheus/objects/triton_in_out.hpp" // for TritonInOut -#include "morpheus/types.hpp" // for TensorIndex, TensorMap -#include "morpheus/utilities/matx_util.hpp" // for MatxUtil::logits, MatxUtil::reduce_max -#include "morpheus/utilities/stage_util.hpp" // for foreach_map -#include "morpheus/utilities/string_util.hpp" // for MORPHEUS_CONCAT_STR -#include "morpheus/utilities/tensor_util.hpp" // for get_elem_count +#include "morpheus/objects/dtype.hpp" // for DType +#include "morpheus/objects/tensor.hpp" // for Tensor::create +#include "morpheus/objects/tensor_object.hpp" // for TensorObject +#include "morpheus/objects/triton_in_out.hpp" // for TritonInOut +#include "morpheus/types.hpp" // for TensorIndex, TensorMap +#include "morpheus/utilities/string_util.hpp" // for MORPHEUS_CONCAT_STR +#include "morpheus/utilities/tensor_util.hpp" // for get_elem_count #include // for cudaMemcpy, cudaMemcpy2D, cudaMemcpyDeviceToHost, cudaMemcpyHostToDevice #include #include #include // for MRC_CHECK_CUDA #include -#include #include // for cuda_stream_per_thread #include // for device_buffer #include // for min +#include #include -#include -#include #include +#include #include #include #include // for runtime_error, out_of_range +#include #include +#include // IWYU pragma: no_include /** @@ -64,17 +52,14 @@ * @file */ +namespace { + /** * @brief Checks the status object returned by a Triton client call logging any potential errors. * */ #define CHECK_TRITON(method) ::InferenceClientStage__check_triton_errors(method, #method, __FILE__, __LINE__); -namespace { - -using namespace morpheus; -using buffer_map_t = std::map>; - // Component-private free functions. void InferenceClientStage__check_triton_errors(triton::client::Error status, const std::string& methodName, @@ -91,298 +76,89 @@ void InferenceClientStage__check_triton_errors(triton::client::Error status, } } -void build_output_tensors(TensorIndex count, - const std::vector& model_outputs, - buffer_map_t& output_buffers, - TensorMap& output_tensors) -{ - // Create the output memory blocks - for (auto& model_output : model_outputs) - { - ShapeType total_shape = model_output.shape; - - // First dimension will always end up being the number of rows in the dataframe - total_shape[0] = count; - auto elem_count = TensorUtils::get_elem_count(total_shape); - - // Create the output memory - auto output_buffer = std::make_shared(elem_count * model_output.datatype.item_size(), - rmm::cuda_stream_per_thread); - - output_buffers[model_output.mapped_name] = output_buffer; - - // Triton results are always in row-major as required by the KServe protocol - // https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#tensor-data - ShapeType stride{total_shape[1], 1}; - output_tensors[model_output.mapped_name].swap( - Tensor::create(std::move(output_buffer), model_output.datatype, total_shape, stride, 0)); - } -} - -ShapeType get_seq_ids(const InferenceClientStage::sink_type_t& message) -{ - // Take a copy of the sequence Ids allowing us to map rows in the response to rows in the dataframe - // The output tensors we store in `reponse_memory` will all be of the same length as the the - // dataframe. seq_ids has three columns, but we are only interested in the first column. - auto seq_ids = message->get_input("seq_ids"); - const auto item_size = seq_ids.dtype().item_size(); - - ShapeType host_seq_ids(message->count); - MRC_CHECK_CUDA(cudaMemcpy2D(host_seq_ids.data(), - item_size, - seq_ids.data(), - seq_ids.stride(0) * item_size, - item_size, - host_seq_ids.size(), - cudaMemcpyDeviceToHost)); - - return host_seq_ids; -} +using namespace morpheus; +using buffer_map_t = std::map>; -std::pair, std::vector> build_input( - const InferenceClientStage::sink_type_t& msg_slice, const TritonInOut& model_input) +static bool is_default_grpc_port(std::string& server_url) { - DCHECK(msg_slice->memory->has_tensor(model_input.mapped_name)) - << "Model input '" << model_input.mapped_name << "' not found in InferenceMemory"; - - auto const& inp_tensor = msg_slice->get_input(model_input.mapped_name); - - // Convert to the right type. Make shallow if necessary - auto final_tensor = inp_tensor.as_type(model_input.datatype); - - std::vector inp_data = final_tensor.get_host_data(); - - // Test - triton::client::InferInput* inp_ptr; - - triton::client::InferInput::Create( - &inp_ptr, model_input.name, {inp_tensor.shape(0), inp_tensor.shape(1)}, model_input.datatype.triton_str()); - - std::shared_ptr inp_shared; - inp_shared.reset(inp_ptr); - - inp_ptr->AppendRaw(inp_data); + // Check if we are the default gRPC port of 8001 and try 8000 for http client instead + size_t colon_loc = server_url.find_last_of(':'); - return std::make_pair(inp_shared, std::move(inp_data)); -} + if (colon_loc == -1) + { + return false; + } -std::shared_ptr build_output(const TritonInOut& model_output) -{ - triton::client::InferRequestedOutput* out_ptr; + // Check if the port matches 8001 + if (server_url.size() < colon_loc + 1 || server_url.substr(colon_loc + 1) != "8001") + { + return false; + } - triton::client::InferRequestedOutput::Create(&out_ptr, model_output.name); - std::shared_ptr out_shared; - out_shared.reset(out_ptr); + // It matches, change to 8000 + server_url = server_url.substr(0, colon_loc) + ":8000"; - return out_shared; + return true; } -void reduce_outputs(const InferenceClientStage::sink_type_t& x, buffer_map_t& output_buffers, TensorMap& output_tensors) +struct TritonInferOperation { - // When our tensor lengths are longer than our dataframe we will need to use the seq_ids array to - // lookup how the values should map back into the dataframe. - auto host_seq_ids = get_seq_ids(x); - - TensorMap reduced_outputs; - - for (const auto& output : output_tensors) + bool await_ready() const noexcept { - auto& tensor = output.second; - - ShapeType shape = tensor.get_shape(); - ShapeType stride = tensor.get_stride(); - - ShapeType reduced_shape{shape}; - reduced_shape[0] = x->mess_count; - - auto& buffer = output_buffers[output.first]; - auto reduced_buffer = - MatxUtil::reduce_max(DevMemInfo{buffer, tensor.dtype(), shape, stride}, host_seq_ids, 0, reduced_shape); - - output_buffers[output.first] = reduced_buffer; - - reduced_outputs[output.first].swap( - Tensor::create(std::move(reduced_buffer), tensor.dtype(), reduced_shape, stride, 0)); + return false; } - output_tensors = std::move(reduced_outputs); -} - -void apply_logits(buffer_map_t& output_buffers, TensorMap& output_tensors) -{ - TensorMap logit_outputs; - - for (const auto& output : output_tensors) + void await_suspend(std::coroutine_handle<> handle) { - auto& input_tensor = output.second; - - auto shape = input_tensor.get_shape(); - auto stride = input_tensor.get_stride(); - - auto& buffer = output_buffers[output.first]; - - auto output_buffer = MatxUtil::logits(DevMemInfo{buffer, input_tensor.dtype(), shape, stride}); - - output_buffers[output.first] = output_buffer; + CHECK_TRITON(m_client.async_infer( + [this, handle](triton::client::InferResult* result) { + m_result.reset(result); + handle(); + }, + m_options, + m_inputs, + m_outputs)); + } - // For logits the input and output shapes will be the same - logit_outputs[output.first].swap( - Tensor::create(std::move(output_buffer), input_tensor.dtype(), shape, stride, 0)); + std::unique_ptr await_resume() + { + return std::move(m_result); } - output_tensors = std::move(logit_outputs); -} + ITritonClient& m_client; + triton::client::InferOptions const& m_options; + std::vector const& m_inputs; + std::vector const& m_outputs; + std::unique_ptr m_result; +}; } // namespace namespace morpheus { -// Component public implementations -// ************ InferenceClientStage ************************* // -InferenceClientStage::InferenceClientStage(std::string model_name, - std::string server_url, - bool force_convert_inputs, - bool use_shared_memory, - bool needs_logits, - std::map inout_mapping) : - PythonNode(base_t::op_factory_from_sub_fn(build_operator())), - m_model_name(std::move(model_name)), - m_server_url(std::move(server_url)), - m_force_convert_inputs(force_convert_inputs), - m_use_shared_memory(use_shared_memory), - m_needs_logits(needs_logits), - m_inout_mapping(std::move(inout_mapping)), - m_options(m_model_name) -{ - // Connect with the server to setup the inputs/outputs - this->connect_with_server(); // TODO(Devin) -} - -InferenceClientStage::subscribe_fn_t InferenceClientStage::build_operator() -{ - return [this](rxcpp::observable input, rxcpp::subscriber output) { - std::unique_ptr client; - - CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&client, m_server_url, false)); - - return input.subscribe(rxcpp::make_observer( - [this, &output, &client](sink_type_t x) { - // Using the `count` which is the number of rows in the inference tensors. We will check later if this - // doesn't match the number of rows in the dataframe (`mess_count`). This happens when the size of the - // input is too large and needs to be broken up in chunks in the pre-process stage. When this is the - // case we will reduce the rows in the response outputs such that we have a single response for each - // row int he dataframe. - TensorMap output_tensors; - buffer_map_t output_buffers; - build_output_tensors(x->count, m_model_outputs, output_buffers, output_tensors); - - for (TensorIndex start = 0; start < x->count; start += m_max_batch_size) - { - triton::client::InferInput* input1; - - TensorIndex stop = std::min(start + m_max_batch_size, x->count); - - sink_type_t mini_batch_input = x->get_slice(start, stop); - - // Iterate on the model inputs in case the model takes less than what tensors are available - std::vector, std::vector>> - saved_inputs = foreach_map(m_model_inputs, [&mini_batch_input](auto const& model_input) { - return (build_input(mini_batch_input, model_input)); - }); - - std::vector> saved_outputs = - foreach_map(m_model_outputs, [](auto const& model_output) { - // Generate the outputs to be requested. - return build_output(model_output); - }); - - std::vector inputs = - foreach_map(saved_inputs, [](auto x) { return x.first.get(); }); - - std::vector outputs = - foreach_map(saved_outputs, [](auto x) { return x.get(); }); - - auto results = std::unique_ptr([&]() { - triton::client::InferResult* results; - CHECK_TRITON(client->Infer(&results, m_options, inputs, outputs)); - return results; - }()); - - for (auto& model_output : m_model_outputs) - { - std::vector output_shape; - - CHECK_TRITON(results->Shape(model_output.name, &output_shape)); - - // Make sure we have at least 2 dims - while (output_shape.size() < 2) - { - output_shape.push_back(1); - } - - const uint8_t* output_ptr = nullptr; - size_t output_ptr_size = 0; - CHECK_TRITON(results->RawData(model_output.name, &output_ptr, &output_ptr_size)); - - auto output_tensor = output_tensors[model_output.mapped_name].slice({start, 0}, {stop, -1}); - - DCHECK_EQ(stop - start, output_shape[0]); - DCHECK_EQ(output_tensor.bytes(), output_ptr_size); - DCHECK_NOTNULL(output_ptr); - DCHECK_NOTNULL(output_tensor.data()); - - MRC_CHECK_CUDA( - cudaMemcpy(output_tensor.data(), output_ptr, output_ptr_size, cudaMemcpyHostToDevice)); - } - } - - if (x->mess_count != x->count) - { - reduce_outputs(x, output_buffers, output_tensors); - } - - // If we need to do logits, do that here - if (m_needs_logits) - { - apply_logits(output_buffers, output_tensors); - } - - // Final output of all mini-batches - auto response_mem = std::make_shared(x->mess_count, std::move(output_tensors)); - auto response = std::make_shared( - x->meta, x->mess_offset, x->mess_count, std::move(response_mem), 0, response_mem->count); - - output.on_next(std::move(response)); - }, - [&](std::exception_ptr error_ptr) { output.on_error(error_ptr); }, - [&]() { output.on_completed(); })); - }; -} -void InferenceClientStage::connect_with_server() +HttpTritonClient::HttpTritonClient(std::string server_url) { - std::string server_url = m_server_url; - std::unique_ptr client; - auto result = triton::client::InferenceServerHttpClient::Create(&client, server_url, false); + CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&client, server_url, false)); - // Now load the input/outputs for the model - bool is_server_live = false; + bool is_server_live; - triton::client::Error status = client->IsServerLive(&is_server_live); + auto status = client->IsServerLive(&is_server_live); - if (!status.IsOk()) + if (not status.IsOk()) { - if (this->is_default_grpc_port(server_url)) + std::string new_server_url = server_url; + if (is_default_grpc_port(new_server_url)) { - LOG(WARNING) << "Failed to connect to Triton at '" << m_server_url + LOG(WARNING) << "Failed to connect to Triton at '" << server_url << "'. Default gRPC port of (8001) was detected but C++ " "InferenceClientStage uses HTTP protocol. Retrying with default HTTP port (8000)"; // We are using the default gRPC port, try the default HTTP std::unique_ptr unique_client; - auto result = triton::client::InferenceServerHttpClient::Create(&unique_client, server_url, false); + CHECK_TRITON(triton::client::InferenceServerHttpClient::Create(&unique_client, new_server_url, false)); client = std::move(unique_client); @@ -392,44 +168,130 @@ void InferenceClientStage::connect_with_server() { throw std::runtime_error(MORPHEUS_CONCAT_STR( "Failed to connect to Triton at '" - << m_server_url + << server_url << "'. Received 'Unsupported Protocol' error. Are you using the right port? The C++ " "InferenceClientStage uses Triton's HTTP protocol instead of gRPC. Ensure you have " "specified the HTTP port (Default 8000).")); } - if (!status.IsOk()) + if (not status.IsOk()) throw std::runtime_error( MORPHEUS_CONCAT_STR("Unable to connect to Triton at '" - << m_server_url << "'. Check the URL and port and ensure the server is running.")); + << server_url << "'. Check the URL and port and ensure the server is running.")); + } + + m_client = std::move(client); +} + +triton::client::Error HttpTritonClient::is_server_live(bool* live) +{ + return m_client->IsServerLive(live); +} + +triton::client::Error HttpTritonClient::is_server_ready(bool* ready) +{ + return m_client->IsServerReady(ready); +} + +triton::client::Error HttpTritonClient::is_model_ready(bool* ready, std::string& model_name) +{ + return m_client->IsModelReady(ready, model_name); +} + +triton::client::Error HttpTritonClient::model_config(std::string* model_config, std::string& model_name) +{ + return m_client->ModelConfig(model_config, model_name); +} + +triton::client::Error HttpTritonClient::model_metadata(std::string* model_metadata, std::string& model_name) +{ + return m_client->ModelMetadata(model_metadata, model_name); +} + +triton::client::Error HttpTritonClient::async_infer(triton::client::InferenceServerHttpClient::OnCompleteFn callback, + const triton::client::InferOptions& options, + const std::vector& inputs, + const std::vector& outputs) +{ + std::vector> inference_inputs; + std::vector inference_input_ptrs; + + for (auto& input : inputs) + { + triton::client::InferInput* inference_input_ptr; + triton::client::InferInput::Create(&inference_input_ptr, input.name, input.shape, input.type); + + inference_input_ptr->AppendRaw(input.data); + + inference_input_ptrs.emplace_back(inference_input_ptr); + inference_inputs.emplace_back(inference_input_ptr); + } + + std::vector> inference_outputs; + std::vector inference_output_ptrs; + + for (auto& output : outputs) + { + triton::client::InferRequestedOutput* inference_output_ptr; + triton::client::InferRequestedOutput::Create(&inference_output_ptr, output.name); + inference_output_ptrs.emplace_back(inference_output_ptr); + inference_outputs.emplace_back(inference_output_ptr); } - // Save this for new clients - m_server_url = server_url; + triton::client::InferResult* result; + + auto status = m_client->Infer(&result, options, inference_input_ptrs, inference_output_ptrs); + + callback(result); + + return status; + + // TODO(cwharris): either fix tests or make this ENV-flagged, as AsyncInfer gives different results. + + // return m_client->AsyncInfer( + // [callback](triton::client::InferResult* result) { + // callback(result); + // }, + // options, + // inference_input_ptrs, + // inference_output_ptrs); +} + +TritonInferenceClientSession::TritonInferenceClientSession(std::shared_ptr client, + std::string model_name) : + m_client(std::move(client)), + m_model_name(std::move(model_name)) +{ + // Now load the input/outputs for the model - if (!is_server_live) + bool is_server_live = false; + CHECK_TRITON(m_client->is_server_live(&is_server_live)); + if (not is_server_live) + { throw std::runtime_error("Server is not live"); + } bool is_server_ready = false; - CHECK_TRITON(client->IsServerReady(&is_server_ready)); - - if (!is_server_ready) + CHECK_TRITON(m_client->is_server_ready(&is_server_ready)); + if (not is_server_ready) + { throw std::runtime_error("Server is not ready"); + } bool is_model_ready = false; - CHECK_TRITON(client->IsModelReady(&is_model_ready, this->m_model_name)); - - if (!is_model_ready) + CHECK_TRITON(m_client->is_model_ready(&is_model_ready, this->m_model_name)); + if (not is_model_ready) + { throw std::runtime_error("Model is not ready"); + } std::string model_metadata_json; - CHECK_TRITON(client->ModelMetadata(&model_metadata_json, this->m_model_name)); + CHECK_TRITON(m_client->model_metadata(&model_metadata_json, this->m_model_name)); auto model_metadata = nlohmann::json::parse(model_metadata_json); std::string model_config_json; - CHECK_TRITON(client->ModelConfig(&model_config_json, this->m_model_name)); - + CHECK_TRITON(m_client->model_config(&model_config_json, this->m_model_name)); auto model_config = nlohmann::json::parse(model_config_json); if (model_config.contains("max_batch_size")) @@ -455,18 +317,11 @@ void InferenceClientStage::connect_with_server() bytes *= y; } - auto mapped_name = input.at("name").get(); - - if (m_inout_mapping.find(mapped_name) != m_inout_mapping.end()) - { - mapped_name = m_inout_mapping[mapped_name]; - } - m_model_inputs.push_back(TritonInOut{input.at("name").get(), bytes, DType::from_triton(input.at("datatype").get()), shape, - mapped_name, + "", 0}); } @@ -488,54 +343,161 @@ void InferenceClientStage::connect_with_server() bytes *= y; } - auto mapped_name = output.at("name").get(); + m_model_outputs.push_back(TritonInOut{output.at("name").get(), bytes, dtype, shape, "", 0}); + } +} + +std::vector TritonInferenceClientSession::get_input_mappings( + std::vector input_map_overrides) +{ + auto mappings = std::vector(); + + for (auto map : m_model_inputs) + { + mappings.emplace_back(TensorModelMapping(map.name, map.name)); + } + + for (auto override : input_map_overrides) + { + mappings.emplace_back(override); + } + + return mappings; +}; + +std::vector TritonInferenceClientSession::get_output_mappings( + std::vector output_map_overrides) +{ + auto mappings = std::vector(); + + for (auto map : m_model_outputs) + { + mappings.emplace_back(TensorModelMapping(map.name, map.name)); + } + + for (auto override : output_map_overrides) + { + auto pos = std::find_if(mappings.begin(), mappings.end(), [override](TensorModelMapping m) { + return m.model_field_name == override.model_field_name; + }); - if (m_inout_mapping.find(mapped_name) != m_inout_mapping.end()) + if (pos != mappings.end()) { - mapped_name = m_inout_mapping[mapped_name]; + mappings.erase(pos); } - m_model_outputs.push_back( - TritonInOut{output.at("name").get(), bytes, dtype, shape, mapped_name, 0}); + mappings.emplace_back(override); } + + return mappings; } -bool InferenceClientStage::is_default_grpc_port(std::string& server_url) +mrc::coroutines::Task TritonInferenceClientSession::infer(TensorMap&& inputs) { - // Check if we are the default gRPC port of 8001 and try 8000 for http client instead - size_t colon_loc = server_url.find_last_of(':'); + CHECK_EQ(inputs.size(), m_model_inputs.size()) << "Input tensor count does not match model input count"; - if (colon_loc == -1) + auto element_count = inputs.begin()->second.shape(0); + + for (auto& input : inputs) { - return false; + CHECK_EQ(element_count, input.second.shape(0)) << "Input tensors are different sizes"; } - // Check if the port matches 8001 - if (server_url.size() < colon_loc + 1 || server_url.substr(colon_loc + 1) != "8001") + TensorMap model_output_tensors; + + // create full inference output + for (auto& model_output : m_model_outputs) { - return false; + ShapeType full_output_shape = model_output.shape; + full_output_shape[0] = element_count; + auto full_output_element_count = TensorUtils::get_elem_count(full_output_shape); + + auto full_output_buffer = std::make_shared( + full_output_element_count * model_output.datatype.item_size(), rmm::cuda_stream_per_thread); + + ShapeType stride{full_output_shape[1], 1}; + + model_output_tensors[model_output.name].swap( + Tensor::create(std::move(full_output_buffer), model_output.datatype, full_output_shape, stride, 0)); } - // It matches, change to 8000 - server_url = server_url.substr(0, colon_loc) + ":8000"; + // process all batches - return true; -} + for (TensorIndex start = 0; start < element_count; start += m_max_batch_size) + { + TensorIndex stop = std::min(start + m_max_batch_size, static_cast(element_count)); -// ************ InferenceClientStageInterfaceProxy********* // -std::shared_ptr> InferenceClientStageInterfaceProxy::init( - mrc::segment::Builder& builder, - const std::string& name, - std::string model_name, - std::string server_url, - bool force_convert_inputs, - bool use_shared_memory, - bool needs_logits, - std::map inout_mapping) -{ - auto stage = builder.construct_object( - name, model_name, server_url, force_convert_inputs, use_shared_memory, needs_logits, inout_mapping); + // create batch inputs + + std::vector inference_inputs; + + for (auto model_input : m_model_inputs) + { + auto inference_input_slice = + inputs[model_input.name].slice({start, 0}, {stop, -1}).as_type(model_input.datatype); + + inference_inputs.emplace_back( + TritonInferInput{model_input.name, + {inference_input_slice.shape(0), inference_input_slice.shape(1)}, + model_input.datatype.triton_str(), + inference_input_slice.get_host_data()}); + } + + // create batch outputs + + std::vector outputs; - return stage; + for (auto model_output : m_model_outputs) + { + outputs.emplace_back(TritonInferRequestedOutput{model_output.name}); + } + + // infer batch results + + auto options = triton::client::InferOptions(m_model_name); + + auto results = co_await TritonInferOperation(*m_client, options, inference_inputs, outputs); + + // verify batch results and copy to full output tensors + + for (auto model_output : m_model_outputs) + { + auto output_tensor = model_output_tensors[model_output.name].slice({start, 0}, {stop, -1}); + + std::vector output_shape; + + CHECK_TRITON(results->Shape(model_output.name, &output_shape)); + + // Make sure we have at least 2 dims + while (output_shape.size() < 2) + { + output_shape.push_back(1); + } + + const uint8_t* output_ptr = nullptr; + size_t output_ptr_size = 0; + CHECK_TRITON(results->RawData(model_output.name, &output_ptr, &output_ptr_size)); + + DCHECK_EQ(stop - start, output_shape[0]); + DCHECK_EQ(output_tensor.bytes(), output_ptr_size); + DCHECK_NOTNULL(output_ptr); // NOLINT + DCHECK_NOTNULL(output_tensor.data()); // NOLINT + + MRC_CHECK_CUDA(cudaMemcpy(output_tensor.data(), output_ptr, output_ptr_size, cudaMemcpyHostToDevice)); + } + } + + co_return model_output_tensors; +}; + +TritonInferenceClient::TritonInferenceClient(std::unique_ptr&& client, std::string model_name) : + m_client(std::move(client)), + m_model_name(std::move(model_name)) +{} + +std::unique_ptr TritonInferenceClient::create_session() +{ + return std::make_unique(m_client, m_model_name); } + } // namespace morpheus diff --git a/morpheus/_lib/stages/__init__.pyi b/morpheus/_lib/stages/__init__.pyi index 2b40565087..eb70382722 100644 --- a/morpheus/_lib/stages/__init__.pyi +++ b/morpheus/_lib/stages/__init__.pyi @@ -58,7 +58,7 @@ class HttpServerSourceStage(mrc.core.segment.SegmentObject): def __init__(self, builder: mrc.core.segment.Builder, name: str, bind_address: str = '127.0.0.1', port: int = 8080, endpoint: str = '/message', method: str = 'POST', accept_status: int = 201, sleep_time: float = 0.10000000149011612, queue_timeout: int = 5, max_queue_size: int = 1024, num_server_threads: int = 1, max_payload_size: int = 10485760, request_timeout: int = 30, lines: bool = False, stop_after: int = 0) -> None: ... pass class InferenceClientStage(mrc.core.segment.SegmentObject): - def __init__(self, builder: mrc.core.segment.Builder, name: str, model_name: str, server_url: str, force_convert_inputs: bool, use_shared_memory: bool, needs_logits: bool, inout_mapping: typing.Dict[str, str] = {}) -> None: ... + def __init__(self, builder: mrc.core.segment.Builder, name: str, server_url: str, model_name: str, needs_logits: bool, input_mapping: typing.Dict[str, str] = {}, output_mapping: typing.Dict[str, str] = {}) -> None: ... pass class KafkaSourceStage(mrc.core.segment.SegmentObject): @typing.overload diff --git a/morpheus/_lib/stages/module.cpp b/morpheus/_lib/stages/module.cpp index 0fc47034d6..738e534e9a 100644 --- a/morpheus/_lib/stages/module.cpp +++ b/morpheus/_lib/stages/module.cpp @@ -25,12 +25,12 @@ #include "morpheus/stages/file_source.hpp" #include "morpheus/stages/filter_detection.hpp" #include "morpheus/stages/http_server_source_stage.hpp" +#include "morpheus/stages/inference_client_stage.hpp" #include "morpheus/stages/kafka_source.hpp" #include "morpheus/stages/preallocate.hpp" #include "morpheus/stages/preprocess_fil.hpp" #include "morpheus/stages/preprocess_nlp.hpp" #include "morpheus/stages/serialize.hpp" -#include "morpheus/stages/triton_inference.hpp" #include "morpheus/stages/write_to_file.hpp" #include "morpheus/utilities/cudf_util.hpp" #include "morpheus/utilities/http_server.hpp" // for DefaultMaxPayloadSize @@ -39,10 +39,9 @@ #include // for Builder #include #include -#include // for multiple_inheritance -#include // for arg, init, class_, module_, str_attr_accessor, PYBIND11_MODULE, pybind11 -#include // for dict, sequence -// for pathlib.Path -> std::filesystem::path conversions +#include // for multiple_inheritance +#include // for arg, init, class_, module_, str_attr_accessor, PYBIND11_MODULE, pybind11 +#include // for dict, sequence #include // IWYU pragma: keep #include // for pymrc::import #include @@ -151,12 +150,11 @@ PYBIND11_MODULE(stages, _module) .def(py::init<>(&InferenceClientStageInterfaceProxy::init), py::arg("builder"), py::arg("name"), - py::arg("model_name"), py::arg("server_url"), - py::arg("force_convert_inputs"), - py::arg("use_shared_memory"), + py::arg("model_name"), py::arg("needs_logits"), - py::arg("inout_mapping") = py::dict()); + py::arg("input_mapping") = py::dict(), + py::arg("output_mapping") = py::dict()); py::class_, mrc::segment::ObjectProperties, diff --git a/morpheus/_lib/tests/CMakeLists.txt b/morpheus/_lib/tests/CMakeLists.txt index 32ed8379df..b8330fb8bc 100644 --- a/morpheus/_lib/tests/CMakeLists.txt +++ b/morpheus/_lib/tests/CMakeLists.txt @@ -154,6 +154,12 @@ add_morpheus_test( test_tensor.cpp ) +add_morpheus_test( + NAME triton_inference_stage + FILES + stages/test_triton_inference_stage.cpp +) + add_morpheus_test( NAME type_util FILES diff --git a/morpheus/_lib/tests/stages/test_triton_inference_stage.cpp b/morpheus/_lib/tests/stages/test_triton_inference_stage.cpp new file mode 100644 index 0000000000..c7a566b011 --- /dev/null +++ b/morpheus/_lib/tests/stages/test_triton_inference_stage.cpp @@ -0,0 +1,343 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils/common.hpp" // IWYU pragma: associated + +#include "morpheus/messages/memory/tensor_memory.hpp" +#include "morpheus/messages/meta.hpp" +#include "morpheus/messages/multi_inference.hpp" +#include "morpheus/messages/multi_response.hpp" +#include "morpheus/objects/dtype.hpp" +#include "morpheus/objects/tensor.hpp" +#include "morpheus/objects/tensor_object.hpp" +#include "morpheus/stages/inference_client_stage.hpp" +#include "morpheus/stages/triton_inference.hpp" +#include "morpheus/types.hpp" +#include "morpheus/utilities/cudf_util.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +class FakeInferResult : public triton::client::InferResult +{ + private: + std::map> m_output_values; + + public: + FakeInferResult(std::map> output_values) : + m_output_values(std::move(output_values)) + {} + + triton::client::Error RequestStatus() const override + { + throw std::runtime_error("RequestStatus not implemented"); + } + + std::string DebugString() const override + { + throw std::runtime_error("DebugString not implemented"); + } + + triton::client::Error Id(std::string* id) const override + { + throw std::runtime_error("Id not implemented"); + } + + triton::client::Error ModelName(std::string* name) const override + { + throw std::runtime_error("ModelName not implemented"); + } + + triton::client::Error ModelVersion(std::string* version) const override + { + throw std::runtime_error("ModelVersion not implemented"); + } + + triton::client::Error Shape(const std::string& output_name, std::vector* shape) const override + { + shape = new std::vector({0, 0}); // this is technically a leak + + return triton::client::Error::Success; + } + + triton::client::Error Datatype(const std::string& output_name, std::string* datatype) const override + { + throw std::runtime_error("Datatype not implemented"); + } + + triton::client::Error StringData(const std::string& output_name, + std::vector* string_result) const override + { + throw std::runtime_error("StringData not implemented"); + } + + triton::client::Error RawData(const std::string& output_name, const uint8_t** buf, size_t* byte_size) const override + { + auto& output = m_output_values.at(output_name); + *byte_size = output.size() * sizeof(int32_t); + *buf = reinterpret_cast(const_cast(output.data())); + return triton::client::Error::Success; + } +}; + +class FakeTritonClient : public morpheus::ITritonClient +{ + private: + bool m_is_server_live_has_errored = false; + bool m_is_server_live = false; + bool m_is_server_ready_has_errored = false; + bool m_is_server_ready = false; + bool m_is_model_ready_has_errored = false; + bool m_is_model_ready = false; + bool m_model_config_has_errored = false; + bool m_model_metadata_has_errored = false; + bool m_async_infer_has_errored = false; + + public: + triton::client::Error is_server_live(bool* live) override + { + if (not m_is_server_live_has_errored) + { + m_is_server_live_has_errored = true; + return triton::client::Error("is_server_live error"); + } + + *live = m_is_server_live; + + if (not m_is_server_live) + { + m_is_server_live = true; + } + + return triton::client::Error::Success; + } + + triton::client::Error is_server_ready(bool* ready) override + { + if (not m_is_server_ready_has_errored) + { + m_is_server_ready_has_errored = true; + return triton::client::Error("is_server_ready error"); + } + + *ready = m_is_server_live; + + if (not m_is_server_ready) + { + m_is_server_ready = true; + } + + return triton::client::Error::Success; + } + + triton::client::Error is_model_ready(bool* ready, std::string& model_name) override + { + if (not m_is_model_ready_has_errored) + { + m_is_model_ready_has_errored = true; + return triton::client::Error("is_model_ready error"); + } + + *ready = m_is_model_ready; + + if (not m_is_model_ready) + { + m_is_model_ready = true; + } + + return triton::client::Error::Success; + } + + triton::client::Error model_config(std::string* model_config, std::string& model_name) override + { + if (not m_model_config_has_errored) + { + m_model_config_has_errored = true; + return triton::client::Error("model_config error"); + } + + *model_config = R"({ + "max_batch_size": 100 + })"; + + return triton::client::Error::Success; + } + + triton::client::Error model_metadata(std::string* model_metadata, std::string& model_name) override + { + if (not m_model_metadata_has_errored) + { + m_model_metadata_has_errored = true; + return triton::client::Error("model_metadata error"); + } + + *model_metadata = R"({ + "inputs":[ + { + "name":"seq_ids", + "shape": [0, 1], + "datatype":"INT32" + } + ], + "outputs":[ + { + "name":"seq_ids", + "shape": [0, 1], + "datatype":"INT32" + } + ]})"; + + return triton::client::Error::Success; + } + + triton::client::Error async_infer(triton::client::InferenceServerHttpClient::OnCompleteFn callback, + const triton::client::InferOptions& options, + const std::vector& inputs, + const std::vector& outputs) override + { + if (not m_async_infer_has_errored) + { + m_async_infer_has_errored = true; + return triton::client::Error("async_infer error"); + } + + callback(new FakeInferResult({{"seq_ids", std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}})); + + return triton::client::Error::Success; + } +}; + +class TestTritonInferenceStage : public morpheus::test::TestWithPythonInterpreter +{ + protected: + void SetUp() override + { + morpheus::test::TestWithPythonInterpreter::SetUp(); + { + pybind11::gil_scoped_acquire gil; + + // Initially I ran into an issue bootstrapping cudf, I was able to work-around the issue, details in: + // https://github.com/rapidsai/cudf/issues/12862 + morpheus::CudfHelper::load(); + } + } +}; + +cudf::io::table_with_metadata create_test_table_with_metadata(uint32_t rows) +{ + cudf::data_type cudf_data_type{cudf::type_to_id()}; + + auto column = cudf::make_fixed_width_column(cudf_data_type, rows); + + std::vector data(rows); + std::iota(data.begin(), data.end(), 0); + + cudaMemcpy(column->mutable_view().data(), + data.data(), + data.size() * sizeof(int), + cudaMemcpyKind::cudaMemcpyHostToDevice); + + std::vector> columns; + + columns.emplace_back(std::move(column)); + + auto table = std::make_unique(std::move(columns)); + + auto index_info = cudf::io::column_name_info{""}; + auto column_names = std::vector({{index_info}}); + auto metadata = cudf::io::table_metadata{std::move(column_names), {}, {}}; + + return cudf::io::table_with_metadata{std::move(table), metadata}; +} + +TEST_F(TestTritonInferenceStage, SingleRow) +{ + cudf::data_type cudf_data_type{cudf::type_to_id()}; + + const std::size_t count = 10; + const auto dtype = morpheus::DType::create(); + + // Create a 10-number sequence id vector and store them in the tensor. + auto buffer = std::make_shared(count * dtype.item_size(), rmm::cuda_stream_per_thread); + std::vector seq_ids({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + cudaMemcpy(buffer->data(), seq_ids.data(), count * sizeof(int), cudaMemcpyKind::cudaMemcpyHostToDevice); + auto tensors = morpheus::TensorMap(); + tensors["seq_ids"].swap(morpheus::Tensor::create(buffer, dtype, {count, 1}, {})); + + // create the MultiInferenceMessage using the sequence id tensor. + auto memory = std::make_shared(count, std::move(tensors)); + auto table = create_test_table_with_metadata(count); + auto meta = morpheus::MessageMeta::create_from_cpp(std::move(table), 1); + auto message = std::make_shared(meta, 0, count, memory); + + // create the fake triton client used for testing. + auto triton_client = std::make_unique(); + auto triton_inference_client = std::make_unique(std::move(triton_client), ""); + auto stage = morpheus::InferenceClientStage(std::move(triton_inference_client), "", false, {}, {}); + + // manually invoke the stage and iterate through the inference responses + auto on = std::make_shared(); + auto results_task = [](auto& stage, auto message, auto on) + -> mrc::coroutines::Task>> { + std::vector> results; + + auto responses_generator = stage.on_data(std::move(message), on); + + auto iter = co_await responses_generator.begin(); + + while (iter != responses_generator.end()) + { + results.emplace_back(std::move(*iter)); + + co_await ++iter; + } + + co_return results; + }(stage, message, on); + + results_task.resume(); + + while (on->resume_next()) {} + + ASSERT_NO_THROW(results_task.promise().result()); + + auto results = results_task.promise().result(); + + ASSERT_EQ(results.size(), 1); +} diff --git a/morpheus/stages/inference/triton_inference_stage.py b/morpheus/stages/inference/triton_inference_stage.py index 707b2f8f43..e5901363f9 100644 --- a/morpheus/stages/inference/triton_inference_stage.py +++ b/morpheus/stages/inference/triton_inference_stage.py @@ -439,14 +439,16 @@ def __init__(self, model_name: str, server_url: str, force_convert_inputs: bool, - inout_mapping: dict[str, str] = None, + input_mapping: dict[str, str] = None, + output_mapping: dict[str, str] = None, use_shared_memory: bool = False, needs_logits: bool = False): super().__init__(inf_queue) self._model_name = model_name self._server_url = server_url - self._inout_mapping = inout_mapping or {} + self._input_mapping = input_mapping or {} + self._output_mapping = output_mapping or {} self._use_shared_memory = use_shared_memory self._max_batch_size = c.model_max_batch_size @@ -515,7 +517,7 @@ def init(self): shm_config = {} - def build_inout(x: dict): + def build_inout(x: dict, mapping: dict[str, str]): num_bytes = np.dtype(triton_to_np_dtype(x["datatype"])).itemsize shape = [] @@ -530,7 +532,7 @@ def build_inout(x: dict): num_bytes *= y_int - mapped_name = x["name"] if x["name"] not in self._inout_mapping else self._inout_mapping[x["name"]] + mapped_name = x["name"] if x["name"] not in mapping else mapping[x["name"]] return TritonInOut(name=x["name"], bytes=num_bytes, @@ -539,12 +541,12 @@ def build_inout(x: dict): mapped_name=mapped_name) for x in model_meta["inputs"]: - self._inputs[x["name"]] = build_inout(x) + self._inputs[x["name"]] = build_inout(x, self._input_mapping) for x in model_meta["outputs"]: assert x["name"] not in self._inputs, "Input/Output names must be unique from eachother" - self._outputs[x["name"]] = build_inout(x) + self._outputs[x["name"]] = build_inout(x, self._output_mapping) # Combine the inputs/outputs for the shared memory shm_config = {**self._inputs, **self._outputs} @@ -687,11 +689,16 @@ class TritonInferenceStage(InferenceStage): _INFERENCE_WORKER_DEFAULT_INOUT_MAPPING = { PipelineModes.FIL: { - "output__0": "probs", + "outputs": { + "output__0": "probs", + } }, PipelineModes.NLP: { - "attention_mask": "input_mask", - "output": "probs", + "inputs": { + "attention_mask": "input_mask", + }, "outputs": { + "output": "probs", + } } } @@ -702,7 +709,9 @@ def __init__(self, force_convert_inputs: bool = False, use_shared_memory: bool = False, needs_logits: bool = None, - inout_mapping: dict[str, str] = None): + inout_mapping: dict[str, str] = None, + input_mapping: dict[str, str] = None, + output_mapping: dict[str, str] = None): super().__init__(c) self._config = c @@ -710,19 +719,39 @@ def __init__(self, if needs_logits is None: needs_logits = c.mode == PipelineModes.NLP - # Combine the pipeline mode defaults with any user supplied ones - inout_mapping_ = self._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(c.mode, {}) - if inout_mapping is not None: - inout_mapping_.update(inout_mapping) - - self._kwargs = { - "model_name": model_name, - "server_url": server_url, - "force_convert_inputs": force_convert_inputs, - "use_shared_memory": use_shared_memory, - "inout_mapping": inout_mapping_, - "needs_logits": needs_logits - } + input_mapping_ = self._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(c.mode, {}).get("inputs", {}) + output_mapping_ = self._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(c.mode, {}).get("outputs", {}) + + if inout_mapping: + + if input_mapping: + raise RuntimeError( + "TritonInferenceStages' `inout_mapping` and `input_mapping` arguments cannot be used together`") + + if output_mapping: + raise RuntimeError( + "TritonInferenceStages' `inout_mapping` and `output_mapping` arguments cannot be used together`") + + warnings.warn(("TritonInferenceStage's `inout_mapping` argument has been deprecated. " + "Please use `input_mapping` and/or `output_mapping` instead"), + DeprecationWarning) + + input_mapping_.update(inout_mapping) + output_mapping_.update(inout_mapping) + + if input_mapping is not None: + input_mapping_.update(input_mapping) + + if output_mapping is not None: + output_mapping_.update(output_mapping) + + self._server_url = server_url + self._model_name = model_name + self._force_convert_inputs = force_convert_inputs + self._use_shared_memory = use_shared_memory + self._input_mapping = input_mapping_ + self._output_mapping = output_mapping_ + self._needs_logits = needs_logits def supports_cpp_node(self) -> bool: # Get the value from the worker class @@ -734,7 +763,21 @@ def _get_inference_worker(self, inf_queue: ProducerConsumerQueue) -> TritonInfer worker. """ - return TritonInferenceWorker(inf_queue=inf_queue, c=self._config, **self._kwargs) + return TritonInferenceWorker(inf_queue=inf_queue, + c=self._config, + server_url=self._server_url, + model_name=self._model_name, + force_convert_inputs=self._force_convert_inputs, + use_shared_memory=self._use_shared_memory, + input_mapping=self._input_mapping, + output_mapping=self._output_mapping, + needs_logits=self._needs_logits) def _get_cpp_inference_node(self, builder: mrc.Builder) -> mrc.SegmentObject: - return _stages.InferenceClientStage(builder, name=self.unique_name, **self._kwargs) + return _stages.InferenceClientStage(builder, + self.unique_name, + self._server_url, + self._model_name, + self._needs_logits, + self._input_mapping, + self._output_mapping) diff --git a/tests/examples/log_parsing/test_inference.py b/tests/examples/log_parsing/test_inference.py index b9109792a2..bd917df2e6 100644 --- a/tests/examples/log_parsing/test_inference.py +++ b/tests/examples/log_parsing/test_inference.py @@ -24,7 +24,6 @@ from _utils import TEST_DIRS from morpheus.config import Config -from morpheus.config import PipelineModes from morpheus.messages import InferenceMemoryNLP from morpheus.messages import MessageMeta from morpheus.messages import MultiInferenceNLPMessage @@ -96,7 +95,7 @@ def build_inf_message(df: DataFrameType, count=count) -def _check_worker(inference_mod: types.ModuleType, worker: TritonInferenceWorker, expected_mapping: dict[str, str]): +def _check_worker(inference_mod: types.ModuleType, worker: TritonInferenceWorker): assert isinstance(worker, TritonInferenceWorker) assert isinstance(worker, inference_mod.TritonInferenceLogParsing) assert worker._model_name == 'test_model' @@ -104,7 +103,6 @@ def _check_worker(inference_mod: types.ModuleType, worker: TritonInferenceWorker assert not worker._force_convert_inputs assert not worker._use_shared_memory assert worker.needs_logits - assert worker._inout_mapping == expected_mapping @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) @@ -117,10 +115,10 @@ def test_log_parsing_triton_inference_log_parsing_constructor(config: Config, server_url='test_server', force_convert_inputs=False, use_shared_memory=False, - inout_mapping={'test': 'this'}, + input_mapping={'test': 'this'}, needs_logits=True) - _check_worker(inference_mod, worker, {'test': 'this'}) + _check_worker(inference_mod, worker) @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) @@ -164,37 +162,6 @@ def test_log_parsing_triton_inference_log_parsing_build_output_message(config: C assert msg.get_tensor('seq_ids').shape == (count, 3) -@pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) -def test_log_parsing_inference_stage_constructor(config: Config, import_mod: typing.List[types.ModuleType]): - inference_mod = import_mod[0] - - expected_kwargs = { - "model_name": - 'test_model', - "server_url": - 'test_server', - "force_convert_inputs": - False, - "use_shared_memory": - False, - "needs_logits": - True, - "inout_mapping": - inference_mod.LogParsingInferenceStage._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(PipelineModes.NLP, {}), - } - - stage = inference_mod.LogParsingInferenceStage( - config, - model_name='test_model', - server_url='test_server', - force_convert_inputs=False, - use_shared_memory=False, - ) - - assert stage._config is config - assert stage._kwargs == expected_kwargs - - @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) def test_log_parsing_inference_stage_get_inference_worker(config: Config, import_mod: typing.List[types.ModuleType]): inference_mod = import_mod[0] @@ -206,12 +173,8 @@ def test_log_parsing_inference_stage_get_inference_worker(config: Config, import use_shared_memory=False, inout_mapping={'test': 'this'}) - expected_mapping = inference_mod.LogParsingInferenceStage._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get( - PipelineModes.NLP, {}) - expected_mapping.update({'test': 'this'}) - worker = stage._get_inference_worker(inf_queue=ProducerConsumerQueue()) - _check_worker(inference_mod, worker, expected_mapping) + _check_worker(inference_mod, worker) @pytest.mark.use_cudf diff --git a/tests/test_cli.py b/tests/test_cli.py index 51f4df3e41..db9566ad6e 100755 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -406,9 +406,9 @@ def test_pipeline_fil(self, config, callback_values): assert isinstance(process_fil, PreprocessFILStage) assert isinstance(triton_inf, TritonInferenceStage) - assert triton_inf._kwargs['model_name'] == 'test-model' - assert triton_inf._kwargs['server_url'] == 'test:123' - assert triton_inf._kwargs['force_convert_inputs'] + assert triton_inf._model_name == 'test-model' + assert triton_inf._server_url == 'test:123' + assert triton_inf._force_convert_inputs assert isinstance(monitor, MonitorStage) assert monitor._mc._description == 'Unittest' @@ -529,9 +529,9 @@ def test_pipeline_fil_all(self, config, callback_values, tmp_path, mlflow_uri): assert mlflow_drift._tracking_uri == mlflow_uri assert isinstance(triton_inf, TritonInferenceStage) - assert triton_inf._kwargs['model_name'] == 'test-model' - assert triton_inf._kwargs['server_url'] == 'test:123' - assert triton_inf._kwargs['force_convert_inputs'] + assert triton_inf._model_name == 'test-model' + assert triton_inf._server_url == 'test:123' + assert triton_inf._force_convert_inputs assert isinstance(monitor, MonitorStage) assert monitor._mc._description == 'Unittest' @@ -663,9 +663,9 @@ def test_enum_parsing(self, config, callback_values, tmp_path, mlflow_uri): assert mlflow_drift._tracking_uri == mlflow_uri assert isinstance(triton_inf, TritonInferenceStage) - assert triton_inf._kwargs['model_name'] == 'test-model' - assert triton_inf._kwargs['server_url'] == 'test:123' - assert triton_inf._kwargs['force_convert_inputs'] + assert triton_inf._model_name == 'test-model' + assert triton_inf._server_url == 'test:123' + assert triton_inf._force_convert_inputs assert isinstance(monitor, MonitorStage) assert monitor._mc._description == 'Unittest' @@ -744,9 +744,9 @@ def test_pipeline_nlp(self, config, callback_values): assert not process_nlp._add_special_tokens assert isinstance(triton_inf, TritonInferenceStage) - assert triton_inf._kwargs['model_name'] == 'test-model' - assert triton_inf._kwargs['server_url'] == 'test:123' - assert triton_inf._kwargs['force_convert_inputs'] + assert triton_inf._model_name == 'test-model' + assert triton_inf._server_url == 'test:123' + assert triton_inf._force_convert_inputs assert isinstance(monitor, MonitorStage) assert monitor._mc._description == 'Unittest' @@ -877,9 +877,9 @@ def test_pipeline_nlp_all(self, config, callback_values, tmp_path, mlflow_uri): assert mlflow_drift._tracking_uri == mlflow_uri assert isinstance(triton_inf, TritonInferenceStage) - assert triton_inf._kwargs['model_name'] == 'test-model' - assert triton_inf._kwargs['server_url'] == 'test:123' - assert triton_inf._kwargs['force_convert_inputs'] + assert triton_inf._model_name == 'test-model' + assert triton_inf._server_url == 'test:123' + assert triton_inf._force_convert_inputs assert isinstance(monitor, MonitorStage) assert monitor._mc._description == 'Unittest' diff --git a/tests/test_triton_inference_stage.py b/tests/test_triton_inference_stage.py index 02430909d4..a361c712a1 100644 --- a/tests/test_triton_inference_stage.py +++ b/tests/test_triton_inference_stage.py @@ -122,45 +122,6 @@ def test_resource_pool_create_raises_error(): assert pool.borrow_obj() == 20 -@pytest.mark.parametrize("pipeline_mode", list(PipelineModes)) -@pytest.mark.parametrize("force_convert_inputs", [True, False]) -@pytest.mark.parametrize("use_shared_memory", [True, False]) -@pytest.mark.parametrize("needs_logits", [True, False, None]) -@pytest.mark.parametrize("inout_mapping", [None, {'unit': 'test'}]) -def test_stage_constructor(config: Config, - pipeline_mode: PipelineModes, - force_convert_inputs: bool, - use_shared_memory: bool, - needs_logits: bool | None, - inout_mapping: dict[str, str] | None): - if needs_logits is None: - expexted_needs_logits = (pipeline_mode == PipelineModes.NLP) - else: - expexted_needs_logits = needs_logits - - expected_inout_mapping = TritonInferenceStage._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(pipeline_mode, {}) - expected_inout_mapping.update(inout_mapping or {}) - - config.mode = pipeline_mode - - stage = TritonInferenceStage(config, - model_name='test', - server_url='test:0000', - force_convert_inputs=force_convert_inputs, - use_shared_memory=use_shared_memory, - needs_logits=needs_logits, - inout_mapping=inout_mapping) - - assert stage._kwargs == { - "model_name": "test", - "server_url": "test:0000", - "force_convert_inputs": force_convert_inputs, - "use_shared_memory": use_shared_memory, - "needs_logits": expexted_needs_logits, - 'inout_mapping': expected_inout_mapping - } - - @pytest.mark.use_python @pytest.mark.parametrize("pipeline_mode", list(PipelineModes)) def test_stage_constructor_worker_class(config: Config, pipeline_mode: PipelineModes):