diff --git a/CMakeLists.txt b/CMakeLists.txt index eebaf5f..06f9417 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,10 +11,16 @@ include_directories(${CMAKE_SOURCE_DIR}/include) # including gtl as a system dependency prevents warnings when compiling it include_directories(SYSTEM ${CMAKE_SOURCE_DIR}/gtl/include) -add_executable(lerw src/main.cpp) +set(COMMON_SOURCES + src/config.cpp +) + +add_executable(lerw_length src/main_length.cpp ${COMMON_SOURCES}) +add_executable(lerw_points src/main_points.cpp ${COMMON_SOURCES}) +foreach(target lerw_length lerw_points) # options from https://github.com/cpp-best-practices/cmake_template -target_compile_options(lerw PUBLIC +target_compile_options(${target} PUBLIC -Wall -Wextra # reasonable and standard -Wshadow # warn the user if a variable declaration shadows one from a parent context @@ -41,20 +47,21 @@ target_compile_options(lerw PUBLIC -Wsuggest-override # warn if an overridden member function is not marked 'override' or 'final' ) -target_compile_options(lerw PUBLIC -O3 -march=native) -target_compile_options(lerw PUBLIC -fconcepts-diagnostics-depth=4) +target_compile_options(${target} PUBLIC -O3 -march=native) +target_compile_options(${target} PUBLIC -fconcepts-diagnostics-depth=4) # silence warnings about portability of the gtl # https://gcc.gnu.org/onlinedocs/gcc-14.1.0/gcc/Warning-Options.html#index-Winterference-size -target_compile_options(lerw PUBLIC -Wno-interference-size) +target_compile_options(${target} PUBLIC -Wno-interference-size) # silence notes from the pstl-implementation # https://stackoverflow.com/a/23995391 -target_compile_options(lerw PUBLIC -fcompare-debug-second) +target_compile_options(${target} PUBLIC -fcompare-debug-second) -target_link_libraries(lerw PRIVATE +target_link_libraries(${target} PRIVATE tbb boost_program_options) +endforeach() # testing find_package(Catch2 3 REQUIRED) @@ -76,4 +83,4 @@ include(Catch) catch_discover_tests(tests) # nix-build wants an 'install' target -install(TARGETS lerw DESTINATION .) +install(TARGETS lerw_length lerw_points DESTINATION .) diff --git a/Makefile b/Makefile index 9ec615b..d07eeb5 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,8 @@ build: install: build interface.py mkdir -p $(INSTALL_DIR)/bin - cp result/lerw $(INSTALL_DIR)/bin + cp result/lerw_length $(INSTALL_DIR)/bin + cp result/lerw_points $(INSTALL_DIR)/bin cp -n interface.py $(INSTALL_DIR) build_manual: diff --git a/include/config.hpp b/include/config.hpp new file mode 100644 index 0000000..be79185 --- /dev/null +++ b/include/config.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include "utils.hpp" + +namespace lerw { + +struct Configuration { + Norm norm; + std::size_t dimension; + std::size_t num_samples; + double distance; + double alpha; + std::string output_path; + std::size_t seed; + bool help_requested; +}; + +Configuration parse_command_line(int argc, char *argv[]); + +} // namespace lerw diff --git a/include/generator.hpp b/include/generator.hpp index 7d46f0f..9a8820e 100644 --- a/include/generator.hpp +++ b/include/generator.hpp @@ -8,8 +8,8 @@ namespace lerw { template struct RandomWalkGenerator { - Stopper stopper; Stepper stepper; + Stopper stopper; template constexpr auto operator()(RNG &rng) -> auto { @@ -24,8 +24,8 @@ template struct RandomWalkGenerator { template struct LoopErasedRandomWalkGenerator { - Stopper stopper; Stepper stepper; + Stopper stopper; template constexpr auto operator()(RNG &rng) -> auto { diff --git a/include/lerw.hpp b/include/lerw.hpp index 6e7a943..eef77db 100644 --- a/include/lerw.hpp +++ b/include/lerw.hpp @@ -73,26 +73,43 @@ template struct LengthSelector { template using LengthType = typename LengthSelector::type; -struct LERWComputer { +struct DistanceLerwComputer { std::function rng_factory; std::size_t N; double alpha; double distance; - template auto compute() const { + template + auto compute(Projection projection = {}) const { using point_t = PointType; - return compute_lerw_lengths( - [alpha = alpha]() { - return LDStepper{LengthType{alpha}, DirectionType{}}; + return compute_lengths( + [alpha = alpha, distance = distance] { + return LoopErasedRandomWalkGenerator{ + LDStepper{LengthType{alpha}, + DirectionType{}}, + DistanceStopper{distance}}; }, - [distance = distance]() { return DistanceStopper{distance}; }, - rng_factory, N); + rng_factory, projection, N); } }; -template +struct LengthLerwComputer { + double alpha; + size_t length; + + template + auto compute(RNG &rng) const { + using point_t = PointType; + return LoopErasedRandomWalkGenerator{ + LDStepper{LengthType{alpha}, + DirectionType{}}, + LengthStopper{length}}(rng); + } +}; + +template auto compute_lengths(GeneratorFactory &&generator_factory, - RNGFactory &&rng_factory, - size_t N) -> std::vector { + RNGFactory &&rng_factory, Projection projection, + size_t N) { auto generators = std::vector{}; auto rngs = std::vector{}; @@ -105,52 +122,13 @@ auto compute_lengths(GeneratorFactory &&generator_factory, std::vector lengths(N); - std::transform( - std::execution::par_unseq, generators.begin(), generators.end(), - rngs.begin(), lengths.begin(), - [](auto generator, auto rng) { return generator(rng).size(); }); + std::transform(std::execution::par_unseq, generators.begin(), + generators.end(), rngs.begin(), lengths.begin(), + [projection](auto generator, auto rng) { + return projection(generator(rng)); + }); return lengths; } -template -auto compute_lerw_lengths(StepperFactory &&stepper_factory, - StopperFactory &&stopper_factory, - RNGFactory &&rng_factory, - std::size_t n_samples) -> auto { - auto generator_factory = [&stopper_factory, &stepper_factory] { - return LoopErasedRandomWalkGenerator{stopper_factory(), stepper_factory()}; - }; - return compute_lengths(generator_factory, rng_factory, n_samples); -} - -template -auto compute_average_length(GeneratorFactory &&generator_factory, - RNGFactory &&rng_factory, size_t N) -> double { - assert(N != 0); - return std::ranges::fold_left_first( - compute_lengths(std::move(generator_factory), - std::move(rng_factory), N), - std::plus{}) / - static_cast(N); -} - -template -auto compute_lerw_average_lengths(StepperFactory &&stepper_factory, - RNGFactory &&rng_factory, - const std::vector &distances, - std::size_t n_samples) -> auto { - auto results = std::vector>{}; - for (const auto &d : distances) { - auto stopper_factory = [d] { return DistanceStopper{d}; }; - auto l = compute_average_length( - [&stopper_factory, &stepper_factory] { - return LoopErasedRandomWalkGenerator{stopper_factory(), - stepper_factory()}; - }, - rng_factory, n_samples); - results.emplace_back(d, l); - } - return results; -} } // namespace lerw diff --git a/include/utils.hpp b/include/utils.hpp index d617905..63b7280 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -80,4 +80,10 @@ template constexpr auto norm(T... args) -> auto { return norm_selector{}(args...); } + +// helper for switch +constexpr auto switch_pair(std::size_t dimension, Norm norm) -> std::size_t { + return (dimension << 2) + static_cast(norm); +} + } // namespace lerw diff --git a/interface.py b/interface.py index a4c6044..9b0d7c1 100644 --- a/interface.py +++ b/interface.py @@ -7,7 +7,8 @@ # relative to the location where this is called from DATA_DIR: Path = Path("data") -CPP_EXECUTABLE: Path = Path("bin") / "lerw" +CPP_LENGTHS_EXE: Path = Path("bin") / "lerw_length" +CPP_POINTS_EXE: Path = Path("bin") / "lerw_points" class Norm(Enum): @@ -16,6 +17,71 @@ class Norm(Enum): L2 = 2 +def get_walk( + number_of_steps: int, + alpha: float, + norm: Norm, + seed: int = 3, + recompute: bool = False, +) -> npt.NDArray[np.int64]: + """Get Points of a single walk. + + Arguments match output of " --help". + """ + DATA_DIR.mkdir(exist_ok=True) + dimension = 2 + distance = 0.0 # ignored by cpp + + filename = _format_filename( + "walk", dimension, distance, number_of_steps, alpha, norm, seed + ) + file_path = DATA_DIR / filename + + if recompute: + file_path.unlink(missing_ok=True) + + if not file_path.exists(): + cmd = [ + Path.cwd() / CPP_POINTS_EXE, + "--dimension", + dimension, + "--distance", + distance, + "--number_of_walks", + number_of_steps, + "--alpha", + alpha, + "--norm", + norm.name, + "--output", + file_path, + "--seed", + seed, + ] + + result = subprocess.run( + list(map(str, cmd)), + capture_output=True, + text=True, + check=False, + ) + + # Check for any output, which indicates an error + if result.stdout or result.stderr or result.returncode != 0: + print( + f"Failed to call C++:\n{result.stderr}\n\n{result.stdout}", + file=sys.stderr, + ) + raise subprocess.CalledProcessError( + returncode=result.returncode or 1, + cmd=cmd, + output=result.stdout, + stderr=result.stderr, + ) + + return np.genfromtxt(file_path, dtype=np.int64, comments="#", delimiter=",") + + def get_walk_lengths( dimension: int, distance: float, @@ -31,7 +97,9 @@ def get_walk_lengths( """ DATA_DIR.mkdir(exist_ok=True) - filename = _format_filename(dimension, distance, number_of_walks, alpha, norm, seed) + filename = _format_filename( + "walks", dimension, distance, number_of_walks, alpha, norm, seed + ) file_path = DATA_DIR / filename if recompute: @@ -39,7 +107,7 @@ def get_walk_lengths( if not file_path.exists(): cmd = [ - Path.cwd() / CPP_EXECUTABLE, + Path.cwd() / CPP_LENGTHS_EXE, "--dimension", dimension, "--distance", @@ -80,6 +148,7 @@ def get_walk_lengths( def _format_filename( + prefix: str, dimension: int, distance: float, number_of_walks: int, @@ -87,10 +156,24 @@ def _format_filename( norm: Norm, seed: int = 42, ) -> str: - return f"walks_dim{dimension}_dist{distance}_n{number_of_walks}_a{alpha}_{norm.name}_rng{seed}.txt" + return f"{prefix}_dim{dimension}_dist{distance}_n{number_of_walks}_a{alpha}_{norm.name}_rng{seed}.txt" def test(): + num_steps = 10 + args = { + "alpha": 0.5, + "norm": Norm.L2, + "seed": 2, + } + file = Path(DATA_DIR) / _format_filename("walk", 2, 0.0, num_steps, **args) + file.unlink(missing_ok=True) + + walk = get_walk(num_steps, **args) + assert file.exists() + assert walk.shape == (11, 2) + assert all(walk[10, :] == [16, -1077]) + args = { "dimension": 2, "distance": 5000, @@ -99,7 +182,7 @@ def test(): "norm": Norm.L2, "seed": 2, } - file = Path(DATA_DIR) / _format_filename(**args) + file = Path(DATA_DIR) / _format_filename("walks", **args) file.unlink(missing_ok=True) walks = get_walk_lengths(**args) diff --git a/src/config.cpp b/src/config.cpp new file mode 100644 index 0000000..08ddba7 --- /dev/null +++ b/src/config.cpp @@ -0,0 +1,69 @@ +#include "config.hpp" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnull-dereference" +#include +#pragma GCC diagnostic pop +#include +#include + +namespace po = boost::program_options; + +namespace lerw { + +Configuration parse_command_line(int argc, char *argv[]) { + Configuration config{.norm = Norm::L2, + .dimension = 2, + .num_samples = 1000, + .distance = 1000, + .alpha = 0.5, + .output_path = "", + .seed = 42, + .help_requested = false}; + + po::options_description desc("Allowed options"); + desc.add_options()("help", "produce help message")( + "norm,n", + po::value()->default_value("L2")->notifier( + [&config](const std::string &n) { config.norm = parse_norm(n); }), + "norm (L1, L2, or LINF)")( + "dimension,D", + po::value(&config.dimension)->default_value(config.dimension), + "dimension of the lattice")( + "number_of_walks,N", + po::value(&config.num_samples)->default_value(config.num_samples), + "number of walks")( + "distance,R", + po::value(&config.distance)->default_value(config.distance), + "distance from the origin when the walk is stopped")( + "alpha,a", po::value(&config.alpha)->default_value(config.alpha), + "shape parameter (must be > 0)")( + "seed,s", + po::value(&config.seed)->default_value(config.seed), + "random number generator seed")( + "output,o", po::value(&config.output_path), + "path to output file (if not specified, writes to stdout)"); + + try { + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + config.help_requested = vm.count("help") > 0; + if (config.help_requested) { + std::cout << desc << "\n"; + } + + if (config.alpha <= 0) { + throw std::runtime_error("alpha must be greater than 0"); + } + + } catch (const std::exception &e) { + throw std::runtime_error(std::string("Error parsing command line: ") + + e.what()); + } + + return config; +} + +} // namespace lerw diff --git a/src/main.cpp b/src/main.cpp deleted file mode 100644 index 7d94561..0000000 --- a/src/main.cpp +++ /dev/null @@ -1,129 +0,0 @@ -#include -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wnull-dereference" -#include -#pragma GCC diagnostic pop -#include -#include -#include -#include -#include - -#include "lerw.hpp" -#include "utils.hpp" - -using namespace lerw; -namespace po = boost::program_options; - -// helper for switch -constexpr auto switch_pair(std::size_t dimension, Norm norm) -> std::size_t { - return (dimension << 2) + static_cast(norm); -} - -auto main(int argc, char *argv[]) -> int { - Norm norm = Norm::L2; - std::size_t dimension = 2; - std::size_t N = 1000; // number of samples - double distance = 1000; // distance - double alpha = 0.5; // shape parameter - std::string output_path; - std::size_t seed = 42; // default seed value - - po::options_description desc("Allowed options"); - desc.add_options()("help", "produce help message")( - "norm,n", - po::value()->default_value("L2")->notifier( - [&norm](const std::string &n) { norm = parse_norm(n); }), - "norm (L1, L2, or LINF)")( - "dimension,D", po::value(&dimension)->default_value(dimension), - "dimension of the lattice")("number_of_walks,N", - po::value(&N)->default_value(N), - "number of walks")( - "distance,R", po::value(&distance)->default_value(distance), - "distance from the origin when the walk is stopped")( - "alpha,a", po::value(&alpha)->default_value(alpha), - "shape parameter (must be > 0)")( - "seed,s", po::value(&seed)->default_value(seed), - "random number generator seed")( - "output,o", po::value(&output_path), - "path to output file (if not specified, writes to stdout)"); - - boost::program_options::variables_map vm; - try { - po::store(po::parse_command_line(argc, argv, desc), vm); - po::notify(vm); - } catch (const std::exception &e) { - std::cerr << "Error: " << e.what() << "\n"; - return 1; - } - - if (vm.count("help")) { - std::cout << desc << "\n"; - return 0; - } - - if (alpha <= 0) { - std::cerr << "Error: alpha must be greater than 0\n"; - return 1; - } - - std::ofstream output_file; - std::ostream *out = &std::cout; // Default to cout - if (vm.count("output")) { - output_file.open(output_path); - if (!output_file) { - std::cerr << "Error: Could not open output file: " << output_path << "\n"; - return 1; - } - out = &output_file; - } - - auto seed_rng = std::mt19937{seed}; - auto computer = LERWComputer{[&seed_rng] { return std::mt19937{seed_rng()}; }, - N, alpha, distance}; - - const auto lengths = [&] { - switch (switch_pair(dimension, norm)) { - case switch_pair(1, Norm::L1): - return computer.compute<1, Norm::L1>(); - case switch_pair(2, Norm::L1): - return computer.compute<2, Norm::L1>(); - case switch_pair(3, Norm::L1): - return computer.compute<3, Norm::L1>(); - case switch_pair(4, Norm::L1): - return computer.compute<4, Norm::L1>(); - case switch_pair(5, Norm::L1): - return computer.compute<5, Norm::L1>(); - case switch_pair(1, Norm::L2): - return computer.compute<1, Norm::L2>(); - case switch_pair(2, Norm::L2): - return computer.compute<2, Norm::L2>(); - case switch_pair(3, Norm::L2): - return computer.compute<3, Norm::L2>(); - case switch_pair(4, Norm::L2): - return computer.compute<4, Norm::L2>(); - case switch_pair(5, Norm::L2): - return computer.compute<5, Norm::L2>(); - // TODO: LINF with d=1 is broken - // case switch_pair(1, Norm::LINF): - // return computer.compute<1, Norm::LINF>(); - case switch_pair(2, Norm::LINF): - return computer.compute<2, Norm::LINF>(); - case switch_pair(3, Norm::LINF): - return computer.compute<3, Norm::LINF>(); - case switch_pair(4, Norm::LINF): - return computer.compute<4, Norm::LINF>(); - case switch_pair(5, Norm::LINF): - return computer.compute<5, Norm::LINF>(); - default: - throw std::invalid_argument("Unsupported dimension/norm choice"); - } - }(); - - std::println(*out, "# D={}, R={}, N={}, α={}, Norm={}, seed={}", dimension, - distance, N, alpha, norm_to_string(norm), seed); - - for (auto l : lengths) { - std::println(*out, "{}", l); - } -} diff --git a/src/main_length.cpp b/src/main_length.cpp new file mode 100644 index 0000000..1ff94ac --- /dev/null +++ b/src/main_length.cpp @@ -0,0 +1,97 @@ +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "lerw.hpp" +#include "utils.hpp" + +using namespace lerw; + +auto main(int argc, char *argv[]) -> int { + auto config = parse_command_line(argc, argv); + + if (config.help_requested) { + return 0; + } + + std::ofstream output_file; + std::ostream *out = &std::cout; // Default to cout + + if (!config.output_path.empty()) { + output_file.open(config.output_path); + if (!output_file) { + std::cerr << "Error: Could not open output file: " << config.output_path + << "\n"; + return 1; + } + out = &output_file; + } + + auto seed_rng = std::mt19937{config.seed}; + auto computer = + DistanceLerwComputer{[&seed_rng] { return std::mt19937{seed_rng()}; }, + config.num_samples, config.alpha, config.distance}; + + const auto lengths = [&] { + switch (switch_pair(config.dimension, config.norm)) { + case switch_pair(1, Norm::L1): + return computer.compute<1, Norm::L1>( + [](auto walk) { return walk.size(); }); + case switch_pair(2, Norm::L1): + return computer.compute<2, Norm::L1>( + [](auto walk) { return walk.size(); }); + case switch_pair(3, Norm::L1): + return computer.compute<3, Norm::L1>( + [](auto walk) { return walk.size(); }); + case switch_pair(4, Norm::L1): + return computer.compute<4, Norm::L1>( + [](auto walk) { return walk.size(); }); + case switch_pair(5, Norm::L1): + return computer.compute<5, Norm::L1>( + [](auto walk) { return walk.size(); }); + case switch_pair(1, Norm::L2): + return computer.compute<1, Norm::L2>( + [](auto walk) { return walk.size(); }); + case switch_pair(2, Norm::L2): + return computer.compute<2, Norm::L2>( + [](auto walk) { return walk.size(); }); + case switch_pair(3, Norm::L2): + return computer.compute<3, Norm::L2>( + [](auto walk) { return walk.size(); }); + case switch_pair(4, Norm::L2): + return computer.compute<4, Norm::L2>( + [](auto walk) { return walk.size(); }); + case switch_pair(5, Norm::L2): + return computer.compute<5, Norm::L2>( + [](auto walk) { return walk.size(); }); + // TODO: LINF with d=1 is broken + // case switch_pair(1, Norm::LINF): + // return computer.compute<1, Norm::LINF>(); + case switch_pair(2, Norm::LINF): + return computer.compute<2, Norm::LINF>( + [](auto walk) { return walk.size(); }); + case switch_pair(3, Norm::LINF): + return computer.compute<3, Norm::LINF>( + [](auto walk) { return walk.size(); }); + case switch_pair(4, Norm::LINF): + return computer.compute<4, Norm::LINF>( + [](auto walk) { return walk.size(); }); + case switch_pair(5, Norm::LINF): + return computer.compute<5, Norm::LINF>( + [](auto walk) { return walk.size(); }); + default: + throw std::invalid_argument("Unsupported dimension/norm choice"); + } + }(); + + std::println(*out, "# D={}, R={}, N={}, α={}, Norm={}, seed={}", + config.dimension, config.distance, config.num_samples, + config.alpha, norm_to_string(config.norm), config.seed); + + for (auto l : lengths) { + std::println(*out, "{}", l); + } +} diff --git a/src/main_points.cpp b/src/main_points.cpp new file mode 100644 index 0000000..df34077 --- /dev/null +++ b/src/main_points.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "lerw.hpp" +#include "point.hpp" +#include "utils.hpp" + +using namespace lerw; + +auto main(int argc, char *argv[]) -> int { + // NOTE: We only do one walk. We ignore the distance-parameter + // and use the num_samples-parameter to mean the length of the walk + auto config = parse_command_line(argc, argv); + + if (config.help_requested) { + return 0; + } + + std::ofstream output_file; + std::ostream *out = &std::cout; // Default to cout + + if (!config.output_path.empty()) { + output_file.open(config.output_path); + if (!output_file) { + std::cerr << "Error: Could not open output file: " << config.output_path + << "\n"; + return 1; + } + out = &output_file; + } + + auto rng = std::mt19937{config.seed}; + auto computer = LengthLerwComputer{config.alpha, config.num_samples}; + + const auto points = [&] { + switch (switch_pair(config.dimension, config.norm)) { + case switch_pair(2, Norm::L1): + return computer.compute<2, Norm::L1>(rng); + case switch_pair(2, Norm::L2): + return computer.compute<2, Norm::L2>(rng); + case switch_pair(2, Norm::LINF): + return computer.compute<2, Norm::LINF>(rng); + default: + throw std::invalid_argument("Unsupported dimension/norm choice"); + } + }(); + + std::println(*out, "# D={}, R={}, N={}, α={}, Norm={}, seed={}", + config.dimension, config.distance, config.num_samples, + config.alpha, norm_to_string(config.norm), config.seed); + + for (auto p : points) { + std::println(*out, "{}, {}", p.x, p.y); + } +} diff --git a/tests/generator.cpp b/tests/generator.cpp index 93dd1aa..257db25 100644 --- a/tests/generator.cpp +++ b/tests/generator.cpp @@ -2,10 +2,10 @@ #include #include -#include "generator.hpp" -#include "ldstepper.hpp" #include "directions.hpp" #include "distributions.hpp" +#include "generator.hpp" +#include "ldstepper.hpp" #include "point.hpp" #include "stepper.hpp" #include "stopper.hpp" @@ -53,7 +53,7 @@ struct LoopingMockStepper { TEST_CASE("RandomWalkGenerator basic functionality") { auto rng = std::mt19937{42}; SECTION("Walk with zero steps") { - RandomWalkGenerator generator{MockStopper{0}, MockStepper{}}; + RandomWalkGenerator generator{ MockStepper{}, MockStopper{0}}; const auto walk = generator(rng); REQUIRE(walk.size() == 1); @@ -61,7 +61,7 @@ TEST_CASE("RandomWalkGenerator basic functionality") { } SECTION("Walk with one step") { - RandomWalkGenerator generator{MockStopper{1}, MockStepper{}}; + RandomWalkGenerator generator{ MockStepper{}, MockStopper{1}}; const auto walk = generator(rng); REQUIRE(walk.size() == 2); @@ -71,7 +71,7 @@ TEST_CASE("RandomWalkGenerator basic functionality") { SECTION("Walk with multiple steps") { const size_t num_steps = 5; - RandomWalkGenerator generator{MockStopper{num_steps}, MockStepper{}}; + RandomWalkGenerator generator{ MockStepper{}, MockStopper{num_steps}}; const std::vector expected = {0, 1, 2, 3, 4, 5}; const auto walk = generator(rng); @@ -81,7 +81,7 @@ TEST_CASE("RandomWalkGenerator basic functionality") { SECTION("Verify stepper is called correct number of times") { const size_t num_steps = 4; - RandomWalkGenerator generator{MockStopper{num_steps}, MockStepper{}}; + RandomWalkGenerator generator{ MockStepper{}, MockStopper{num_steps}}; const auto walk = generator(rng); REQUIRE(generator.stepper.step_count == num_steps); @@ -89,7 +89,10 @@ TEST_CASE("RandomWalkGenerator basic functionality") { SECTION("Large number of steps") { const size_t large_steps = 1000; - RandomWalkGenerator generator{MockStopper{large_steps}, MockStepper{}}; + RandomWalkGenerator generator{ + MockStepper{}, + MockStopper{large_steps} + }; const auto walk = generator(rng); REQUIRE(walk.size() == large_steps + 1); @@ -103,7 +106,7 @@ TEST_CASE("LoopErasedRandomWalkGenerator") { auto rng = std::mt19937{42}; SECTION("Generates walk of correct length") { - LoopErasedRandomWalkGenerator generator{stopper, MockStepper{}}; + LoopErasedRandomWalkGenerator generator{MockStepper{}, stopper}; const std::vector expected = {0, 1, 2, 3, 4, 5}; const auto walk = generator(rng); @@ -112,7 +115,7 @@ TEST_CASE("LoopErasedRandomWalkGenerator") { } SECTION("Simple loop detection and erasure") { - LoopErasedRandomWalkGenerator generator{stopper, LoopingMockStepper{}}; + LoopErasedRandomWalkGenerator generator{LoopingMockStepper{}, stopper}; auto walk = generator(rng); // Check that the loop was erased (stepper did 1 -> 2 -> 3 -> 0 -> 1) @@ -121,14 +124,14 @@ TEST_CASE("LoopErasedRandomWalkGenerator") { } SECTION("Stops at maximum steps") { - LoopErasedRandomWalkGenerator generator{stopper, MockStepper{}}; + LoopErasedRandomWalkGenerator generator{MockStepper{}, stopper}; auto walk = generator(rng); REQUIRE(walk.size() == max_steps + 1); } SECTION("Handles zero steps") { - LoopErasedRandomWalkGenerator generator{MockStopper{0}, MockStepper{}}; + LoopErasedRandomWalkGenerator generator{MockStepper{}, MockStopper{0}}; auto walk = generator(rng); REQUIRE(walk.size() == 1); // Should still contain start point @@ -139,25 +142,25 @@ TEST_CASE("LoopErasedRandomWalkGenerator") { TEST_CASE("ConstructDifferentCombinations") { // TODO: I want these gone, they should be LDSteppers with L1Direction const auto a = LoopErasedRandomWalkGenerator{ - DistanceStopper{10.0}, NearestNeighborStepper{}}; + NearestNeighborStepper{}, DistanceStopper{10.0}}; const auto b = LoopErasedRandomWalkGenerator{ - DistanceStopper{10.0}, NearestNeighborStepper{}}; + NearestNeighborStepper{}, DistanceStopper{10.0}}; const auto c = LoopErasedRandomWalkGenerator{ - DistanceStopper{10.0}, NearestNeighborStepper{}}; + NearestNeighborStepper{}, DistanceStopper{10.0}}; - const auto d = LoopErasedRandomWalkGenerator{DistanceStopper{10.0}, - LDStepper{ + const auto d = LoopErasedRandomWalkGenerator{LDStepper{ Pareto{2.0}, L2Direction{}, - }}; - const auto e = LoopErasedRandomWalkGenerator{DistanceStopper{10.0}, - LDStepper{ + }, + DistanceStopper{10.0}}; + const auto e = LoopErasedRandomWalkGenerator{LDStepper{ Pareto{2.0}, L2Direction{}, - }}; - const auto f = LoopErasedRandomWalkGenerator{DistanceStopper{10.0}, - LDStepper{ + }, + DistanceStopper{10.0}}; + const auto f = LoopErasedRandomWalkGenerator{LDStepper{ Pareto{2.0}, L2Direction{}, - }}; + }, + DistanceStopper{10.0}}; }