diff --git a/src/TiledArray/conversions/make_array.h b/src/TiledArray/conversions/make_array.h index 1295e6f8e4..306b61ee34 100644 --- a/src/TiledArray/conversions/make_array.h +++ b/src/TiledArray/conversions/make_array.h @@ -28,6 +28,7 @@ #include "TiledArray/array_impl.h" #include "TiledArray/external/madness.h" +#include "TiledArray/pmap/replicated_pmap.h" #include "TiledArray/shape.h" #include "TiledArray/type_traits.h" @@ -73,7 +74,7 @@ template ::value>::type* = nullptr> inline Array make_array( World& world, const detail::trange_t& trange, - const std::shared_ptr >& pmap, Op&& op) { + const std::shared_ptr>& pmap, Op&& op) { typedef typename Array::value_type value_type; typedef typename value_type::range_type range_type; @@ -150,10 +151,10 @@ template ::value>::type* = nullptr> inline Array make_array( World& world, const detail::trange_t& trange, - const std::shared_ptr >& pmap, Op&& op) { + const std::shared_ptr>& pmap, Op&& op) { typedef typename Array::value_type value_type; typedef typename Array::ordinal_type ordinal_type; - typedef std::pair > datum_type; + typedef std::pair> datum_type; // Create a vector to hold local tiles std::vector tiles; @@ -241,6 +242,41 @@ inline Array make_array(World& world, const detail::trange_t& trange, op); } +/// a make_array variant that uses a sequence of tiles +/// to construct a DistArray with default pmap +template +Array make_array(World& world, const detail::trange_t& tiled_range, + Tiles begin, Tiles end, bool replicated) { + Array array; + using Tuple = std::remove_reference_t; + using Index = std::tuple_element_t<0, Tuple>; + using shape_type = typename Array::shape_type; + + std::shared_ptr pmap; + if (replicated) { + size_t ntiles = tiled_range.tiles_range().volume(); + pmap = std::make_shared(world, ntiles); + } + + if constexpr (shape_type::is_dense()) { + array = Array(world, tiled_range, pmap); + } else { + std::vector> tile_norms; + for (Tiles it = begin; it != end; ++it) { + auto [index, tile] = *it; + tile_norms.push_back({index, tile.norm()}); + } + shape_type shape(world, tile_norms, tiled_range); + array = Array(world, tiled_range, shape, pmap); + } + for (Tiles it = begin; it != end; ++it) { + auto [index, tile] = *it; + if (array.is_zero(index)) continue; + array.set(index, tile); + } + return array; +} + } // namespace TiledArray #endif // TILEDARRAY_CONVERSIONS_MAKE_ARRAY_H__INCLUDED diff --git a/src/TiledArray/dist_array.h b/src/TiledArray/dist_array.h index cb9d094f34..6c583a795b 100644 --- a/src/TiledArray/dist_array.h +++ b/src/TiledArray/dist_array.h @@ -1878,39 +1878,6 @@ auto norm2(const DistArray& a) { return std::sqrt(squared_norm(a)); } -template -Array make_array(World& world, const detail::trange_t& tiled_range, - Tiles begin, Tiles end, bool replicated) { - Array array; - using Tuple = std::remove_reference_t; - using Index = std::tuple_element_t<0, Tuple>; - using shape_type = typename Array::shape_type; - - std::shared_ptr pmap; - if (replicated) { - size_t ntiles = tiled_range.tiles_range().volume(); - pmap = std::make_shared(world, ntiles); - } - - if constexpr (shape_type::is_dense()) { - array = Array(world, tiled_range, pmap); - } else { - std::vector> tile_norms; - for (Tiles it = begin; it != end; ++it) { - auto [index, tile] = *it; - tile_norms.push_back({index, tile.norm()}); - } - shape_type shape(world, tile_norms, tiled_range); - array = Array(world, tiled_range, shape, pmap); - } - for (Tiles it = begin; it != end; ++it) { - auto [index, tile] = *it; - if (array.is_zero(index)) continue; - array.set(index, tile); - } - return array; -} - template DistArray replicated(const DistArray& a) { auto& world = a.world(); diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index e16c076db4..1f6a9e13f2 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -1,6 +1,7 @@ #ifndef TILEDARRAY_EINSUM_TILEDARRAY_H__INCLUDED #define TILEDARRAY_EINSUM_TILEDARRAY_H__INCLUDED +#include "TiledArray/conversions/make_array.h" #include "TiledArray/dist_array.h" #include "TiledArray/einsum/index.h" #include "TiledArray/einsum/range.h"