Skip to content

Commit

Permalink
expose Zig / C API for compilation
Browse files Browse the repository at this point in the history
Signed-off-by: Stephen Gutekanst <[email protected]>
  • Loading branch information
emidoots committed Dec 15, 2023
1 parent e210522 commit d9e236d
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 6 deletions.
131 changes: 127 additions & 4 deletions src/mach_dxc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,141 @@
// Avoid __declspec(dllimport) since dxcompiler is static.
#define DXC_API_IMPORT
#include <dxcapi.h>
#include <cassert>
#include <stddef.h>

#include "mach_dxc.h"

#ifdef __cplusplus
extern "C" {
#endif

MACH_EXPORT void machDxcFoo() {
CComPtr<IDxcCompiler> dxcInstance;
// Mach change start: static dxcompiler/dxil
BOOL MachDxcompilerInvokeDllMain();
void MachDxcompilerInvokeDllShutdown();

//----------------
// MachDxcCompiler
//----------------
MACH_EXPORT MachDxcCompiler machDxcInit() {
MachDxcompilerInvokeDllMain();
CComPtr<IDxcCompiler3> dxcInstance;
HRESULT hr = DxcCreateInstance(CLSID_DxcCompiler, IID_PPV_ARGS(&dxcInstance));
// TODO: check success
return;
assert(SUCCEEDED(hr));
return reinterpret_cast<MachDxcCompiler>(dxcInstance.Detach());
}

MACH_EXPORT void machDxcDeinit(MachDxcCompiler compiler) {
CComPtr<IDxcCompiler3> dxcInstance = CComPtr(reinterpret_cast<IDxcCompiler3*>(compiler));
dxcInstance.Release();
MachDxcompilerInvokeDllShutdown();
}

//---------------------
// MachDxcCompileResult
//---------------------
MACH_EXPORT MachDxcCompileResult machDxcCompile(
MachDxcCompiler compiler,
char const* code,
size_t code_len,
char const* const* args,
size_t args_len
) {
CComPtr<IDxcCompiler3> dxcInstance = CComPtr(reinterpret_cast<IDxcCompiler3*>(compiler));

CComPtr<IDxcUtils> pUtils;
DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(&pUtils));
CComPtr<IDxcBlobEncoding> pSource;
pUtils->CreateBlob(code, code_len, CP_UTF8, &pSource);

DxcBuffer sourceBuffer;
sourceBuffer.Ptr = pSource->GetBufferPointer();
sourceBuffer.Size = pSource->GetBufferSize();
sourceBuffer.Encoding = 0;

// We have args in char form, but dxcInstance->Compile expects wchar_t form.
std::vector<std::wstring> arguments;
for (int i=0; i < args_len; i++) {
wchar_t wtext_buf[200];
std::mbstowcs(wtext_buf, args[i], strlen(args[i])+1);
arguments.push_back(std::wstring(wtext_buf));
}
std::vector<LPCWSTR> w_arguments_list;
for (int i=0; i < args_len; i++) {
w_arguments_list.push_back(arguments[i].data());
}

CComPtr<IDxcResult> pCompileResult;
HRESULT hr = dxcInstance->Compile(
&sourceBuffer,
w_arguments_list.data(),
(uint32_t)w_arguments_list.size(),
nullptr,
IID_PPV_ARGS(&pCompileResult)
);
assert(SUCCEEDED(hr));
return reinterpret_cast<MachDxcCompileResult>(pCompileResult.Detach());
}

MACH_EXPORT MachDxcCompileError machDxcCompileResultGetError(MachDxcCompileResult err) {
CComPtr<IDxcResult> pCompileResult = CComPtr(reinterpret_cast<IDxcResult*>(err));
CComPtr<IDxcBlobUtf8> pErrors;
pCompileResult->GetOutput(DXC_OUT_ERRORS, IID_PPV_ARGS(&pErrors), nullptr);
if (pErrors && pErrors->GetStringLength() > 0) {
return reinterpret_cast<MachDxcCompileError>(pErrors.Detach());
}
return nullptr;
}

MACH_EXPORT MachDxcCompileObject machDxcCompileResultGetObject(MachDxcCompileResult err) {
CComPtr<IDxcResult> pCompileResult = CComPtr(reinterpret_cast<IDxcResult*>(err));
CComPtr<IDxcBlob> pObject;
pCompileResult->GetOutput(DXC_OUT_OBJECT, IID_PPV_ARGS(&pObject), nullptr);
if (pObject && pObject->GetBufferSize() > 0) {
return reinterpret_cast<MachDxcCompileObject>(pObject.Detach());
}
return nullptr;
}

MACH_EXPORT void machDxcCompileResultDeinit(MachDxcCompileResult err) {
CComPtr<IDxcResult> pCompileResult = CComPtr(reinterpret_cast<IDxcResult*>(err));
pCompileResult.Release();
}

//---------------------
// MachDxcCompileObject
//---------------------
MACH_EXPORT char const* machDxcCompileObjectGetBytes(MachDxcCompileObject err) {
CComPtr<IDxcBlob> pObject = CComPtr(reinterpret_cast<IDxcBlob*>(err));
return (char const*)(pObject->GetBufferPointer());
}

MACH_EXPORT size_t machDxcCompileObjectGetBytesLength(MachDxcCompileObject err) {
CComPtr<IDxcBlob> pObject = CComPtr(reinterpret_cast<IDxcBlob*>(err));
return pObject->GetBufferSize();
}

MACH_EXPORT void machDxcCompileObjectDeinit(MachDxcCompileObject err) {
CComPtr<IDxcBlob> pObject = CComPtr(reinterpret_cast<IDxcBlob*>(err));
pObject.Release();
}

//--------------------
// MachDxcCompileError
//--------------------
MACH_EXPORT char const* machDxcCompileErrorGetString(MachDxcCompileError err) {
CComPtr<IDxcBlobUtf8> pErrors = CComPtr(reinterpret_cast<IDxcBlobUtf8*>(err));
return (char const*)(pErrors->GetBufferPointer());
}

MACH_EXPORT size_t machDxcCompileErrorGetStringLength(MachDxcCompileError err) {
CComPtr<IDxcBlobUtf8> pErrors = CComPtr(reinterpret_cast<IDxcBlobUtf8*>(err));
return pErrors->GetStringLength();
}

MACH_EXPORT void machDxcCompileErrorDeinit(MachDxcCompileError err) {
CComPtr<IDxcBlobUtf8> pErrors = CComPtr(reinterpret_cast<IDxcBlobUtf8*>(err));
pErrors.Release();
}

#ifdef __cplusplus
Expand Down
76 changes: 75 additions & 1 deletion src/mach_dxc.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,81 @@ extern "C" {
# define MACH_EXPORT
#endif // defined(MACH_DXC_C_SHARED_LIBRARY)

MACH_EXPORT void machDxcFoo();
#if !defined(MACH_OBJECT_ATTRIBUTE)
#define MACH_OBJECT_ATTRIBUTE
#endif

#include <stddef.h>

typedef struct MachDxcCompilerImpl* MachDxcCompiler MACH_OBJECT_ATTRIBUTE;
typedef struct MachDxcCompileResultImpl* MachDxcCompileResult MACH_OBJECT_ATTRIBUTE;
typedef struct MachDxcCompileErrorImpl* MachDxcCompileError MACH_OBJECT_ATTRIBUTE;
typedef struct MachDxcCompileObjectImpl* MachDxcCompileObject MACH_OBJECT_ATTRIBUTE;

//----------------
// MachDxcCompiler
//----------------

/// Initializes a DXC compiler
///
/// Invoke machDxcDeinit when done with the compiler.
MACH_EXPORT MachDxcCompiler machDxcInit();

/// Deinitializes the DXC compiler.
MACH_EXPORT void machDxcDeinit(MachDxcCompiler compiler);

//---------------------
// MachDxcCompileResult
//---------------------

/// Compiles the given code with the given dxc.exe CLI arguments
///
/// Invoke machDxcCompileResultDeinit when done with the result.
MACH_EXPORT MachDxcCompileResult machDxcCompile(
MachDxcCompiler compiler,
char const* code,
size_t code_len,
char const* const* args,
size_t args_len
);

/// Returns an error object, or null in the case of success.
///
/// Invoke machDxcCompileErrorDeinit when done with the error, iff it was non-null.
MACH_EXPORT MachDxcCompileError machDxcCompileResultGetError(MachDxcCompileResult err);

/// Returns the compiled object code, or null if an error occurred.
MACH_EXPORT MachDxcCompileObject machDxcCompileResultGetObject(MachDxcCompileResult err);

/// Deinitializes the DXC compiler.
MACH_EXPORT void machDxcCompileResultDeinit(MachDxcCompileResult err);

//---------------------
// MachDxcCompileObject
//---------------------

/// Returns a pointer to the raw bytes of the compiled object file.
MACH_EXPORT char const* machDxcCompileObjectGetBytes(MachDxcCompileObject err);

/// Returns the length of the compiled object file.
MACH_EXPORT size_t machDxcCompileObjectGetBytesLength(MachDxcCompileObject err);

/// Deinitializes the compiled object, calling Get methods after this is illegal.
MACH_EXPORT void machDxcCompileObjectDeinit(MachDxcCompileObject err);

//--------------------
// MachDxcCompileError
//--------------------

/// Returns a pointer to the null-terminated UTF-8 encoded error string. This includes
/// compiler warnings, unless they were disabled in the compile arguments.
MACH_EXPORT char const* machDxcCompileErrorGetString(MachDxcCompileError err);

/// Returns the length of the error string.
MACH_EXPORT size_t machDxcCompileErrorGetStringLength(MachDxcCompileError err);

/// Deinitializes the error, calling Get methods after this is illegal.
MACH_EXPORT void machDxcCompileErrorDeinit(MachDxcCompileError err);

#ifdef __cplusplus
} // extern "C"
Expand Down
90 changes: 89 additions & 1 deletion src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,94 @@ const c = @cImport(
@cInclude("mach_dxc.h"),
);

const Compiler = struct {
handle: c.MachDxcCompiler,

pub fn init() Compiler {
const handle = c.machDxcInit();
return .{ .handle = handle };
}

pub fn deinit(compiler: Compiler) void {
c.machDxcDeinit(compiler.handle);
}

pub fn compile(compiler: Compiler, code: []const u8, args: []const [*:0]const u8) Result {
const result = c.machDxcCompile(compiler.handle, code.ptr, code.len, args.ptr, args.len);
return .{ .handle = result };
}

pub const Result = struct {
handle: c.MachDxcCompileResult,

pub fn deinit(result: Result) void {
c.machDxcCompileResultDeinit(result.handle);
}

pub fn getError(result: Result) ?Error {
if (c.machDxcCompileResultGetError(result.handle)) |err| return .{ .handle = err };
return null;
}

pub fn getObject(result: Result) Object {
return .{ .handle = c.machDxcCompileResultGetObject(result.handle) };
}

pub const Error = struct {
handle: c.MachDxcCompileError,

pub fn deinit(err: Error) void {
c.machDxcCompileErrorDeinit(err.handle);
}

pub fn getString(err: Error) []const u8 {
return c.machDxcCompileErrorGetString(err.handle)[0..c.machDxcCompileErrorGetStringLength(err.handle)];
}
};

pub const Object = struct {
handle: c.MachDxcCompileObject,

pub fn deinit(obj: Object) void {
c.machDxcCompileObjectDeinit(obj.handle);
}

pub fn getBytes(obj: Object) []const u8 {
return c.machDxcCompileObjectGetBytes(obj.handle)[0..c.machDxcCompileObjectGetBytesLength(obj.handle)];
}
};
};
};

test {
c.machDxcFoo();
const std = @import("std");

const code =
\\ Texture1D<float4> tex[5] : register(t3);
\\ SamplerState SS[3] : register(s2);
\\
\\ [RootSignature("DescriptorTable(SRV(t3, numDescriptors=5)), DescriptorTable(Sampler(s2, numDescriptors=3))")]
\\ float4 main(int i : A, float j : B) : SV_TARGET
\\ {
\\ float4 r = tex[NonUniformResourceIndex(i)].Sample(SS[NonUniformResourceIndex(i)], i);
\\ r += tex[NonUniformResourceIndex(j)].Sample(SS[i], j+2);
\\ return r;
\\ };
;
const args = &[_][*:0]const u8{ "-E", "main", "-T", "ps_6_0", "-D", "MYDEFINE=1", "-Qstrip_debug", "-Qstrip_reflect" };

const compiler = Compiler.init();
defer compiler.deinit();

const result = compiler.compile(code, args);
if (result.getError()) |err| {
defer err.deinit();
std.debug.print("compiler error: {s}\n", .{err.getString()});
return;
}

const object = result.getObject();
defer object.deinit();

try std.testing.expectEqual(@as(usize, 2392), object.getBytes().len);
}

0 comments on commit d9e236d

Please sign in to comment.