From 016a4689f8296973dd2066fc84cb6deabce6947a Mon Sep 17 00:00:00 2001 From: Lukas Stockner Date: Tue, 7 Jan 2025 00:22:58 +0100 Subject: [PATCH] Generate symbol derivatives for outputs when requested (#1916) Signed-off-by: Lukas Stockner --- src/liboslexec/batched_analysis.cpp | 7 ++----- src/liboslexec/batched_llvm_instance.cpp | 16 ++++------------ src/liboslexec/llvm_instance.cpp | 16 ++++------------ src/liboslexec/oslexec_pvt.h | 15 +++++++++++++++ src/liboslexec/runtimeoptimize.cpp | 9 +++++++++ 5 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/liboslexec/batched_analysis.cpp b/src/liboslexec/batched_analysis.cpp index 9f76c1acf..626c7e78e 100644 --- a/src/liboslexec/batched_analysis.cpp +++ b/src/liboslexec/batched_analysis.cpp @@ -2175,11 +2175,8 @@ struct Analyzer { const SymLocationDesc* symloc = nullptr; if (interpolate_param) { // See if userdata input placement has been used for this symbol - ustring layersym = ustring::fmtformat("{}.{}", inst()->layername(), - s.name()); - symloc = m_ba.group().find_symloc(layersym, SymArena::UserData); - if (!symloc) - symloc = m_ba.group().find_symloc(s.name(), SymArena::UserData); + symloc = m_ba.group().find_symloc(s.name(), inst()->layername(), + SymArena::UserData); if (symloc != nullptr) { // We copy values from userdata pre-placement which always succeeds // We must track this write, not because it will need to be masked diff --git a/src/liboslexec/batched_llvm_instance.cpp b/src/liboslexec/batched_llvm_instance.cpp index a8fb3eb8c..771a07408 100644 --- a/src/liboslexec/batched_llvm_instance.cpp +++ b/src/liboslexec/batched_llvm_instance.cpp @@ -1190,11 +1190,8 @@ BatchedBackendLLVM::llvm_assign_initial_value( llvm::Value* got_userdata = nullptr; // See if userdata input placement has been used for this symbol - ustring layersym = ustring::fmtformat("{}.{}", inst()->layername(), - sym.name()); - symloc = group().find_symloc(layersym, SymArena::UserData); - if (!symloc) - symloc = group().find_symloc(sym.name(), SymArena::UserData); + symloc = group().find_symloc(sym.name(), inst()->layername(), + SymArena::UserData); if (symloc) { // We had a userdata pre-placement record for this variable. // Just copy from the correct offset location! @@ -2361,13 +2358,8 @@ BatchedBackendLLVM::build_llvm_instance(bool groupentry) { if (!s.renderer_output()) // Skip if not a renderer output continue; - // Try to look up the sym among the outputs with the full layer.name - // specification first. If that fails, look for name only. - ustring layersym = ustring::fmtformat("{}.{}", inst()->layername(), - s.name()); - auto symloc = group().find_symloc(layersym, SymArena::Outputs); - if (!symloc) - symloc = group().find_symloc(s.name(), SymArena::Outputs); + auto symloc = group().find_symloc(s.name(), inst()->layername(), + SymArena::Outputs); if (!symloc) { // std::cout << "No output copy for " << s.name() // << " because no symloc was found\n"; diff --git a/src/liboslexec/llvm_instance.cpp b/src/liboslexec/llvm_instance.cpp index 8554925dd..ea18acc44 100644 --- a/src/liboslexec/llvm_instance.cpp +++ b/src/liboslexec/llvm_instance.cpp @@ -818,11 +818,8 @@ BackendLLVM::llvm_assign_initial_value(const Symbol& sym, bool force) llvm::Value* got_userdata = nullptr; // See if userdata input placement has been used for this symbol - ustring layersym = ustring::fmtformat("{}.{}", inst()->layername(), - sym.name()); - symloc = group().find_symloc(layersym, SymArena::UserData); - if (!symloc) - symloc = group().find_symloc(sym.name(), SymArena::UserData); + symloc = group().find_symloc(sym.name(), inst()->layername(), + SymArena::UserData); if (symloc) { // We had a userdata pre-placement record for this variable. // Just copy from the correct offset location! @@ -1793,13 +1790,8 @@ BackendLLVM::build_llvm_instance(bool groupentry) { if (!s.renderer_output()) // Skip if not a renderer output continue; - // Try to look up the sym among the outputs with the full layer.name - // specification first. If that fails, look for name only. - ustring layersym = ustring::fmtformat("{}.{}", inst()->layername(), - s.name()); - auto symloc = group().find_symloc(layersym, SymArena::Outputs); - if (!symloc) - symloc = group().find_symloc(s.name(), SymArena::Outputs); + auto symloc = group().find_symloc(s.name(), inst()->layername(), + SymArena::Outputs); if (!symloc) { // std::cout << "No output copy for " << s.name() // << " because no symloc was found\n"; diff --git a/src/liboslexec/oslexec_pvt.h b/src/liboslexec/oslexec_pvt.h index d1a0f35a7..7a8dedc81 100644 --- a/src/liboslexec/oslexec_pvt.h +++ b/src/liboslexec/oslexec_pvt.h @@ -1931,6 +1931,21 @@ class ShaderGroup { return nullptr; } + // Find the SymLocationDesc for this named param but only if it matches + // the arena type, returning its pointer or nullptr if that name is not + // found. + // Try to look up the sym with the full layer.name specification first. + // If that fails, try again based on name only. + const SymLocationDesc* find_symloc(ustring name, ustring layer, + SymArena arena) const + { + ustring layersym = ustring::fmtformat("{}.{}", layer, name); + auto symloc = find_symloc(layersym, arena); + if (!symloc) + symloc = find_symloc(name, arena); + return symloc; + } + // Given a data block for interactive params, allocate space for it to // live with the group and copy the initial data. void setup_interactive_arena(cspan paramblock); diff --git a/src/liboslexec/runtimeoptimize.cpp b/src/liboslexec/runtimeoptimize.cpp index 2248337db..d420ebd5f 100644 --- a/src/liboslexec/runtimeoptimize.cpp +++ b/src/liboslexec/runtimeoptimize.cpp @@ -2698,6 +2698,15 @@ RuntimeOptimizer::track_variable_dependencies() if (s.symtype() == SymTypeGlobal && s.everwritten() && !s.typespec().is_closure_based() && s.mangled() != Strings::N) s.has_derivs(true); + // If the renderer requests a symbol to be written as an output with + // derivs, mark them to be generated. + if (s.renderer_output()) { + auto symloc = group().find_symloc(s.name(), inst()->layername(), + SymArena::Outputs); + if (symloc && symloc->derivs && s.everwritten() + && !s.typespec().is_closure_based()) + s.has_derivs(true); + } if (s.has_derivs()) add_dependency(symdeps, DerivSym, snum); ++snum;