diff --git a/GEMINI.md b/GEMINI.md index 57931f957f..eae9943099 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -26,3 +26,10 @@ This file provides a context for the XLS project. **Documentation:** * OSS docs are in `docs_src` and rendered with `mkdocs` at [https://google.github.io/xls/](https://google.github.io/xls/). + +**‼️ Agent Instructions ‼️** + +* **NodeMap/NodeSet Usage:** New code should generally use `NodeMap` or + `NodeSet` instead of `absl::flat_hash_map` or `absl::flat_hash_set` when the + key is `Node*` and the value is not trivially copyable. Existing code should + not be modified unless specifically directed to do so. diff --git a/xls/common/BUILD b/xls/common/BUILD index 835d488a55..e26ae38f25 100644 --- a/xls/common/BUILD +++ b/xls/common/BUILD @@ -640,6 +640,11 @@ cc_library( ], ) +cc_library( + name = "pointer_utils", + hdrs = ["pointer_utils.h"], +) + cc_test( name = "stopwatch_test", srcs = ["stopwatch_test.cc"], diff --git a/xls/common/pointer_utils.h b/xls/common/pointer_utils.h new file mode 100644 index 0000000000..6645d5019e --- /dev/null +++ b/xls/common/pointer_utils.h @@ -0,0 +1,38 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_COMMON_POINTER_UTILS_H_ +#define XLS_COMMON_POINTER_UTILS_H_ + +#include + +namespace xls { + +namespace internal { +using RawDeletePtr = void (*)(void*); +} + +// type-erased version of unique ptr that keeps track of the appropriate +// destructor. To use reinterpret_cast(x.get()). +using TypeErasedUniquePtr = std::unique_ptr; + +template > +TypeErasedUniquePtr EraseType(std::unique_ptr ptr) { + return TypeErasedUniquePtr( + ptr.release(), [](void* ptr) { Deleter()(reinterpret_cast(ptr)); }); +} + +} // namespace xls + +#endif // XLS_COMMON_POINTER_UTILS_H_ diff --git a/xls/data_structures/inline_bitmap.h b/xls/data_structures/inline_bitmap.h index c60f4c6ef8..c2c13efd32 100644 --- a/xls/data_structures/inline_bitmap.h +++ b/xls/data_structures/inline_bitmap.h @@ -39,6 +39,10 @@ class BitmapView; // A bitmap that has 64-bits of inline storage by default. class InlineBitmap { public: + // How many bits are held in one word. + static constexpr int64_t kWordBits = 64; + // How many bytes are held in one word. + static constexpr int64_t kWordBytes = 8; // Constructs an InlineBitmap of width `bit_count` using the bits in // `word`. If `bit_count` is greater than 64, then all high bits are set to // `fill`. @@ -345,9 +349,6 @@ class InlineBitmap { friend uint64_t GetWordBitsAtForTest(const InlineBitmap& ib, int64_t bit_offset); - static constexpr int64_t kWordBits = 64; - static constexpr int64_t kWordBytes = 8; - // Gets the kWordBits bits following bit_offset with 'Get(bit_offset)' being // the LSB, Get(bit_offset + 1) being the next lsb etc. int64_t GetWordBitsAt(int64_t bit_offset) const; diff --git a/xls/data_structures/transitive_closure.h b/xls/data_structures/transitive_closure.h index 7511e9bbb8..4ee557bdb6 100644 --- a/xls/data_structures/transitive_closure.h +++ b/xls/data_structures/transitive_closure.h @@ -16,6 +16,7 @@ #define XLS_DATA_STRUCTURES_TRANSITIVE_CLOSURE_H_ #include +#include #include #include "absl/container/flat_hash_map.h" @@ -25,6 +26,8 @@ namespace xls { +class Node; + namespace internal { // Compute the transitive closure of a relation. template @@ -97,6 +100,28 @@ class DenseIdRelation { absl::Span relation_; }; +template +class NodeRelation { + public: + explicit NodeRelation(NodeRelationBase& relation) : relation_(relation) {} + template + void ForEachKeyValue(F f) const { + for (auto& [j, from_j] : relation_) { + f(j, from_j); + } + } + bool Contains(const absl::flat_hash_set& vs, Node* const& v) const { + return vs.contains(v); + } + void UnionInPlace(absl::flat_hash_set& i, + const absl::flat_hash_set& k) const { + i.insert(k.begin(), k.end()); + } + + private: + NodeRelationBase& relation_; +}; + } // namespace internal template @@ -110,6 +135,19 @@ HashRelation TransitiveClosure(HashRelation v) { return v; } +// Compute the transitive closure of a relation represented as an explicit +// adjacency list using NodeMap. +// +// This has some terrible template stuff to ensure we don't actually need to +// include NodeMap. +template +NodeRelation TransitiveClosure(NodeRelation v) + requires(NodeRelation::kIsNodeMap) +{ + internal::TransitiveClosure(internal::NodeRelation(v)); + return v; +} + // TODO(allight): Using a more efficient bitmap format like croaring might give // a speedup here. using DenseIdRelation = absl::Span; diff --git a/xls/ir/BUILD b/xls/ir/BUILD index a4fabda9ed..a238add092 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -638,9 +638,11 @@ cc_library( "//xls/common:casts", "//xls/common:iterator_range", "//xls/common:math_util", + "//xls/common:pointer_utils", "//xls/common:visitor", "//xls/common/status:ret_check", "//xls/common/status:status_macros", + "//xls/data_structures:inline_bitmap", "//xls/data_structures:leaf_type_tree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -2491,3 +2493,40 @@ cc_test( "@googletest//:gtest", ], ) + +cc_library( + name = "node_map", + hdrs = ["node_map.h"], + deps = [ + ":ir", + "//xls/common:pointer_utils", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:hash_container_defaults", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + ], +) + +cc_test( + name = "node_map_test", + srcs = ["node_map_test.cc"], + deps = [ + ":benchmark_support", + ":bits", + ":function_builder", + ":ir", + ":ir_matcher", + ":ir_test_base", + ":node_map", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@google_benchmark//:benchmark", + "@googletest//:gtest", + ], +) diff --git a/xls/ir/node.cc b/xls/ir/node.cc index da8e82add1..4ada070b3b 100644 --- a/xls/ir/node.cc +++ b/xls/ir/node.cc @@ -36,6 +36,7 @@ #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xls/common/casts.h" +#include "xls/common/pointer_utils.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/ir/change_listener.h" @@ -997,4 +998,28 @@ bool Node::OpIn(absl::Span choices) const { Package* Node::package() const { return function_base()->package(); } +std::optional Node::TakeUserData(int64_t idx) { + DCHECK(function_base()->package()->IsLiveUserDataId(idx)) << idx; + if (user_data_.size() <= idx) { + return std::nullopt; + } + std::optional data = std::move(user_data_[idx]); + user_data_[idx] = std::nullopt; + return data; +} +void* Node::GetUserData(int64_t idx) { + DCHECK(function_base()->package()->IsLiveUserDataId(idx)) << idx; + if (user_data_.size() <= idx) { + user_data_.resize(idx + 1); + } + const auto& v = user_data_[idx]; + return v ? v->get() : nullptr; +} +void Node::SetUserData(int64_t idx, TypeErasedUniquePtr data) { + DCHECK(function_base()->package()->IsLiveUserDataId(idx)) << idx; + if (user_data_.size() <= idx) { + user_data_.resize(idx + 1); + } + user_data_[idx] = std::move(data); +} } // namespace xls diff --git a/xls/ir/node.h b/xls/ir/node.h index 4b79928eff..b4c1273b84 100644 --- a/xls/ir/node.h +++ b/xls/ir/node.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "absl/container/inlined_vector.h" #include "absl/log/check.h" @@ -31,6 +32,7 @@ #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xls/common/casts.h" +#include "xls/common/pointer_utils.h" #include "xls/common/status/status_macros.h" #include "xls/ir/change_listener.h" #include "xls/ir/op.h" @@ -316,6 +318,28 @@ class Node { absl::Format(&sink, "%s", node.GetName()); } + // User-data access functions. Should not be directly used. Use NodeMap + // instead. + // + // Extreme care should be used when interacting with these functions and the + // Package ones since this is basically doing manual memory management. + + // Get the pointer associated with this indexes user data or nullptr if + // never set. Use HasUserData to see if anything has ever been set. + // + // idx must be a value returned by Package::AllocateNodeUserData which has not + // had ReleaseNodeUserDataId called on it. + void* GetUserData(int64_t idx); + // Sets user data at idx to 'data'. + void SetUserData(int64_t idx, TypeErasedUniquePtr data); + // Removes user data at idx from the node. Returns std::nullopt if nothing has + // been set. + std::optional TakeUserData(int64_t idx); + // Checks if anything has ever been set at the given user data. + bool HasUserData(int64_t idx) { + return user_data_.size() > idx && user_data_[idx].has_value(); + } + protected: // FunctionBase needs to be a friend to access RemoveUser for deleting nodes // from the graph. @@ -368,6 +392,15 @@ class Node { // Set of users sorted by node_id for stability. absl::InlinedVector users_; + + private: + std::vector> user_data_; + + // Clear all user data. + void ClearUserData() { user_data_.clear(); }; + + // for ClearUserData + friend class Package; }; inline NodeRef::NodeRef(Node* node) diff --git a/xls/ir/node_map.h b/xls/ir/node_map.h new file mode 100644 index 0000000000..9e73880afd --- /dev/null +++ b/xls/ir/node_map.h @@ -0,0 +1,629 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_IR_NODE_MAP_H_ +#define XLS_IR_NODE_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/hash_container_defaults.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xls/common/pointer_utils.h" +#include "xls/ir/block.h" // IWYU pragma: keep +#include "xls/ir/function.h" // IWYU pragma: keep +#include "xls/ir/node.h" +#include "xls/ir/package.h" +#include "xls/ir/proc.h" // IWYU pragma: keep + +namespace xls { + +using ForceAllowNodeMap = std::false_type; +namespace internal { +template +struct ErrorOnSlower { + static_assert( + !kIsLikelySlower::value, + "NodeMap is likely slower than absl::flat_hash_map for this Value type " + "because it is trivially copyable. To override declare node map as " + "NodeMap<{ValueT}, ForceAllowNodeMap>. Care should be taken to validate " + "that this is actually a performance win however."); +}; +}; // namespace internal + +// `xls::NodeMap` is a map-like interface for holding mappings from `xls::Node*` +// to `ValueT`. It is designed to be a partial drop-in replacement for +// `absl::flat_hash_map` but with better performance in XLS +// workloads. +// +// `NodeMap` achieves this performance by storing `ValueT` logically within the +// `Node` object itself as 'user-data'. This avoids hashing `Node*` and reduces +// cache misses compared to `absl::flat_hash_map`. A read of a value requires +// only 4 pointer reads. +// +// Notable Differences from `absl::flat_hash_map`: +// +// * All operations inherently perform pointer reads on any Node* typed values +// in key positions. This means that attempting to use deallocated Node*s as +// keys **in any way** (including just calling contains, etc) is UB. +// * The node-map has pointer stability of its values as well as iterator +// stability (except for iterators pointing to an entry which is removed +// either by a call to erase or by removing the node which is the entries +// key). +// * `reserve()` is not available as `NodeMap` does not require upfront storage +// allocation in the same way as `absl::flat_hash_map`. +// * Iteration order is from most-recently inserted to least-recently inserted. +// * All keys in a `NodeMap` must come from the same `Package`. +// * If a `Node` is deleted from its function/package, any associated data in +// any `NodeMap` is deallocated and it is removed from the map. +// * Each node with data in any `NodeMap` has an internal vector to hold +// user-data for all live maps. This extra space is not cleaned up until +// package destruction, based on the assumption that only a small number of +// maps will be simultaneously live. +// +// WARNING: This map is not thread safe. Also destruction of a node which has a +// value mapped to it is a modification of the map and needs to be taken into +// account if using this map in a multi-threaded context. +// +// NB This does a lot of very unsafe stuff internally to store the data using +// node user-data. +template > +class NodeMap : public internal::ErrorOnSlower { + private: + // Intrusive list node to hold the actual data allowing us to iterate. + struct DataHolder { + template + DataHolder(Node* n, Args&&... args) + : value(std::piecewise_construct, std::forward_as_tuple(n), + std::forward_as_tuple(std::forward(args)...)), + iter() {} + + ~DataHolder() { + // Remove itself from the list on deletion. + if (configured_list) { + configured_list->erase(iter); + } + } + + std::pair value; + // Intrusive list node to allow for iteration that's somewhat fast. + std::list::iterator iter; + std::list* configured_list = nullptr; + }; + + class ConstIterator; + class Iterator { + public: + using difference_type = ptrdiff_t; + using value_type = std::pair; + using reference = value_type&; + using pointer = value_type*; + using element_type = value_type; + using const_reference = const value_type&; + using iterator_category = std::forward_iterator_tag; + + Iterator() : iter_() {} + Iterator(std::list::iterator iter) : iter_(iter) {} + Iterator(const Iterator& other) : iter_(other.iter_) {} + Iterator& operator=(const Iterator& other) { + iter_ = other.iter_; + return *this; + } + Iterator& operator++() { + ++iter_; + return *this; + } + Iterator operator++(int) { + Iterator tmp = *this; + ++(*this); + return tmp; + } + reference operator*() const { return (*iter_)->value; } + pointer get() const { return &**this; } + pointer operator->() const { return &(*iter_)->value; } + friend bool operator==(const Iterator& a, const Iterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator!=(const Iterator& a, const Iterator& b) { + return !(a == b); + } + + private: + std::list::iterator iter_; + friend class ConstIterator; + }; + class ConstIterator { + public: + using difference_type = ptrdiff_t; + using value_type = std::pair; + using reference = const value_type&; + using pointer = const value_type*; + using element_type = value_type; + using const_reference = const value_type&; + using iterator_category = std::forward_iterator_tag; + ConstIterator() : iter_() {} + ConstIterator(std::list::const_iterator iter) : iter_(iter) {} + ConstIterator(std::list::iterator iter) : iter_(iter) {} + ConstIterator(const ConstIterator& other) : iter_(other.iter_) {} + ConstIterator(const Iterator& other) : iter_(other.iter_) {} + ConstIterator& operator=(const ConstIterator& other) { + iter_ = other.iter_; + return *this; + } + ConstIterator& operator++() { + ++iter_; + return *this; + } + ConstIterator operator++(int) { + ConstIterator tmp = *this; + ++(*this); + return tmp; + } + reference operator*() const { return (*iter_)->value; } + pointer get() const { return &**this; } + pointer operator->() const { return &(*iter_)->value; } + friend bool operator==(const ConstIterator& a, const ConstIterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator!=(const ConstIterator& a, const ConstIterator& b) { + return !(a == b); + } + + private: + std::list::const_iterator iter_; + }; + + public: + static constexpr bool kIsNodeMap = true; + using size_type = size_t; + using difference_type = ptrdiff_t; + using key_equal = absl::DefaultHashContainerEq; + using value_type = std::pair; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using iterator = Iterator; + using const_iterator = ConstIterator; + + // Creates an empty `NodeMap` associated with the given `Package`. + explicit NodeMap(Package* pkg) + : pkg_(pkg), + id_(pkg->AllocateNodeUserDataId()), + values_(std::make_unique>()) {} + // Creates an empty `NodeMap` which will become associated with a `Package` + // upon first insertion. + NodeMap() + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) {} + // Releases all data held by this map and informs the package that this + // map's user-data ID can be reused. + ~NodeMap() { + // Release all the data. + clear(); + } + // Copy constructor. + NodeMap(const NodeMap& other) + : pkg_(other.pkg_), + id_(pkg_ != nullptr ? pkg_->AllocateNodeUserDataId() : -1), + values_(std::make_unique>()) { + for (auto& [k, v] : other) { + this->insert(k, v); + } + } + // Copy assignment. + NodeMap& operator=(const NodeMap& other) { + if (HasPackage()) { + clear(); + } else { + pkg_ = other.pkg_; + id_ = pkg_ != nullptr ? pkg_->AllocateNodeUserDataId() : -1; + } + for (auto& [k, v] : other) { + this->insert(k, v); + } + return *this; + } + // Move constructor. + NodeMap(NodeMap&& other) { + pkg_ = other.pkg_; + id_ = other.id_; + values_ = std::move(other.values_); + other.id_ = -1; + other.pkg_ = nullptr; + } + // Move assignment. + NodeMap& operator=(NodeMap&& other) { + pkg_ = other.pkg_; + id_ = other.id_; + values_ = std::move(other.values_); + other.pkg_ = nullptr; + other.id_ = -1; + return *this; + } + // Range constructor. + template + NodeMap(It first, It last) + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) { + for (auto it = first; it != last; ++it) { + insert(it->first, it->second); + } + } + + // Range constructor with explicit package. + template + NodeMap(Package* pkg, It first, It last) + : pkg_(pkg), + id_(pkg->AllocateNodeUserDataId()), + values_(std::make_unique>()) { + for (auto it = first; it != last; ++it) { + insert(it->first, it->second); + } + } + // Constructs a `NodeMap` from an `absl::flat_hash_map`. + NodeMap(const absl::flat_hash_map& other) + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) { + for (auto& [k, v] : other) { + this->insert(k, v); + } + } + // Assigns contents from an `absl::flat_hash_map`. + NodeMap& operator=(const absl::flat_hash_map& other) { + clear(); + for (auto& [k, v] : other) { + this->insert(k, v); + } + } + // Initializer list constructor. + NodeMap(std::initializer_list init) + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) { + for (const auto& pair : init) { + insert(pair.first, pair.second); + } + } + + // Returns true if the map has an associated package. + bool HasPackage() const { return pkg_ != nullptr; } + + // Returns true if the map contains no elements. + bool empty() const { + CheckValidId(); + return values_->empty(); + } + // Returns the number of elements in the map. + size_t size() const { + CheckValidId(); + return values_->size(); + } + + // Returns true if the map contains an element with key `n`. + bool contains(Node* n) const { + CheckValidId(n); + if (!HasPackage()) { + return false; + } + return n->HasUserData(id_); + } + // Returns 1 if the map contains an element with key `n`, 0 otherwise. + size_t count(Node* n) const { + CheckValidId(n); + if (!HasPackage()) { + return 0; + } + return n->HasUserData(id_) ? 1 : 0; + } + // Returns a reference to the value mapped to key `n`. If no such element + // exists, this function CHECK-fails. + ValueT& at(Node* n) { + EnsureValidId(n); + CHECK(contains(n)) << "Nothing was ever set for " << n; + return reinterpret_cast(n->GetUserData(id_))->value.second; + } + // Returns a reference to the value mapped to key `n`, inserting a + // default-constructed value if `n` is not already present. + ValueT& operator[](Node* n) { + EnsureValidId(n); + if (contains(n)) { + return at(n); + } + auto holder = std::make_unique(n); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return holder_ptr->value.second; + } + // Returns a const reference to the value mapped to key `n`. If no such + // element exists, this function CHECK-fails. + const ValueT& operator[](Node* n) const { return at(n); } + // Returns a const reference to the value mapped to key `n`. If no such + // element exists, this function CHECK-fails. + const ValueT& at(Node* n) const { + CheckValidId(n); + CHECK(contains(n)) << "Nothing was ever set for " << n; + return reinterpret_cast(n->GetUserData(id_))->value.second; + } + + // Erases the element with key `n` if it exists. + void erase(Node* n) { + CheckValidId(n); + if (contains(n)) { + std::optional data = n->TakeUserData(id_); + DCHECK(data); + // The DataHolder could remove itself from the list when its destructor + // runs. It seems better to just be explicit + DataHolder* holder = reinterpret_cast(data->get()); + DCHECK(holder->configured_list == values_.get()); + holder->configured_list = nullptr; + values_->erase(holder->iter); + } + } + + // Erases the element pointed to by `it`. Returns an iterator to the + // element following the erased element. + const_iterator erase(const_iterator it) { + CheckValidId(it->first); + auto res = it; + ++res; + erase(it->first); + return res; + } + + // Erases the element pointed to by `it`. Returns an iterator to the + // element following the erased element. + iterator erase(iterator it) { + CheckValidId(it->first); + auto res = it; + ++res; + erase(it->first); + return res; + } + + // Removes all elements from the map. + ABSL_ATTRIBUTE_REINITIALIZES void clear() { + CheckValidId(); + if (pkg_ == nullptr) { + return; + } + for (DataHolder* v : *values_) { + std::optional data = + v->value.first->TakeUserData(id_); + // We can't remove the current iterator position. + reinterpret_cast(data->get())->configured_list = nullptr; + } + values_->clear(); + // Release the id. + pkg_->ReleaseNodeUserDataId(id_); + } + // Swaps the contents of this map with `other`. + void swap(NodeMap& other) { + std::swap(pkg_, other.pkg_); + std::swap(id_, other.id_); + values_.swap(other.values_); + } + + // Returns an iterator to the first element in the map. + iterator begin() { + CheckValidId(); + return Iterator(values_->begin()); + } + // Returns an iterator to the element following the last element in the map. + iterator end() { + CheckValidId(); + return Iterator(values_->end()); + } + // Returns a const iterator to the first element in the map. + const_iterator cbegin() const { + CheckValidId(); + return ConstIterator(values_->cbegin()); + } + // Returns a const iterator to the element following the last element in the + // map. + const_iterator cend() const { + CheckValidId(); + return ConstIterator(values_->cend()); + } + // Returns a const iterator to the first element in the map. + const_iterator begin() const { + CheckValidId(); + return cbegin(); + } + // Returns a const iterator to the element following the last element in the + // map. + const_iterator end() const { + CheckValidId(); + return cend(); + } + + // Finds an element with key `n`. + // Returns an iterator to the element if found, or `end()` otherwise. + iterator find(Node* n) { + CheckValidId(n); + if (!contains(n)) { + return end(); + } + return Iterator(reinterpret_cast(n->GetUserData(id_))->iter); + } + // Finds an element with key `n`. + // Returns a const iterator to the element if found, or `end()` otherwise. + const_iterator find(Node* n) const { + CheckValidId(n); + if (!contains(n)) { + return end(); + } + return ConstIterator( + reinterpret_cast(n->GetUserData(id_))->iter); + } + + // Inserts a key-value pair into the map if the key does not already exist. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + std::pair insert(Node* n, ValueT value) { + EnsureValidId(n); + if (contains(n)) { + return std::make_pair(find(n), false); + } + auto holder = std::make_unique(n, std::move(value)); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + // Inserts a key-value pair into the map or assigns to the existing value if + // the key already exists. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + std::pair insert_or_assign(Node* n, ValueT value) { + EnsureValidId(n); + if (contains(n)) { + Iterator f = find(n); + f->second = std::move(value); + return std::make_pair(f, false); + } + auto holder = std::make_unique(n, std::move(value)); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + // Inserts an element constructed in-place if the key does not already exist. + // Note: Unlike `try_emplace`, `emplace` may construct `ValueT` from `args` + // even if insertion does not occur. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + template + std::pair emplace(Node* n, Args&&... args) { + EnsureValidId(n); + // If key already exists, construct elements but don't insert. + // This is to match std::map::emplace behavior where element construction + // might happen before check for duplication and value is discarded. + auto holder = std::make_unique(n, std::forward(args)...); + if (contains(n)) { + return std::make_pair(find(n), false); + } + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + // Inserts an element constructed in-place if the key does not already exist. + // If the key already exists, no element is constructed. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + template + std::pair try_emplace(Node* n, Args&&... args) { + EnsureValidId(n); + if (contains(n)) { + return std::make_pair(find(n), false); + } + auto holder = std::make_unique(n, std::forward(args)...); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + friend bool operator==(const Iterator& a, const ConstIterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator==(const ConstIterator& a, const Iterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator!=(const Iterator& a, const ConstIterator& b) { + return a.iter_ != b.iter_; + } + friend bool operator!=(const ConstIterator& a, const Iterator& b) { + return a.iter_ != b.iter_; + } + + private: + void CheckValid() const { +#ifdef DEBUG + CHECK(HasPackage()); + CheckValidId(); +#endif + } + void CheckValidId() const { +#ifdef DEBUG + if (pkg_ != nullptr) { + CHECK(pkg_->IsLiveUserDataId(id_)) << id_; + } +#endif + } + + // Check that this map has a valid id and correct package. + void CheckValidId(Node* n) const { +#ifdef DEBUG + CheckValidId(); + if (HasPackage()) { + CHECK_EQ(n->package(), pkg_) + << "Incorrect package for " << n << " got " << n->package()->name() + << " expected " << pkg_->name(); + } +#endif + } + // Force this map to have a user-data id if it doesn't already + void EnsureValidId(Node* n) { + if (!HasPackage()) { + pkg_ = n->package(); + DCHECK(pkg_ != nullptr) + << "Cannot add a node " << n << " without a package."; + id_ = pkg_->AllocateNodeUserDataId(); + } + CheckValidId(n); + } + + Package* pkg_; + int64_t id_; + std::unique_ptr> values_; +}; + +} // namespace xls + +#endif // XLS_IR_NODE_MAP_H_ diff --git a/xls/ir/node_map_test.cc b/xls/ir/node_map_test.cc new file mode 100644 index 0000000000..6580c0993f --- /dev/null +++ b/xls/ir/node_map_test.cc @@ -0,0 +1,422 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/ir/node_map.h" + +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xls/common/status/matchers.h" +#include "xls/common/status/status_macros.h" +#include "xls/ir/benchmark_support.h" +#include "xls/ir/bits.h" +#include "xls/ir/function_builder.h" +#include "xls/ir/ir_matcher.h" +#include "xls/ir/ir_test_base.h" +#include "xls/ir/package.h" + +namespace m = ::xls::op_matchers; + +namespace xls { +namespace { + +using testing::_; +using testing::Eq; +using testing::Pair; +using testing::UnorderedElementsAre; + +template +using TestNodeMap = NodeMap; + +struct EmplaceOnly { + inline static int constructions = 0; + int x; + explicit EmplaceOnly(int i) : x(i) { ++constructions; } + EmplaceOnly(const EmplaceOnly&) = delete; + EmplaceOnly& operator=(const EmplaceOnly&) = delete; + EmplaceOnly(EmplaceOnly&&) = default; + EmplaceOnly& operator=(EmplaceOnly&&) = default; + bool operator==(const EmplaceOnly& other) const { return x == other.x; } +}; + +class NodeMapTest : public IrTestBase {}; + +TEST_F(NodeMapTest, Basic) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + // Set. + map[a.node()] = 1; + map[b.node()] = 2; + + EXPECT_THAT(map.at(a.node()), Eq(1)); + EXPECT_THAT(map.at(b.node()), Eq(2)); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_TRUE(map.contains(b.node())); + EXPECT_FALSE(map.contains(c.node())); + EXPECT_EQ(map.count(a.node()), 1); + EXPECT_EQ(map.count(b.node()), 1); + EXPECT_EQ(map.count(c.node()), 0); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); + + // Update. + map[a.node()] += 5; + + EXPECT_THAT(map.at(a.node()), Eq(6)); + EXPECT_THAT(map.at(b.node()), Eq(2)); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_TRUE(map.contains(b.node())); + EXPECT_FALSE(map.contains(c.node())); + EXPECT_EQ(map.count(a.node()), 1); + EXPECT_EQ(map.count(b.node()), 1); + EXPECT_EQ(map.count(c.node()), 0); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(6)), + Pair(m::Param("bar"), Eq(2)))); + // erase. + map.erase(a.node()); + + EXPECT_THAT(map.at(b.node()), Eq(2)); + EXPECT_TRUE(map.contains(b.node())); + EXPECT_FALSE(map.contains(c.node())); + EXPECT_FALSE(map.contains(a.node())); + EXPECT_EQ(map.count(a.node()), 0); + EXPECT_EQ(map.count(b.node()), 1); + EXPECT_EQ(map.count(c.node()), 0); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Find) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + // Set. + map[a.node()] = 1; + map[b.node()] = 2; + + EXPECT_THAT(map.find(c.node()), Eq(map.end())); + EXPECT_THAT(map.find(a.node()), + testing::Pointee(Pair(m::Param("foo"), Eq(1)))); + EXPECT_THAT(map.find(b.node()), + testing::Pointee(Pair(m::Param("bar"), Eq(2)))); + // Update with iterator. + map.find(a.node())->second = 33; + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(33)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Copy) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + { + TestNodeMap map1; + // Set. + map1[a.node()] = 1; + map1[b.node()] = 2; + map = map1; + map[a.node()] = 4; + map1[b.node()] = 6; + EXPECT_THAT(map1, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(6)))); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(4)), + Pair(m::Param("bar"), Eq(2)))); + } + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(4)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, IterConstructor) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + std::vector> v = {{a.node(), 1}, {b.node(), 2}}; + TestNodeMap map(v.begin(), v.end()); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Move) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + std::optional> opt_map; + { + TestNodeMap map1; + // Set. + map1[a.node()] = 1; + map1[b.node()] = 2; + opt_map.emplace(std::move(map1)); + EXPECT_FALSE(map1.HasPackage()); + } + + TestNodeMap map(*std::move(opt_map)); + + EXPECT_THAT(map.find(c.node()), Eq(map.end())); + EXPECT_THAT(map.find(a.node()), + testing::Pointee(Pair(m::Param("foo"), Eq(1)))); + EXPECT_THAT(map.find(b.node()), + testing::Pointee(Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Insert) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + // Set. + map[a.node()] = 1; + map[b.node()] = 2; + + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); + EXPECT_THAT(map.insert(c.node(), 3), Pair(_, true)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(3)))); + + EXPECT_THAT(map.insert(a.node(), 3), Pair(map.find(a.node()), false)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(3)))); + + EXPECT_THAT(map.insert_or_assign(a.node(), 3), + Pair(map.find(a.node()), false)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(3)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(3)))); + + map.erase(c.node()); + EXPECT_THAT(map.insert_or_assign(c.node(), 7), Pair(_, true)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(3)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(7)))); +} + +TEST_F(NodeMapTest, Emplace) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + EmplaceOnly::constructions = 0; + auto [it, inserted] = map.emplace(a.node(), 42); + EXPECT_TRUE(inserted); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it->second.x, 42); + EXPECT_EQ(EmplaceOnly::constructions, 1); + // Emplace on existing fails but constructs argument. + auto [it2, inserted2] = map.emplace(a.node(), 44); + EXPECT_FALSE(inserted2); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it2->second.x, 42); + EXPECT_EQ(it2, it); + EXPECT_EQ(EmplaceOnly::constructions, 2); +} + +TEST_F(NodeMapTest, TryEmplace) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + EmplaceOnly::constructions = 0; + auto [it, inserted] = map.try_emplace(a.node(), 42); + EXPECT_TRUE(inserted); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it->second.x, 42); + EXPECT_EQ(EmplaceOnly::constructions, 1); + // try_emplace on existing fails and does not construct argument. + auto [it2, inserted2] = map.try_emplace(a.node(), 44); + EXPECT_FALSE(inserted2); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it2->second.x, 42); + EXPECT_EQ(it2, it); + EXPECT_EQ(EmplaceOnly::constructions, 1); +} + +TEST_F(NodeMapTest, Swap) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map1; + TestNodeMap map2; + map1[a.node()] = 1; + map2[b.node()] = 2; + map1.swap(map2); + EXPECT_THAT(map1, UnorderedElementsAre(Pair(m::Param("bar"), Eq(2)))); + EXPECT_THAT(map2, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)))); +} + +TEST_F(NodeMapTest, InitializerList) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map{{a.node(), 1}, {b.node(), 2}}; + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, NodeDeletionRemovesMapElement) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Param("baz", p->GetBitsType(32)); + // Remove doesn't like getting rid of the last node. + fb.Param("other", p->GetBitsType(32)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + TestNodeMap map{{a.node(), 1}, {b.node(), 2}, {c.node(), 3}}; + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Param("baz"), Eq(3)))); + XLS_ASSERT_OK(f->RemoveNode(b.node())); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("baz"), Eq(3)))); + XLS_ASSERT_OK(f->RemoveNode(a.node())); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("baz"), Eq(3)))); + XLS_ASSERT_OK(f->RemoveNode(c.node())); + EXPECT_THAT(map, UnorderedElementsAre()); +} + +absl::Status GenerateFunction(Package* p, benchmark::State& state) { + FunctionBuilder fb("benchmark", p); + XLS_RETURN_IF_ERROR( + benchmark_support::GenerateChain( + fb, state.range(0), 2, benchmark_support::strategy::BinaryAdd(), + benchmark_support::strategy::SharedLiteral(UBits(32, 32))) + .status()); + return fb.Build().status(); +} + +template +void BM_ReadSome(benchmark::State& v, Setup setup) { + Package p("benchmark"); + XLS_ASSERT_OK(GenerateFunction(&p, v)); + Function* f = p.functions()[0].get(); + for (auto s : v) { + Map map; + setup(map); + int64_t i = 0; + for (Node* n : f->nodes()) { + if (i++ % 4 == 0) { + map[n].value = i; + } + } + for (int64_t i = 0; i < v.range(1); ++i) { + for (Node* n : f->nodes()) { + auto v = map.find(n); + if (v != map.end()) { + benchmark::DoNotOptimize(v->second.value); + } + benchmark::DoNotOptimize(v); + } + } + for (Node* n : f->nodes()) { + if (i++ % 3 == 0) { + map.erase(n); + } + if (i++ % 7 == 0) { + map[n].value = i; + } + } + for (int64_t i = 0; i < v.range(1); ++i) { + for (Node* n : f->nodes()) { + auto v = map.find(n); + if (v != map.end()) { + benchmark::DoNotOptimize(v->second.value); + } + benchmark::DoNotOptimize(v); + } + } + } +} + +// Simulate a typical xls map value which has real destructors etc. +struct TestValue { + int64_t value; + TestValue() ABSL_ATTRIBUTE_NOINLINE : value(12) {} + explicit TestValue(int64_t v) ABSL_ATTRIBUTE_NOINLINE : value(v) { + benchmark::DoNotOptimize(v); + } + TestValue(const TestValue& v) ABSL_ATTRIBUTE_NOINLINE : value(v.value) { + benchmark::DoNotOptimize(v); + } + ~TestValue() ABSL_ATTRIBUTE_NOINLINE { benchmark::DoNotOptimize(value); } + TestValue(TestValue&& v) ABSL_ATTRIBUTE_NOINLINE : value(v.value) { + benchmark::DoNotOptimize(v); + } +}; +void BM_ReadSomeNodeMap(benchmark::State& v) { + BM_ReadSome>(v, [](auto& a) {}); +} +void BM_ReadSomeFlatMap(benchmark::State& v) { + BM_ReadSome>(v, [](auto& a) {}); +} +void BM_ReadSomeFlatMapReserve(benchmark::State& v) { + BM_ReadSome>( + v, [&v](auto& map) { map.reserve(v.range(0)); }); +} +BENCHMARK(BM_ReadSomeNodeMap)->RangePair(100, 100000, 1, 100); +BENCHMARK(BM_ReadSomeFlatMap)->RangePair(100, 100000, 1, 100); +BENCHMARK(BM_ReadSomeFlatMapReserve)->RangePair(100, 100000, 1, 100); + +} // namespace +} // namespace xls diff --git a/xls/ir/package.cc b/xls/ir/package.cc index 3d665f4ab1..452415c582 100644 --- a/xls/ir/package.cc +++ b/xls/ir/package.cc @@ -29,6 +29,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -37,6 +38,7 @@ #include "absl/types/span.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/data_structures/inline_bitmap.h" #include "xls/ir/block.h" #include "xls/ir/call_graph.h" #include "xls/ir/channel.h" @@ -56,7 +58,7 @@ namespace xls { -Package::Package(std::string_view name) : name_(name) {} +Package::Package(std::string_view name) : name_(name), user_data_ids_(64) {} Package::~Package() = default; @@ -950,6 +952,50 @@ TransformMetricsProto TransformMetrics::ToProto() const { return ret; } +namespace { +#ifdef NDEBUG +static constexpr bool kDebugMode = false; +#else +static constexpr bool kDebugMode = true; +#endif +} // namespace + +void Package::ReleaseNodeUserDataId(int64_t id) { + CHECK(user_data_ids_.Get(id)) << "id: " << id; + user_data_ids_.Set(id, false); + if constexpr (kDebugMode) { + for (FunctionBase* fb : GetFunctionBases()) { + for (Node* n : fb->nodes()) { + CHECK(!n->HasUserData(id)) + << "id: " << id << " node: " << n->ToString(); + } + } + } +} +int64_t Package::AllocateNodeUserDataId() { + if (user_data_ids_.IsAllOnes()) { + int64_t size = user_data_ids_.bit_count(); + user_data_ids_ = std::move(user_data_ids_).WithSize(size + 64); + LOG(WARNING) << "Excessive user data live use: " << (size + 1) + << " live data!"; + } + int64_t off = 0; + // Find first word with a false bit. + while (!(~user_data_ids_.GetWord(off / InlineBitmap::kWordBits))) { + off += InlineBitmap::kWordBits; + } + // Find the byte. + while (!(~user_data_ids_.GetByte(off / 8))) { + off += 8; + } + // Find the bit. + while (user_data_ids_.Get(off)) { + off++; + } + user_data_ids_.Set(off); + return off; +} + // Printers for fuzztest use. namespace { void WriteParseFunction(const Package& p, std::ostream* os) { diff --git a/xls/ir/package.h b/xls/ir/package.h index ddede3db2d..02a32e519a 100644 --- a/xls/ir/package.h +++ b/xls/ir/package.h @@ -16,11 +16,13 @@ #define XLS_IR_PACKAGE_H_ #include +#include #include #include #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -28,6 +30,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xls/data_structures/inline_bitmap.h" #include "xls/ir/channel.h" #include "xls/ir/channel_ops.h" #include "xls/ir/fileno.h" @@ -448,6 +451,21 @@ class Package { } TransformMetrics& transform_metrics() { return transform_metrics_; } + // Allocate a new user data id. This function will not reuse an id until + // ReleaseNodeUserDataId is called on it. + int64_t AllocateNodeUserDataId(); + + // Releases the user data id and allows it to be reused. + // + // NB This must be called once for each value returned by + // AllocateNodeUserDataId. + // + // When this is called all nodes with user data *MUST* have *already* had + // TakeUserData called on them to delete the user data associated with them. + // On DEBUG builds this is CHECKed. + void ReleaseNodeUserDataId(int64_t id); + bool IsLiveUserDataId(int64_t id) { return user_data_ids_.Get(id); } + private: std::vector GetChannelNames() const; @@ -493,6 +511,9 @@ class Package { // Metrics which record the total number of transformations to the package. TransformMetrics transform_metrics_ = {0}; + + // Bitmap containing allocated user data ids. + InlineBitmap user_data_ids_; }; // Printers for fuzztest use. diff --git a/xls/passes/BUILD b/xls/passes/BUILD index a14d9db3a7..e5f8985a17 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -948,6 +948,7 @@ cc_library( "//xls/ir:interval", "//xls/ir:interval_ops", "//xls/ir:interval_set", + "//xls/ir:node_map", "//xls/ir:op", "//xls/ir:ternary", "//xls/ir:type", @@ -1306,6 +1307,7 @@ cc_library( "//xls/common/status:status_macros", "//xls/data_structures:leaf_type_tree", "//xls/ir", + "//xls/ir:node_map", "//xls/ir:op", "//xls/ir:type", "@com_google_absl//absl/algorithm:container", @@ -1339,6 +1341,7 @@ cc_library( "//xls/common/status:status_macros", "//xls/data_structures:inline_bitmap", "//xls/ir", + "//xls/ir:node_map", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1523,13 +1526,13 @@ cc_library( "//xls/ir:bits", "//xls/ir:interval", "//xls/ir:interval_ops", + "//xls/ir:node_map", "//xls/ir:op", "//xls/ir:ternary", "//xls/ir:type", "//xls/ir:value", "//xls/ir:value_utils", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1597,6 +1600,7 @@ cc_library( "//xls/ir:interval", "//xls/ir:interval_ops", "//xls/ir:interval_set", + "//xls/ir:node_map", "//xls/ir:op", "//xls/ir:partial_info", "//xls/ir:partial_ops", @@ -1605,7 +1609,6 @@ cc_library( "//xls/ir:value", "//xls/ir:value_utils", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -2450,8 +2453,8 @@ cc_library( "//xls/ir", "//xls/ir:bits", "//xls/ir:bits_ops", + "//xls/ir:node_map", "//xls/ir:type", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", @@ -3884,9 +3887,9 @@ cc_library( "//xls/common/status:status_macros", "//xls/data_structures:leaf_type_tree", "//xls/ir", + "//xls/ir:node_map", "//xls/ir:type", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -4169,6 +4172,7 @@ cc_library( "//xls/ir:interval", "//xls/ir:interval_ops", "//xls/ir:interval_set", + "//xls/ir:node_map", "//xls/ir:op", "//xls/ir:state_element", "//xls/ir:ternary", @@ -4177,7 +4181,6 @@ cc_library( "//xls/ir:value_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -4355,12 +4358,12 @@ cc_library( "//xls/ir:bits", "//xls/ir:interval_ops", "//xls/ir:interval_set", + "//xls/ir:node_map", "//xls/ir:ternary", "//xls/ir:type", "//xls/ir:value", "//xls/ir:value_utils", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", diff --git a/xls/passes/bit_count_query_engine.cc b/xls/passes/bit_count_query_engine.cc index 71c244b9dc..1265d25860 100644 --- a/xls/passes/bit_count_query_engine.cc +++ b/xls/passes/bit_count_query_engine.cc @@ -22,7 +22,6 @@ #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/types/span.h" @@ -36,6 +35,7 @@ #include "xls/ir/interval_ops.h" #include "xls/ir/interval_set.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/ternary.h" #include "xls/ir/type.h" @@ -529,10 +529,9 @@ LeafTypeTree BitCountQueryEngine::ComputeInfo( CHECK_OK(vis.InjectValue(operand, operand_info)); } CHECK_OK(node->VisitSingleNode(&vis)); - absl::flat_hash_map< - Node*, std::unique_ptr>> - result = std::move(vis).ToStoredValues(); - return std::move(*result.at(node)).ToOwned(); + NodeMap> result = + std::move(vis).ToStoredValues(); + return std::move(result.at(node)).ToOwned(); } absl::Status BitCountQueryEngine::MergeWithGiven( diff --git a/xls/passes/bit_provenance_analysis.h b/xls/passes/bit_provenance_analysis.h index 1e5c09b2c2..4b959badc9 100644 --- a/xls/passes/bit_provenance_analysis.h +++ b/xls/passes/bit_provenance_analysis.h @@ -20,7 +20,6 @@ #include #include -#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" @@ -28,6 +27,7 @@ #include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/passes/query_engine.h" namespace xls { @@ -158,18 +158,15 @@ class BitProvenanceAnalysis { // Get all the sources for a given node. LeafTypeTreeView GetBitSources(Node* n) const { CHECK(IsTracked(n)) << n; - return sources_.at(n)->AsView(); + return sources_.at(n).AsView(); } private: explicit BitProvenanceAnalysis( - absl::flat_hash_map< - Node*, std::unique_ptr>>&& sources) + NodeMap>&& sources) : sources_(std::move(sources)) {} // Map from a node to the nodes which are the source of each of its bits. - absl::flat_hash_map>> - sources_; + NodeMap> sources_; }; } // namespace xls diff --git a/xls/passes/context_sensitive_range_query_engine.cc b/xls/passes/context_sensitive_range_query_engine.cc index 5064367a8a..a2e0566245 100644 --- a/xls/passes/context_sensitive_range_query_engine.cc +++ b/xls/passes/context_sensitive_range_query_engine.cc @@ -43,6 +43,7 @@ #include "xls/ir/interval_ops.h" #include "xls/ir/interval_set.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/ternary.h" diff --git a/xls/passes/dataflow_visitor.h b/xls/passes/dataflow_visitor.h index 0de2f2018f..7148abe33a 100644 --- a/xls/passes/dataflow_visitor.h +++ b/xls/passes/dataflow_visitor.h @@ -22,7 +22,6 @@ #include #include -#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/types/span.h" @@ -33,6 +32,7 @@ #include "xls/ir/bits_ops.h" #include "xls/ir/dfs_visitor.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/type.h" #include "xls/passes/stateless_query_engine.h" @@ -322,16 +322,15 @@ class DataflowVisitor : public DfsVisitorWithDefault { absl::Status HandleTupleIndex(TupleIndex* tuple_index) override { return SetValue( tuple_index, - map_.at(tuple_index->operand(0))->AsView({tuple_index->index()})); + map_.at(tuple_index->operand(0)).AsView({tuple_index->index()})); } // Returns the leaf type tree value associated with `node`. LeafTypeTreeView GetValue(Node* node) const { - return map_.at(node)->AsView(); + return map_.at(node).AsView(); } - absl::flat_hash_map>> - ToStoredValues() && { + NodeMap> ToStoredValues() && { return std::move(map_); } @@ -386,20 +385,17 @@ class DataflowVisitor : public DfsVisitorWithDefault { // Sets the leaf type tree value associated with `node`. absl::Status SetValue(Node* node, LeafTypeTreeView value) { XLS_RET_CHECK_EQ(node->GetType(), value.type()); - map_.insert_or_assign( - node, std::make_unique>(value.AsShared())); + map_.insert_or_assign(node, value.AsShared()); return absl::OkStatus(); } absl::Status SetValue(Node* node, LeafTypeTree&& value) { XLS_RET_CHECK_EQ(node->GetType(), value.type()); - map_.insert_or_assign(node, std::make_unique>( - std::move(value).AsShared())); + map_.insert_or_assign(node, std::move(value).AsShared()); return absl::OkStatus(); } absl::Status SetValue(Node* node, SharedLeafTypeTree&& value) { XLS_RET_CHECK_EQ(node->GetType(), value.type()); - map_.insert_or_assign( - node, std::make_unique>(std::move(value))); + map_.insert_or_assign(node, std::move(value)); return absl::OkStatus(); } @@ -519,7 +515,7 @@ class DataflowVisitor : public DfsVisitorWithDefault { // Storage for the leaf type tree values associated with each node; must be // pointer-stable so that values can be shared (by populating some values as // Views of others). - absl::flat_hash_map>> map_; + NodeMap> map_; }; } // namespace xls diff --git a/xls/passes/lazy_ternary_query_engine.cc b/xls/passes/lazy_ternary_query_engine.cc index 9ee5682757..7aa5dfb98c 100644 --- a/xls/passes/lazy_ternary_query_engine.cc +++ b/xls/passes/lazy_ternary_query_engine.cc @@ -21,7 +21,6 @@ #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -35,6 +34,7 @@ #include "xls/ir/interval_ops.h" #include "xls/ir/lsb_or_msb.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/ternary.h" @@ -674,9 +674,8 @@ TernaryTree LazyTernaryQueryEngine::ComputeInfo( } CHECK_OK(node->VisitSingleNode(&visitor)); } - absl::flat_hash_map> result = - std::move(visitor).ToStoredValues(); - return std::move(*result.at(node)).ToOwned(); + NodeMap result = std::move(visitor).ToStoredValues(); + return std::move(result.at(node)).ToOwned(); } absl::Status LazyTernaryQueryEngine::MergeWithGiven( diff --git a/xls/passes/node_dependency_analysis.cc b/xls/passes/node_dependency_analysis.cc index 490eab4ba2..1344f6aa56 100644 --- a/xls/passes/node_dependency_analysis.cc +++ b/xls/passes/node_dependency_analysis.cc @@ -28,22 +28,23 @@ #include "xls/data_structures/inline_bitmap.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/topo_sort.h" namespace xls { namespace { +using NodeSet = absl::flat_hash_set; + // Perform actual analysis. // f is the function to analyze. We only care about getting results for // 'interesting_nodes' if the set is non-empty (otherwise all nodes are // searched). Succs returns the nodes that depend on the argument. iter is the // iterator to walk the function in the topological order defined by preds. template -std::tuple, - absl::flat_hash_map> -AnalyzeDependents(FunctionBase* f, - const absl::flat_hash_set& interesting_nodes, +std::tuple, absl::flat_hash_map> +AnalyzeDependents(FunctionBase* f, const NodeSet& interesting_nodes, Successors succs, absl::Span topo_sort) { absl::flat_hash_map node_ids; node_ids.reserve(f->node_count()); @@ -67,8 +68,7 @@ AnalyzeDependents(FunctionBase* f, VLOG(3) << "Analyzing dependents of " << f->node_count() << " nodes with " << interesting_nodes.size() << " interesting."; int64_t bitmap_size = f->node_count(); - absl::flat_hash_map results; - results.reserve(f->node_count()); + NodeMap results; for (Node* n : topo_sort) { auto [it, inserted] = results.try_emplace(n, bitmap_size); InlineBitmap& bm = it->second; @@ -86,8 +86,13 @@ AnalyzeDependents(FunctionBase* f, } } // To avoid any bugs delete everything that's not specifically requested. - absl::erase_if(results, - [&](auto& pair) { return !is_interesting(pair.first); }); + for (auto it = results.begin(); it != results.end();) { + if (!is_interesting(it->first)) { + it = results.erase(it); + } else { + ++it; + } + } return {results, node_ids}; } @@ -103,7 +108,7 @@ absl::StatusOr NodeDependencyAnalysis::GetDependents( NodeDependencyAnalysis NodeDependencyAnalysis::BackwardDependents( FunctionBase* fb, absl::Span nodes) { - absl::flat_hash_set interesting(nodes.begin(), nodes.end()); + NodeSet interesting(nodes.begin(), nodes.end()); auto [dependents, node_ids] = AnalyzeDependents( fb, interesting, /*succs=*/[](Node* node) { return node->users(); }, TopoSort(fb)); @@ -112,7 +117,7 @@ NodeDependencyAnalysis NodeDependencyAnalysis::BackwardDependents( NodeDependencyAnalysis NodeDependencyAnalysis::ForwardDependents( FunctionBase* fb, absl::Span nodes) { - absl::flat_hash_set interesting(nodes.begin(), nodes.end()); + NodeSet interesting(nodes.begin(), nodes.end()); auto [dependents, node_ids] = AnalyzeDependents( fb, interesting, /*succs=*/[](Node* node) { return node->operands(); }, ReverseTopoSort(fb)); diff --git a/xls/passes/node_dependency_analysis.h b/xls/passes/node_dependency_analysis.h index 5a59addfe3..b747ff2a13 100644 --- a/xls/passes/node_dependency_analysis.h +++ b/xls/passes/node_dependency_analysis.h @@ -27,6 +27,7 @@ #include "xls/data_structures/inline_bitmap.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" namespace xls { @@ -107,15 +108,14 @@ class NodeDependencyAnalysis { } private: - NodeDependencyAnalysis(bool is_forwards, - absl::flat_hash_map dependents, + NodeDependencyAnalysis(bool is_forwards, NodeMap dependents, absl::flat_hash_map node_ids) : is_forward_(is_forwards), dependents_(std::move(dependents)), node_indices_(std::move(node_ids)) {} bool is_forward_; - absl::flat_hash_map dependents_; + NodeMap dependents_; absl::flat_hash_map node_indices_; }; diff --git a/xls/passes/partial_info_query_engine.cc b/xls/passes/partial_info_query_engine.cc index ebd6a2ffd4..fbf3d21b3b 100644 --- a/xls/passes/partial_info_query_engine.cc +++ b/xls/passes/partial_info_query_engine.cc @@ -23,7 +23,6 @@ #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -40,6 +39,7 @@ #include "xls/ir/interval_set.h" #include "xls/ir/lsb_or_msb.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/partial_information.h" @@ -740,10 +740,9 @@ LeafTypeTree PartialInfoQueryEngine::ComputeInfo( } CHECK_OK(node->VisitSingleNode(&visitor)); } - absl::flat_hash_map>> - result = std::move(visitor).ToStoredValues(); - return std::move(*result.at(node)).ToOwned(); + NodeMap> result = + std::move(visitor).ToStoredValues(); + return std::move(result.at(node)).ToOwned(); } absl::Status PartialInfoQueryEngine::MergeWithGiven( diff --git a/xls/passes/proc_state_range_query_engine.cc b/xls/passes/proc_state_range_query_engine.cc index 5631a23b64..5964023484 100644 --- a/xls/passes/proc_state_range_query_engine.cc +++ b/xls/passes/proc_state_range_query_engine.cc @@ -22,7 +22,6 @@ #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -42,6 +41,7 @@ #include "xls/ir/interval_ops.h" #include "xls/ir/interval_set.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/proc.h" @@ -64,7 +64,7 @@ namespace { class ProcStateGivens : public RangeDataProvider, public TernaryDataProvider { public: - ProcStateGivens(Proc* proc, absl::flat_hash_map intervals) + ProcStateGivens(Proc* proc, NodeMap intervals) : proc_(proc), intervals_(std::move(intervals)) {} absl::Status IterateFunction(DfsVisitor* visitor) override { return proc_->Accept(visitor); @@ -93,7 +93,7 @@ class ProcStateGivens : public RangeDataProvider, public TernaryDataProvider { private: Proc* proc_; - absl::flat_hash_map intervals_; + NodeMap intervals_; }; // A givens that restricts the iteration to only values that hit the proc-state @@ -101,8 +101,7 @@ class ProcStateGivens : public RangeDataProvider, public TernaryDataProvider { class ProcStateEvolutionGivens : public ProcStateGivens { public: ProcStateEvolutionGivens(absl::Span reverse_topo_sort, - Node* target, - absl::flat_hash_map intervals, + Node* target, NodeMap intervals, const DependencyBitmap& interesting_nodes) : ProcStateGivens(target->function_base()->AsProcOrDie(), std::move(intervals)), @@ -141,7 +140,7 @@ ExtractContextSensitiveRange( const NodeDependencyAnalysis& next_dependent_information) { Node* pred = *next->predicate(); XLS_ASSIGN_OR_RETURN( - (absl::flat_hash_map results), + NodeMap results, PropagateGivensBackwards(rqe, proc, {{pred, IntervalSet::Precise(UBits(1, 1))}}, reverse_topo_sort)); @@ -288,9 +287,7 @@ class ConstantValueIrInterpreter // it's hard to imagine many procs with more than a handful of constant set // values which are still narrowable. static constexpr int64_t kSegmentLimit = 8; - const absl::flat_hash_map< - Node*, std::unique_ptr>>>& - values() const { + const NodeMap>>& values() const { return map_; } absl::Status DefaultHandler(Node* n) override { @@ -519,7 +516,7 @@ absl::StatusOr> FindConstantUpdateValues( continue; } LeafTypeTreeView> v = - values.at(n->value())->AsView(); + values.at(n->value()).AsView(); XLS_RET_CHECK(v.type()->IsBits()); for (const Bits& b : v.Get({})) { param_values.insert(b); @@ -855,8 +852,7 @@ absl::StatusOr ProcStateRangeQueryEngine ::Populate( TernaryQueryEngine spec_ternary; RangeQueryEngine spec_range; - absl::flat_hash_map state_read_intervals; - state_read_intervals.reserve(final_range_data.size()); + NodeMap state_read_intervals; for (const auto& [state_element, range] : final_range_data) { state_read_intervals[proc->GetStateRead(state_element)] = range.interval_set.Get({}); diff --git a/xls/passes/token_dependency_pass.cc b/xls/passes/token_dependency_pass.cc index 168fe6c580..4de4eb3372 100644 --- a/xls/passes/token_dependency_pass.cc +++ b/xls/passes/token_dependency_pass.cc @@ -67,7 +67,7 @@ absl::StatusOr TokenDependencyPass::RunOnFunctionBaseInternal( continue; } for (const absl::flat_hash_set& sources : - provenance.at(child)->elements()) { + provenance.at(child).elements()) { for (Node* element : sources) { token_deps[element].insert(node); } diff --git a/xls/passes/token_provenance_analysis.cc b/xls/passes/token_provenance_analysis.cc index d0332f1c78..754bc8b442 100644 --- a/xls/passes/token_provenance_analysis.cc +++ b/xls/passes/token_provenance_analysis.cc @@ -32,6 +32,7 @@ #include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" #include "xls/ir/op.h" #include "xls/ir/type.h" #include "xls/passes/dataflow_visitor.h" @@ -159,7 +160,7 @@ std::string ToString(const TokenProvenance& provenance) { } lines.push_back(absl::StrFormat( " %s : {%s}", node->GetName(), - provenance.at(node)->ToString( + provenance.at(node).ToString( [](const absl::flat_hash_set& sources) { std::vector sorted_sources(sources.begin(), sources.end()); absl::c_sort(sorted_sources, Node::NodeIdLessThan()); @@ -179,7 +180,7 @@ absl::StatusOr ComputeTokenDAG(FunctionBase* f) { for (Node* operand : node->operands()) { if (operand->GetType()->IsToken()) { const absl::flat_hash_set& child = - provenance.at(operand)->Get({}); + provenance.at(operand).Get({}); dag[node].insert(child.cbegin(), child.cend()); } } diff --git a/xls/passes/token_provenance_analysis.h b/xls/passes/token_provenance_analysis.h index 4e1cfc18d9..0ee69c481a 100644 --- a/xls/passes/token_provenance_analysis.h +++ b/xls/passes/token_provenance_analysis.h @@ -20,18 +20,17 @@ #include #include "absl/container/btree_set.h" -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" +#include "xls/ir/node_map.h" namespace xls { -using TokenProvenance = absl::flat_hash_map< - Node*, std::unique_ptr>>>; +using TokenProvenance = NodeMap>>; // Compute, for each token-type in the given `FunctionBase*`, what // side-effecting node(s) contributed to that token. If a leaf type in one of @@ -45,7 +44,7 @@ std::string ToString(const TokenProvenance& provenance); // nodes (/ AfterAll) that their token inputs immediately came from. Note that // this skips over intermediate movement of tokens through tuples, `identity`, // or selects. -using TokenDAG = absl::flat_hash_map>; +using TokenDAG = NodeMap>; // Compute the immediate preceding side-effecting nodes (including proc token // param and `after_all`s) for each side-effecting node (and after_all). Note diff --git a/xls/passes/token_provenance_analysis_test.cc b/xls/passes/token_provenance_analysis_test.cc index 733ef70354..a36034eb10 100644 --- a/xls/passes/token_provenance_analysis_test.cc +++ b/xls/passes/token_provenance_analysis_test.cc @@ -75,25 +75,25 @@ TEST_F(TokenProvenanceAnalysisTest, Simple) { XLS_ASSERT_OK_AND_ASSIGN(TokenProvenance provenance, TokenProvenanceAnalysis(proc)); - EXPECT_THAT(provenance.at(token.node())->Get({}), + EXPECT_THAT(provenance.at(token.node()).Get({}), UnorderedElementsAre(token.node())); - EXPECT_THAT(provenance.at(recv.node())->Get({0}), + EXPECT_THAT(provenance.at(recv.node()).Get({0}), UnorderedElementsAre(recv.node())); - EXPECT_THAT(provenance.at(recv.node())->Get({1}), IsEmpty()); - EXPECT_THAT(provenance.at(tuple.node())->Get({0}), + EXPECT_THAT(provenance.at(recv.node()).Get({1}), IsEmpty()); + EXPECT_THAT(provenance.at(tuple.node()).Get({0}), UnorderedElementsAre(recv.node())); - EXPECT_THAT(provenance.at(tuple.node())->Get({1}), IsEmpty()); - EXPECT_THAT(provenance.at(tuple.node())->Get({2, 0}), IsEmpty()); - EXPECT_THAT(provenance.at(tuple.node())->Get({2, 1}), IsEmpty()); - EXPECT_THAT(provenance.at(tuple.node())->Get({3, 0}), + EXPECT_THAT(provenance.at(tuple.node()).Get({1}), IsEmpty()); + EXPECT_THAT(provenance.at(tuple.node()).Get({2, 0}), IsEmpty()); + EXPECT_THAT(provenance.at(tuple.node()).Get({2, 1}), IsEmpty()); + EXPECT_THAT(provenance.at(tuple.node()).Get({3, 0}), UnorderedElementsAre(t2.node())); - EXPECT_THAT(provenance.at(t3.node())->Get({}), + EXPECT_THAT(provenance.at(t3.node()).Get({}), UnorderedElementsAre(t3.node())); - EXPECT_THAT(provenance.at(t4.node())->Get({}), + EXPECT_THAT(provenance.at(t4.node()).Get({}), UnorderedElementsAre(t4.node())); - EXPECT_THAT(provenance.at(t5.node())->Get({}), + EXPECT_THAT(provenance.at(t5.node()).Get({}), UnorderedElementsAre(t5.node())); - EXPECT_THAT(provenance.at(t6.node())->Get({}), + EXPECT_THAT(provenance.at(t6.node()).Get({}), UnorderedElementsAre(t6.node())); } @@ -112,7 +112,7 @@ TEST_F(TokenProvenanceAnalysisTest, VeryLongChain) { // The proc only consists of a token param and token-typed identity // operations. for (Node* node : proc->nodes()) { - EXPECT_THAT(provenance.at(node)->Get({}), + EXPECT_THAT(provenance.at(node).Get({}), UnorderedElementsAre(token.node())); } } @@ -257,12 +257,12 @@ TEST_F(TokenProvenanceAnalysisTest, SelectOfTokens) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({select, select, selector})); XLS_ASSERT_OK_AND_ASSIGN(TokenProvenance provenance, TokenProvenanceAnalysis(proc)); - EXPECT_THAT(provenance.at(token1.node())->Get({}), + EXPECT_THAT(provenance.at(token1.node()).Get({}), UnorderedElementsAre(token1.node())); - EXPECT_THAT(provenance.at(token2.node())->Get({}), + EXPECT_THAT(provenance.at(token2.node()).Get({}), UnorderedElementsAre(token2.node())); - EXPECT_THAT(provenance.at(selector.node())->Get({}), IsEmpty()); - EXPECT_THAT(provenance.at(select.node())->Get({}), + EXPECT_THAT(provenance.at(selector.node()).Get({}), IsEmpty()); + EXPECT_THAT(provenance.at(select.node()).Get({}), UnorderedElementsAre(token1.node(), token2.node())); } diff --git a/xls/scheduling/mutual_exclusion_pass.cc b/xls/scheduling/mutual_exclusion_pass.cc index 4b4a97e32d..d2e4934f14 100644 --- a/xls/scheduling/mutual_exclusion_pass.cc +++ b/xls/scheduling/mutual_exclusion_pass.cc @@ -245,7 +245,7 @@ std::optional GetPredicate(Node* node) { return predicate; }; -using NodeRelation = absl::flat_hash_map>; +using NodeRelation = NodeMap>; // Find all nodes that the given node transitively depends on. absl::flat_hash_set DependenciesOf(Node* root) { @@ -326,7 +326,7 @@ absl::StatusOr ComputeMergableEffects(FunctionBase* f) { } NodeRelation result; - NodeRelation transitive_closure = TransitiveClosure(token_dag); + NodeRelation transitive_closure = TransitiveClosure(token_dag); for (Node* node : ReverseTopoSort(f)) { if (node->Is() || node->Is()) { std::string_view channel_name = GetChannelName(node); diff --git a/xls/visualization/ir_viz/BUILD b/xls/visualization/ir_viz/BUILD index 105fd37a40..88355b95b1 100644 --- a/xls/visualization/ir_viz/BUILD +++ b/xls/visualization/ir_viz/BUILD @@ -129,6 +129,7 @@ cc_library( "//xls/estimators/delay_model:analyze_critical_path", "//xls/estimators/delay_model:delay_estimator", "//xls/ir", + "//xls/ir:node_map", "//xls/ir:op", "//xls/passes:bdd_query_engine", "//xls/passes:partial_info_query_engine", diff --git a/xls/visualization/ir_viz/ir_to_proto.cc b/xls/visualization/ir_viz/ir_to_proto.cc index e6d62c9245..33fa5da622 100644 --- a/xls/visualization/ir_viz/ir_to_proto.cc +++ b/xls/visualization/ir_viz/ir_to_proto.cc @@ -39,6 +39,7 @@ #include "xls/ir/block.h" // IWYU pragma: keep #include "xls/ir/function.h" #include "xls/ir/function_base.h" +#include "xls/ir/node_map.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/package.h" @@ -202,8 +203,7 @@ absl::StatusOr FunctionBaseToVisualizationProto( PartialInfoQueryEngine(), ProcStateRangeQueryEngine()); XLS_RETURN_IF_ERROR(query_engine.Populate(function).status()); - using NodeDAG = - absl::flat_hash_map>; + using NodeDAG = NodeMap>; NodeDAG node_dag; if (token_dag) { XLS_ASSIGN_OR_RETURN(NodeDAG token_dag, ComputeTokenDAG(function));