Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ be cheap to copy. When returning shared pointers, they should be returned by
value rather than by reference. Methods should be marked as `const` whenever
they do not modify the object's state.

#### Thread Safety

MaterialX classes support multiple concurrent readers, but not concurrent
reads and writes, following the pattern of standard C++ containers. This
design enables efficient parallel processing in read-heavy workloads such
as shader generation and scene traversal, while keeping the implementation
simple and avoiding the overhead of fine-grained locking.

#### Exception Handling

Exceptions should be used for exceptional conditions rather than for normal
Expand Down
188 changes: 103 additions & 85 deletions source/MaterialXCore/Document.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <MaterialXCore/Document.h>

#include <atomic>
#include <mutex>

MATERIALX_NAMESPACE_BEGIN
Expand All @@ -29,82 +30,120 @@ class Document::Cache
{
public:
Cache() :
valid(false)
_valid(false)
{
}
~Cache() = default;

void setDocument(weak_ptr<Document> document)
{
_doc = document;
invalidate();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if another thread tries to fetch cached data in the miniscule moment after _doc is set, but before invalidate() has had the time to set _valid to false?
I would recommend using a C++17 std::shared_mutex instead for robustness. It allows multiple-reader/single-writer semantics.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good suggestion, @JGamache-autodesk, and I've updated the PR to use this approach.

}

void invalidate()
{
_valid.store(false, std::memory_order_relaxed);
}

vector<PortElementPtr> getMatchingPorts(const string& nodeName)
{
refresh();
auto it = _portElementMap.find(nodeName);
return (it != _portElementMap.end()) ? it->second : vector<PortElementPtr>();
}

vector<NodeDefPtr> getMatchingNodeDefs(const string& nodeName)
{
refresh();
auto it = _nodeDefMap.find(nodeName);
return (it != _nodeDefMap.end()) ? it->second : vector<NodeDefPtr>();
}

vector<InterfaceElementPtr> getMatchingImplementations(const string& nodeDef)
{
refresh();
auto it = _implementationMap.find(nodeDef);
return (it != _implementationMap.end()) ? it->second : vector<InterfaceElementPtr>();
}

private:
void refresh()
{
// Thread synchronization for multiple concurrent readers of a single document.
std::lock_guard<std::mutex> guard(mutex);
// Perform a lock-free read and return if the cache is valid.
if (_valid.load(std::memory_order_acquire))
{
return;
}

if (!valid)
// Acquire a lock and double-check the valid flag for robustness.
std::lock_guard<std::mutex> lock(_mutex);
if (_valid.load(std::memory_order_relaxed))
{
// Clear the existing cache.
portElementMap.clear();
nodeDefMap.clear();
implementationMap.clear();
return;
}

// Traverse the document to build a new cache.
for (ElementPtr elem : doc.lock()->traverseTree())
{
const string& nodeName = elem->getAttribute(PortElement::NODE_NAME_ATTRIBUTE);
const string& nodeGraphName = elem->getAttribute(PortElement::NODE_GRAPH_ATTRIBUTE);
const string& nodeString = elem->getAttribute(NodeDef::NODE_ATTRIBUTE);
const string& nodeDefString = elem->getAttribute(InterfaceElement::NODE_DEF_ATTRIBUTE);
// Verify that the document is still valid.
auto doc = _doc.lock();
if (!doc)
{
return;
}

if (!nodeName.empty())
{
PortElementPtr portElem = elem->asA<PortElement>();
if (portElem)
{
portElementMap[portElem->getQualifiedName(nodeName)].push_back(portElem);
}
}
else
// Clear the existing cache.
_portElementMap.clear();
_nodeDefMap.clear();
_implementationMap.clear();

// Traverse the document to build a new cache.
for (ElementPtr elem : doc->traverseTree())
{
const string& nodeName = elem->getAttribute(PortElement::NODE_NAME_ATTRIBUTE);
const string& nodeGraphName = elem->getAttribute(PortElement::NODE_GRAPH_ATTRIBUTE);
const string& nodeString = elem->getAttribute(NodeDef::NODE_ATTRIBUTE);
const string& nodeDefString = elem->getAttribute(InterfaceElement::NODE_DEF_ATTRIBUTE);

const string& portKey = !nodeName.empty() ? nodeName : nodeGraphName;
if (!portKey.empty())
{
PortElementPtr portElem = elem->asA<PortElement>();
if (portElem)
{
if (!nodeGraphName.empty())
{
PortElementPtr portElem = elem->asA<PortElement>();
if (portElem)
{
portElementMap[portElem->getQualifiedName(nodeGraphName)].push_back(portElem);
}
}
_portElementMap[portElem->getQualifiedName(portKey)].push_back(portElem);
}
if (!nodeString.empty())
}
if (!nodeString.empty())
{
NodeDefPtr nodeDef = elem->asA<NodeDef>();
if (nodeDef)
{
NodeDefPtr nodeDef = elem->asA<NodeDef>();
if (nodeDef)
{
nodeDefMap[nodeDef->getQualifiedName(nodeString)].push_back(nodeDef);
}
_nodeDefMap[nodeDef->getQualifiedName(nodeString)].push_back(nodeDef);
}
if (!nodeDefString.empty())
}
if (!nodeDefString.empty())
{
InterfaceElementPtr interface = elem->asA<InterfaceElement>();
if (interface)
{
InterfaceElementPtr interface = elem->asA<InterfaceElement>();
if (interface)
if (interface->isA<Implementation>() || interface->isA<NodeGraph>())
{
if (interface->isA<Implementation>() || interface->isA<NodeGraph>())
{
implementationMap[interface->getQualifiedName(nodeDefString)].push_back(interface);
}
_implementationMap[interface->getQualifiedName(nodeDefString)].push_back(interface);
}
}
}

valid = true;
}

// Release semantics ensure all map writes are visible before valid becomes true.
_valid.store(true, std::memory_order_release);
}

public:
weak_ptr<Document> doc;
std::mutex mutex;
bool valid;
std::unordered_map<string, std::vector<PortElementPtr>> portElementMap;
std::unordered_map<string, std::vector<NodeDefPtr>> nodeDefMap;
std::unordered_map<string, std::vector<InterfaceElementPtr>> implementationMap;
private:
weak_ptr<Document> _doc;
std::mutex _mutex;
std::atomic<bool> _valid;
std::unordered_map<string, std::vector<PortElementPtr>> _portElementMap;
std::unordered_map<string, std::vector<NodeDefPtr>> _nodeDefMap;
std::unordered_map<string, std::vector<InterfaceElementPtr>> _implementationMap;
};

//
Expand All @@ -124,7 +163,7 @@ Document::~Document()
void Document::initialize()
{
_root = getSelf();
_cache->doc = getDocument();
_cache->setDocument(getDocument());

clearContent();
setVersionIntegers(MATERIALX_MAJOR_VERSION, MATERIALX_MINOR_VERSION);
Expand Down Expand Up @@ -284,18 +323,7 @@ std::pair<int, int> Document::getVersionIntegers() const

vector<PortElementPtr> Document::getMatchingPorts(const string& nodeName) const
{
// Refresh the cache.
_cache->refresh();

// Return all port elements matching the given node name.
if (_cache->portElementMap.count(nodeName))
{
return _cache->portElementMap.at(nodeName);
}
else
{
return vector<PortElementPtr>();
}
return _cache->getMatchingPorts(nodeName);
}

ValuePtr Document::getGeomPropValue(const string& geomPropName, const string& geom) const
Expand Down Expand Up @@ -342,19 +370,14 @@ vector<OutputPtr> Document::getMaterialOutputs() const
vector<NodeDefPtr> Document::getMatchingNodeDefs(const string& nodeName) const
{
// Recurse to data library if present.
vector<NodeDefPtr> matchingNodeDefs = hasDataLibrary() ?
vector<NodeDefPtr> matchingNodeDefs = hasDataLibrary() ?
getDataLibrary()->getMatchingNodeDefs(nodeName) :
vector<NodeDefPtr>();

// Refresh the cache.
_cache->refresh();
// Append all nodedefs matching the given node name.
vector<NodeDefPtr> localNodeDefs = _cache->getMatchingNodeDefs(nodeName);
matchingNodeDefs.insert(matchingNodeDefs.end(), localNodeDefs.begin(), localNodeDefs.end());

// Return all nodedefs matching the given node name.
if (_cache->nodeDefMap.count(nodeName))
{
matchingNodeDefs.insert(matchingNodeDefs.end(), _cache->nodeDefMap.at(nodeName).begin(), _cache->nodeDefMap.at(nodeName).end());
}

return matchingNodeDefs;
}

Expand All @@ -364,15 +387,10 @@ vector<InterfaceElementPtr> Document::getMatchingImplementations(const string& n
vector<InterfaceElementPtr> matchingImplementations = hasDataLibrary() ?
getDataLibrary()->getMatchingImplementations(nodeDef) :
vector<InterfaceElementPtr>();

// Refresh the cache.
_cache->refresh();

// Return all implementations matching the given nodedef string.
if (_cache->implementationMap.count(nodeDef))
{
matchingImplementations.insert(matchingImplementations.end(), _cache->implementationMap.at(nodeDef).begin(), _cache->implementationMap.at(nodeDef).end());
}
// Append all implementations matching the given nodedef string.
vector<InterfaceElementPtr> localImpls = _cache->getMatchingImplementations(nodeDef);
matchingImplementations.insert(matchingImplementations.end(), localImpls.begin(), localImpls.end());

return matchingImplementations;
}
Expand All @@ -388,7 +406,7 @@ bool Document::validate(string* message) const

void Document::invalidateCache()
{
_cache->valid = false;
_cache->invalidate();
}

//
Expand Down