Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address a DML regression caused by the continuous decoding changes #1159

Merged
merged 8 commits into from
Jan 13, 2025
3 changes: 2 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
if (search_->GetSequenceLength() != 0 &&
std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(),
[this](DeviceType device_type) { return device_type == state_->params_->device_type; }))
throw std::runtime_error("Continuous decoding is not supported on the selected device type: " + to_string(state_->params_->device_type));
throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->device_type) +
"). Please recreate the generator instance to avoid using continuous decoding.");

if (last_action_ == Action::generated) {
ComputeLogits(search_->GetNextTokens());
Expand Down
75 changes: 38 additions & 37 deletions src/models/debugging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,47 +88,48 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
stream << SGR::Fg_Green << " Location: " << SGR::Reset;

const auto& memory_info = value->GetTensorMemoryInfo();
switch (memory_info.GetDeviceType()) {
case OrtMemoryInfoDeviceType_CPU:
stream << "CPU\r\n";
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
break;
case OrtMemoryInfoDeviceType_GPU: {
stream << "GPU\r\n";
auto device_type = memory_info.GetDeviceType();
if (device_type == OrtMemoryInfoDeviceType_CPU) {
stream << "CPU\r\n";
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
} else if (device_type == OrtMemoryInfoDeviceType_GPU) {
stream << "GPU\r\n";
#if USE_CUDA
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
DumpValues(stream, type, cpu_copy.get(), element_count);
#elif USE_DML
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);

if (value->GetTensorMutableRawData()) {
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
model.allocator_device_,
value->GetTensorMutableRawData(),
&gpu_resource));

model.GetDmlReadbackHeap()->ReadbackFromGpu(
std::span(cpu_copy.get(), element_size * element_count),
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
}

DumpValues(stream, type, cpu_copy.get(), element_count);
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
DumpValues(stream, type, cpu_copy.get(), element_count);
#else
stream << "Unexpected, using GPU memory but not compiled with CUDA or DML?";
throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA.");
#endif
break;
} else if (static_cast<int>(device_type) == 4) {
stream << "DML\r\n";
#if USE_DML
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);

if (value->GetTensorMutableRawData()) {
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
model.allocator_device_,
value->GetTensorMutableRawData(),
&gpu_resource));

model.GetDmlReadbackHeap()->ReadbackFromGpu(
std::span(cpu_copy.get(), element_size * element_count),
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
}
default:
stream << "Unhandled device type";
break;

DumpValues(stream, type, cpu_copy.get(), element_count);
#else
throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML.");
#endif
} else {
stream << "Unhandled device type: " << static_cast<int>(device_type) << "\r\n";
}
}

Expand Down
49 changes: 37 additions & 12 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,9 @@ namespace Generators {
DefaultInputIDs::DefaultInputIDs(State& state)
: state_{state} {
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
shape_ = {state_.params_->BatchBeamSize(), 0};
shape_ = {state_.params_->search.batch_size, 0};
type_ = model_.session_info_->GetInputDataType(name_);

if (state_.GetCapturedGraphInfo()) {
sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();

#if USE_DML
if (model_.device_type_ == DeviceType::DML) {
sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
}
#endif
}

if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
model_.session_info_->HasInput(model_.config_->model.decoder.inputs.past_sequence_length)) {
if (state_.params_->BatchBeamSize() != 1) {
Expand All @@ -36,7 +26,7 @@ DefaultInputIDs::DefaultInputIDs(State& state)
current_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, current_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.current_sequence_length));
*current_sequence_length_->GetTensorMutableData<int32_t>() = 0;

past_sequence_length_ = OrtValue::CreateTensor(*model_.allocator_device_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
past_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
*past_sequence_length_->GetTensorMutableData<int32_t>() = -1;
}
}
Expand All @@ -56,6 +46,41 @@ void DefaultInputIDs::Add() {
}

void DefaultInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
if (!value_) {
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
shape_[1] = static_cast<int64_t>(new_tokens.size()) / shape_[0];

// If 64-bit, convert from 32-bit to 64-bit
auto input_ids = new_tokens.CopyDeviceToCpu();
if (type_ == Ort::TypeToTensorType<int64_t>) {
value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
auto* p_data = value_->GetTensorMutableData<int64_t>();
for (auto v : input_ids) {
*p_data++ = v;
}
} else {
if (type_ != Ort::TypeToTensorType<int32_t>)
throw std::runtime_error("InputIDs must be int64 or int32");
value_ = OrtValue::CreateTensor<int32_t>(model_.allocator_cpu_.GetInfo(), input_ids, shape_);
}

value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams);
shape_[0] *= state_.params_->search.num_beams;

if (state_.GetCapturedGraphInfo()) {
sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();

#if USE_DML
if (model_.device_type_ == DeviceType::DML) {
sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
}
#endif
}

is_prompt_ = false;
state_.inputs_[input_index_] = value_.get();
return;
}

const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids,
int32_t pad_token_id) {
int32_t seq_length = 0;
Expand Down
21 changes: 12 additions & 9 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@ Logits::Logits(State& state)
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);

if (state_.GetCapturedGraphInfo()) {
if (type_ == Ort::TypeToTensorType<float>) {
sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
}
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
}
}

#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
auto& cpu_ids = model_.config_->model.eos_token_ids;
Expand Down Expand Up @@ -215,6 +206,18 @@ void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length
StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType<Ort::Float16_t> ? sb_logits16_ : sb_logits32_;
output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
: sb_logits->CreateTensorOnStaticBuffer(shape_, type_);

if (state_.GetCapturedGraphInfo()) {
if (!sb_logits16_ && !sb_logits32_) {
if (type_ == Ort::TypeToTensorType<float>) {
sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
}
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
}
}
}

state_.outputs_[output_index_] = output_raw_.get();
}

Expand Down
19 changes: 12 additions & 7 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,18 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
// Copy data to ortvalue_clone
auto element_size = Generators::SizeOf(type_info->GetElementType());
auto data_size = type_info->GetElementCount() * element_size;
if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) {
const auto device_type = ortvalue_output->GetTensorMemoryInfo().GetDeviceType();
if (device_type == OrtMemoryInfoDeviceType_CPU) {
std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
} else if (device_type == OrtMemoryInfoDeviceType_GPU) {
#if USE_CUDA
cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost);
#else
throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA.");
#endif
} else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) {
} else if (static_cast<int>(device_type) == 4) {
#if USE_DML
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
Expand All @@ -354,13 +361,11 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
#else
throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML.");
#endif
} else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
} else {
throw std::runtime_error("Unsupported Device type: " + std::to_string(ortvalue_output->GetTensorMemoryInfo().GetDeviceType()));
throw std::runtime_error("Unsupported device type: " + static_cast<int>(device_type));
}

auto tensor = std::make_shared<Generators::Tensor>(std::move(ortvalue_clone));
Expand Down
Loading