diff --git a/resources/Materials/TestSuite/pbrlib/surfaceshader/network_surfaceshader.mtlx b/resources/Materials/TestSuite/pbrlib/surfaceshader/network_surfaceshader.mtlx new file mode 100644 index 0000000000..fb5202b37c --- /dev/null +++ b/resources/Materials/TestSuite/pbrlib/surfaceshader/network_surfaceshader.mtlx @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/source/MaterialXGenShader/GenContext.h b/source/MaterialXGenShader/GenContext.h index f1de04ace6..1a936ac0dd 100644 --- a/source/MaterialXGenShader/GenContext.h +++ b/source/MaterialXGenShader/GenContext.h @@ -125,6 +125,24 @@ class MX_GENSHADER_API GenContext return _closureContexts.size() ? _closureContexts.back() : nullptr; } + /// Push a parent node onto the stack + void pushParentNode(ConstNodePtr node) + { + _parentNodes.push_back(node); + } + + /// Pop the current parent node from the stack. + void popParentNode() + { + _parentNodes.pop_back(); + } + + /// Return the current stack of parent nodes. + const vector& getParentNodes() + { + return _parentNodes; + } + /// Add user data to the context to make it /// available during shader generator. void pushUserData(const string& name, GenUserDataPtr data) @@ -216,6 +234,7 @@ class MX_GENSHADER_API GenContext std::unordered_map _outputSuffix; vector _closureContexts; + vector _parentNodes; ApplicationVariableHandler _applicationVariableHandler; }; diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index 41ac0da8e7..bcfe1ae441 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -93,7 +93,7 @@ void ShaderGraph::createConnectedNodes(const ElementPtr& downstreamElement, ShaderNode* newNode = getNode(newNodeName); if (!newNode) { - newNode = createNode(*upstreamNode, context); + newNode = createNode(upstreamNode, context); } // @@ -443,171 +443,6 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const NodeGraph& n return graph; } -ShaderGraphPtr ShaderGraph::createSurfaceShader( - const string& name, - const ShaderGraph* parent, - NodePtr node, - GenContext& context, - ElementPtr& root) -{ - NodeDefPtr nodeDef = node->getNodeDef(EMPTY_STRING, true); - if (!nodeDef) - { - throw ExceptionShaderGenError("Could not find a nodedef for shader node '" + node->getName() + - "' with category '" + node->getCategory() + "'"); - } - - ShaderGraphPtr graph = std::make_shared(parent, name, node->getDocument(), context.getReservedWords()); - - // Create input sockets - graph->addInputSockets(*nodeDef, context); - - // Create output sockets - graph->addOutputSockets(*nodeDef); - - // Create this shader node in the graph. - const string& newNodeName = node->getName(); - ShaderNodePtr newNode = ShaderNode::create(graph.get(), newNodeName, *nodeDef, context); - newNode->initialize(*node, *nodeDef, context); - graph->addNode(newNode); - - // Share metadata. - graph->setMetadata(newNode->getMetadata()); - - // Connect it to the graph output - ShaderGraphOutputSocket* outputSocket = graph->getOutputSocket(); - outputSocket->makeConnection(newNode->getOutput()); - outputSocket->setPath(node->getNamePath()); - - ColorManagementSystemPtr colorManagementSystem = context.getShaderGenerator().getColorManagementSystem(); - string targetColorSpace = context.getOptions().targetColorSpaceOverride.empty() ? - node->getDocument()->getColorSpace() : - context.getOptions().targetColorSpaceOverride; - - const string& targetDistanceUnit = context.getOptions().targetDistanceUnit; - UnitSystemPtr unitSystem = context.getShaderGenerator().getUnitSystem(); - - // Set node input values onto graph input sockets - for (const InputPtr& nodeDefInput : nodeDef->getActiveInputs()) - { - ShaderGraphInputSocket* inputSocket = graph->getInputSocket(nodeDefInput->getName()); - ShaderInput* input = newNode->getInput(nodeDefInput->getName()); - if (!inputSocket || !input) - { - throw ExceptionShaderGenError("Shader input '" + nodeDefInput->getName() + "' doesn't match an existing input on graph '" + graph->getName() + "'"); - } - - InputPtr nodeInput = node->getInput(nodeDefInput->getName()); - if (nodeInput) - { - // Copy value from binding - ValuePtr nodeInputValue = nodeInput->getResolvedValue(); - if (nodeInputValue) - { - inputSocket->setValue(nodeInputValue); - input->setBindInput(); - graph->populateColorTransformMap(colorManagementSystem, input, nodeInput, targetColorSpace, true); - graph->populateUnitTransformMap(unitSystem, input, nodeInput, targetDistanceUnit, true); - } - inputSocket->setPath(nodeInput->getNamePath()); - input->setPath(inputSocket->getPath()); - const string& nodeInputUnit = nodeInput->getUnit(); - if (!nodeInputUnit.empty()) - { - inputSocket->setUnit(nodeInputUnit); - input->setUnit(nodeInputUnit); - } - const string& nodeColorspace = nodeInput->getColorSpace(); - if (!nodeColorspace.empty()) - { - inputSocket->setColorSpace(nodeColorspace); - input->setColorSpace(nodeColorspace); - } - } - - // Check if the input is a uniform - bool isUniform = nodeDefInput->getIsUniform(); - if (isUniform) - { - inputSocket->makeConnection(input); - } - else - { - GeomPropDefPtr geomprop = nodeDefInput->getDefaultGeomProp(); - if (geomprop) - { - inputSocket->setGeomProp(geomprop->getName()); - input->setGeomProp(geomprop->getName()); - } - - // If no explicit connection, connect to geometric node if a geomprop is used - // or otherwise to the graph interface. - const string& connection = nodeInput ? nodeInput->getOutputString() : EMPTY_STRING; - if (connection.empty()) - { - if (geomprop) - { - graph->addDefaultGeomNode(input, *geomprop, context); - } - else - { - inputSocket->makeConnection(input); - } - } - } - - // Share metadata. - inputSocket->setMetadata(input->getMetadata()); - } - - // Add shader node paths and unit value - const string& nodePath = node->getNamePath(); - for (auto nodeInput : nodeDef->getActiveInputs()) - { - const string& inputName = nodeInput->getName(); - const string path = nodePath + NAME_PATH_SEPARATOR + inputName; - const string& unit = nodeInput->getUnit(); - const string& colorSpace = nodeInput->getColorSpace(); - ShaderInput* input = newNode->getInput(inputName); - if (input) - { - if (input->getPath().empty()) - { - input->setPath(path); - } - if (input->getUnit().empty() && !unit.empty()) - { - input->setUnit(unit); - } - if (input->getColorSpace().empty() && !colorSpace.empty()) - { - input->setColorSpace(colorSpace); - } - } - ShaderGraphInputSocket* inputSocket = graph->getInputSocket(inputName); - if (inputSocket) - { - if (inputSocket->getPath().empty()) - { - inputSocket->setPath(path); - } - if (inputSocket->getUnit().empty() && !unit.empty()) - { - inputSocket->setUnit(unit); - } - if (inputSocket->getColorSpace().empty() && !colorSpace.empty()) - { - inputSocket->setColorSpace(colorSpace); - } - } - } - - // Start traversal from this shader node - root = node; - - return graph; -} - ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name, ElementPtr element, GenContext& context) { ShaderGraphPtr graph; @@ -674,104 +509,95 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name else if (element->isA()) { NodePtr node = element->asA(); - - // Handle shader nodes different from other nodes - if (node->getType() == SURFACE_SHADER_TYPE_STRING) + NodeDefPtr nodeDef = node->getNodeDef(); + if (!nodeDef) { - graph = createSurfaceShader(name, parent, node, context, root); + throw ExceptionShaderGenError("Could not find a nodedef for node '" + node->getName() + "'"); } - else - { - NodeDefPtr nodeDef = node->getNodeDef(); - if (!nodeDef) - { - throw ExceptionShaderGenError("Could not find a nodedef for node '" + node->getName() + "'"); - } - graph = std::make_shared(parent, name, element->getDocument(), context.getReservedWords()); + graph = std::make_shared(parent, name, element->getDocument(), context.getReservedWords()); - // Create input sockets - graph->addInputSockets(*nodeDef, context); + // Create input sockets + graph->addInputSockets(*nodeDef, context); - // Create output sockets - graph->addOutputSockets(*nodeDef); + // Create output sockets + graph->addOutputSockets(*nodeDef); - // Create this shader node in the graph. - ShaderNodePtr newNode = ShaderNode::create(graph.get(), node->getName(), *nodeDef, context); - graph->addNode(newNode); + // Create this shader node in the graph. + ShaderNodePtr newNode = ShaderNode::create(graph.get(), node->getName(), *nodeDef, context); + graph->addNode(newNode); - // Share metadata. - graph->setMetadata(newNode->getMetadata()); + // Share metadata. + graph->setMetadata(newNode->getMetadata()); - // Connect it to the graph outputs - for (size_t i = 0; i < newNode->numOutputs(); ++i) + // Connect it to the graph outputs + for (size_t i = 0; i < newNode->numOutputs(); ++i) + { + ShaderGraphOutputSocket* outputSocket = graph->getOutputSocket(i); + outputSocket->makeConnection(newNode->getOutput(i)); + outputSocket->setPath(node->getNamePath()); + } + + // Handle node input ports + for (const InputPtr& nodedefInput : nodeDef->getActiveInputs()) + { + ShaderGraphInputSocket* inputSocket = graph->getInputSocket(nodedefInput->getName()); + ShaderInput* input = newNode->getInput(nodedefInput->getName()); + if (!inputSocket || !input) { - ShaderGraphOutputSocket* outputSocket = graph->getOutputSocket(i); - outputSocket->makeConnection(newNode->getOutput(i)); - outputSocket->setPath(node->getNamePath()); + throw ExceptionShaderGenError("Node input '" + nodedefInput->getName() + "' doesn't match an existing input on graph '" + graph->getName() + "'"); } - // Handle node input ports - for (const InputPtr& nodedefInput : nodeDef->getActiveInputs()) + // Copy data from node element to shadergen representation + InputPtr nodeInput = node->getInput(nodedefInput->getName()); + if (nodeInput) { - ShaderGraphInputSocket* inputSocket = graph->getInputSocket(nodedefInput->getName()); - ShaderInput* input = newNode->getInput(nodedefInput->getName()); - if (!inputSocket || !input) + ValuePtr value = nodeInput->getResolvedValue(); + if (value) { - throw ExceptionShaderGenError("Node input '" + nodedefInput->getName() + "' doesn't match an existing input on graph '" + graph->getName() + "'"); - } - - // Copy data from node element to shadergen representation - InputPtr nodeInput = node->getInput(nodedefInput->getName()); - if (nodeInput) - { - ValuePtr value = nodeInput->getResolvedValue(); - if (value) - { - const string& valueString = value->getValueString(); - std::pair enumResult; - const TypeDesc* type = TypeDesc::get(nodedefInput->getType()); - const string& enumNames = nodedefInput->getAttribute(ValueElement::ENUM_ATTRIBUTE); - if (context.getShaderGenerator().getSyntax().remapEnumeration(valueString, type, enumNames, enumResult)) - { - inputSocket->setValue(enumResult.second); - } - else - { - inputSocket->setValue(value); - } - } - - const string path = nodeInput->getNamePath(); - if (!path.empty()) - { - inputSocket->setPath(path); - input->setPath(path); - } - const string& unit = nodeInput->getUnit(); - if (!unit.empty()) + const string& valueString = value->getValueString(); + std::pair enumResult; + const TypeDesc* type = TypeDesc::get(nodedefInput->getType()); + const string& enumNames = nodedefInput->getAttribute(ValueElement::ENUM_ATTRIBUTE); + if (context.getShaderGenerator().getSyntax().remapEnumeration(valueString, type, enumNames, enumResult)) { - inputSocket->setUnit(unit); - input->setUnit(unit); + inputSocket->setValue(enumResult.second); } - const string& colorSpace = nodeInput->getColorSpace(); - if (!colorSpace.empty()) + else { - inputSocket->setColorSpace(colorSpace); - input->setColorSpace(colorSpace); + inputSocket->setValue(value); } } - // Connect graph socket to the node input - inputSocket->makeConnection(input); - - // Share metadata. - inputSocket->setMetadata(input->getMetadata()); + const string path = nodeInput->getNamePath(); + if (!path.empty()) + { + inputSocket->setPath(path); + input->setPath(path); + } + const string& unit = nodeInput->getUnit(); + if (!unit.empty()) + { + inputSocket->setUnit(unit); + input->setUnit(unit); + } + const string& colorSpace = nodeInput->getColorSpace(); + if (!colorSpace.empty()) + { + inputSocket->setColorSpace(colorSpace); + input->setColorSpace(colorSpace); + } } - // Set root for upstream dependency traversal - root = node; + // Connect graph socket to the node input + inputSocket->makeConnection(input); + + // Share metadata. + inputSocket->setMetadata(input->getMetadata()); } + + // Set root for upstream dependency traversal + root = node; } if (!graph) @@ -790,23 +616,24 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name return graph; } -ShaderNode* ShaderGraph::createNode(const Node& node, GenContext& context) +ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context) { - NodeDefPtr nodeDef = node.getNodeDef(); + NodeDefPtr nodeDef = node->getNodeDef(); if (!nodeDef) { - throw ExceptionShaderGenError("Could not find a nodedef for node '" + node.getName() + "'"); + throw ExceptionShaderGenError("Could not find a nodedef for node '" + node->getName() + "'"); } // Create this node in the graph. - const string& name = node.getName(); - ShaderNodePtr newNode = ShaderNode::create(this, name, *nodeDef, context); - newNode->initialize(node, *nodeDef, context); - _nodeMap[name] = newNode; + context.pushParentNode(node); + ShaderNodePtr newNode = ShaderNode::create(this, node->getName(), *nodeDef, context); + newNode->initialize(*node, *nodeDef, context); + _nodeMap[node->getName()] = newNode; _nodeOrder.push_back(newNode.get()); + context.popParentNode(); // Check if any of the node inputs should be connected to the graph interface - for (ValueElementPtr elem : node.getChildrenOfType()) + for (ValueElementPtr elem : node->getChildrenOfType()) { const string& interfaceName = elem->getInterfaceName(); if (!interfaceName.empty()) @@ -829,7 +656,7 @@ ShaderNode* ShaderGraph::createNode(const Node& node, GenContext& context) for (const InputPtr& nodeDefInput : nodeDef->getActiveInputs()) { ShaderInput* input = newNode->getInput(nodeDefInput->getName()); - InputPtr nodeInput = node.getInput(nodeDefInput->getName()); + InputPtr nodeInput = node->getInput(nodeDefInput->getName()); const string& connection = nodeInput ? nodeInput->getNodeName() : EMPTY_STRING; if (connection.empty() && !input->getConnection()) @@ -843,35 +670,47 @@ ShaderNode* ShaderGraph::createNode(const Node& node, GenContext& context) } // Handle colorspace and unit conversion if needed. + ColorManagementSystemPtr colorManagementSystem = context.getShaderGenerator().getColorManagementSystem(); UnitSystemPtr unitSystem = context.getShaderGenerator().getUnitSystem(); + const string& targetColorSpace = context.getOptions().targetColorSpaceOverride.empty() ? + _document->getActiveColorSpace() : + context.getOptions().targetColorSpaceOverride; const string& targetDistanceUnit = context.getOptions().targetDistanceUnit; - ColorManagementSystemPtr colorManagementSystem = context.getShaderGenerator().getColorManagementSystem(); - string targetColorSpace = context.getOptions().targetColorSpaceOverride.empty() ? - _document->getActiveColorSpace() : - context.getOptions().targetColorSpaceOverride; - - for (InputPtr input : node.getInputs()) + for (InputPtr input : node->getInputs()) { - if (input->getType() == FILENAME_TYPE_STRING) + if (input->hasValue() || input->hasInterfaceName()) { - ShaderOutput* shaderOutput = newNode->getOutput(); - if (shaderOutput) + string sourceColorSpace = input->getActiveColorSpace(); + if (input->getType() == FILENAME_TYPE_STRING && node->isColorType()) { - string colorSpace = populateColorTransformMap(colorManagementSystem, shaderOutput, input, targetColorSpace, false); - ShaderInput* shaderInput = newNode->getInput(input->getName()); - if (shaderInput && !colorSpace.empty()) + // Adjust the source color space for filename interface names. + if (input->hasInterfaceName()) { - shaderInput->setColorSpace(colorSpace); + for (ConstNodePtr parentNode : context.getParentNodes()) + { + if (!parentNode->isColorType()) + { + InputPtr interfaceInput = parentNode->getInput(input->getInterfaceName()); + string interfaceColorSpace = interfaceInput ? interfaceInput->getActiveColorSpace() : EMPTY_STRING; + if (!interfaceColorSpace.empty()) + { + sourceColorSpace = interfaceColorSpace; + } + } + } } + + ShaderOutput* shaderOutput = newNode->getOutput(); + populateColorTransformMap(colorManagementSystem, shaderOutput, sourceColorSpace, targetColorSpace, false); populateUnitTransformMap(unitSystem, shaderOutput, input, targetDistanceUnit, false); } - } - else - { - ShaderInput* shaderInput = newNode->getInput(input->getName()); - populateColorTransformMap(colorManagementSystem, shaderInput, input, targetColorSpace, true); - populateUnitTransformMap(unitSystem, shaderInput, input, targetDistanceUnit, true); + else + { + ShaderInput* shaderInput = newNode->getInput(input->getName()); + populateColorTransformMap(colorManagementSystem, shaderInput, sourceColorSpace, targetColorSpace, true); + populateUnitTransformMap(unitSystem, shaderInput, input, targetDistanceUnit, true); + } } } @@ -1229,48 +1068,44 @@ void ShaderGraph::setVariableNames(GenContext& context) } } -string ShaderGraph::populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort, - ValueElementPtr input, const string& targetColorSpace, bool asInput) +void ShaderGraph::populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort, + const string& sourceColorSpace, const string& targetColorSpace, bool asInput) { - if (targetColorSpace.empty()) + if (!shaderPort || + sourceColorSpace.empty() || + targetColorSpace.empty() || + sourceColorSpace == targetColorSpace) { - return EMPTY_STRING; + return; } - const string& sourceColorSpace = input->getActiveColorSpace(); - if (shaderPort && !sourceColorSpace.empty()) + if (*(shaderPort->getType()) == *Type::COLOR3 || *(shaderPort->getType()) == *Type::COLOR4) { - if (*(shaderPort->getType()) == *Type::COLOR3 || *(shaderPort->getType()) == *Type::COLOR4) + // Store the source color space on the shader port. + shaderPort->setColorSpace(sourceColorSpace); + + // Update the color transform map, if a color management system is provided. + if (colorManagementSystem) { - // If we're converting between two identical color spaces than we have no work to do. - if (sourceColorSpace != targetColorSpace) + ColorSpaceTransform transform(sourceColorSpace, targetColorSpace, shaderPort->getType()); + if (colorManagementSystem->supportsTransform(transform)) { - // Cache colorspace on shader port - shaderPort->setColorSpace(sourceColorSpace); - if (colorManagementSystem) + if (asInput) { - ColorSpaceTransform transform(sourceColorSpace, targetColorSpace, shaderPort->getType()); - if (colorManagementSystem->supportsTransform(transform)) - { - if (asInput) - { - _inputColorTransformMap.emplace(static_cast(shaderPort), transform); - } - else - { - _outputColorTransformMap.emplace(static_cast(shaderPort), transform); - } - } - else - { - std::cerr << "Unsupported color space transform from " << - sourceColorSpace << " to " << targetColorSpace << std::endl; - } + _inputColorTransformMap.emplace(static_cast(shaderPort), transform); } + else + { + _outputColorTransformMap.emplace(static_cast(shaderPort), transform); + } + } + else + { + std::cerr << "Unsupported color space transform from " << + sourceColorSpace << " to " << targetColorSpace << std::endl; } } } - return sourceColorSpace; } void ShaderGraph::populateUnitTransformMap(UnitSystemPtr unitSystem, ShaderPort* shaderPort, ValueElementPtr input, const string& globalTargetUnitSpace, bool asInput) diff --git a/source/MaterialXGenShader/ShaderGraph.h b/source/MaterialXGenShader/ShaderGraph.h index 4117599c27..b8ac15fc46 100644 --- a/source/MaterialXGenShader/ShaderGraph.h +++ b/source/MaterialXGenShader/ShaderGraph.h @@ -93,7 +93,7 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode const vector& getOutputSockets() const { return _inputOrder; } /// Create a new node in the graph - ShaderNode* createNode(const Node& node, GenContext& context); + ShaderNode* createNode(ConstNodePtr node, GenContext& context); /// Add input/output sockets ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type); @@ -112,13 +112,6 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode IdentifierMap& getIdentifierMap() { return _identifiers; } protected: - static ShaderGraphPtr createSurfaceShader( - const string& name, - const ShaderGraph* parent, - NodePtr node, - GenContext& context, - ElementPtr& root); - /// Create node connections corresponding to the connection between a pair of elements. /// @param downstreamElement Element representing the node to connect to. /// @param upstreamElement Element representing the node to connect from @@ -171,9 +164,10 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode /// to avoid name conflicts during shader generation. void setVariableNames(GenContext& context); - /// Populates the input or output color transform map if the provided input/parameter - /// has a color space attribute and has a type of color3 or color4. - string populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort, ValueElementPtr element, const string& targetColorSpace, bool asInput); + /// Populate the color transform map for the given shader port, if the provided combination of + /// source and target color spaces are supported for its data type. + void populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort, + const string& sourceColorSpace, const string& targetColorSpace, bool asInput); /// Populates the appropriate unit transform map if the provided input/parameter or output /// has a unit attribute and is of the supported type diff --git a/source/MaterialXGenShader/ShaderTranslator.cpp b/source/MaterialXGenShader/ShaderTranslator.cpp index 2d53d14925..8a03c106cb 100644 --- a/source/MaterialXGenShader/ShaderTranslator.cpp +++ b/source/MaterialXGenShader/ShaderTranslator.cpp @@ -52,9 +52,10 @@ void ShaderTranslator::connectTranslationInputs(NodePtr shader, NodeDefPtr trans throw Exception("Shader input has no associated output or value " + shaderInput->getName()); } - if (shaderInput->hasColorSpace()) + string colorSpace = shaderInput->getActiveColorSpace(); + if (!colorSpace.empty()) { - input->setColorSpace(shaderInput->getColorSpace()); + input->setColorSpace(colorSpace); } if (shaderInput->hasUnit()) {