diff --git a/.github/ISSUE_TEMPLATE/1-bugreport.yml b/.github/ISSUE_TEMPLATE/1-bugreport.yml new file mode 100644 index 000000000..f3189823e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/1-bugreport.yml @@ -0,0 +1,38 @@ +name: "Bug Report" +description: "File a bug report" +type: "Bug" +body: + - type: markdown + attributes: + value: | + Thank you for taking the time to fill out this bug report! + - type: textarea + id: version + attributes: + label: "Packages versions" + description: "Let us know the versions of any other packages used. For example, which version of AirScript are you using?" + placeholder: "air-script: 0.1.0" + validations: + required: true + - type: textarea + id: bug-description + attributes: + label: "Bug description" + description: "Describe the behavior you are experiencing." + placeholder: "Tell us what happened and what should have happened." + validations: + required: true + - type: textarea + id: reproduce-steps + attributes: + label: "How can this be reproduced?" + description: "If possible, describe how to replicate the unexpected behavior that you see." + placeholder: "Steps!" + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This is automatically formatted as code, no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml new file mode 100644 index 000000000..bf545356b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -0,0 +1,20 @@ +name: "Feature request" +description: "Request new goodies" +type: "Feature" +body: + - type: markdown + attributes: + value: | + Thank you for taking the time to fill a feature request! + - type: textarea + id: scenario-why + attributes: + label: "Feature description" + validations: + required: true + - type: textarea + id: scenario-how + attributes: + label: "Why is this feature needed?" + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/3-task.yml b/.github/ISSUE_TEMPLATE/3-task.yml new file mode 100644 index 000000000..f347c3710 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/3-task.yml @@ -0,0 +1,36 @@ +name: "Task" +description: "Work item" +type: "Task" +body: + - type: markdown + attributes: + value: | + A task should be less than a week worth of work! + - type: textarea + id: task-what + attributes: + label: "What should be done?" + placeholder: "Add support for new constraint type" + validations: + required: true + - type: textarea + id: task-how + attributes: + label: "How should it be done?" + placeholder: "Implement the new constraint type in the parser and codegen" + validations: + required: true + - type: textarea + id: task-done + attributes: + label: "When is this task done?" + placeholder: "The task is done when the new constraint type is fully implemented and tested" + validations: + required: true + - type: textarea + id: task-related + attributes: + label: "Additional context" + description: "Add context to the tasks. E.g. other related tasks or relevant discussions on PRs/chats." + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..0086358db --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..f015b57bf --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,11 @@ +## Describe your changes + + +## Checklist before requesting a review +- Repo forked and branch created from `next` according to naming convention. +- Commit messages and codestyle follow [conventions](./CONTRIBUTING.md). +- Commits are [signed](https://docs.github.com/en/authentication/managing-commit-signature-verification/signing-commits). +- Relevant issues are linked in the PR description. +- Tests added for new functionality. +- Documentation/comments updated according to changes. +- Updated `CHANGELOG.md` diff --git a/.github/workflows/book.yml b/.github/workflows/book.yml index e9cb896bc..10a5fdede 100644 --- a/.github/workflows/book.yml +++ b/.github/workflows/book.yml @@ -36,17 +36,20 @@ jobs: name: Build documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@main # Installation from source takes a fair while, so we install the binaries directly instead. - name: Install mdbook and plugins uses: taiki-e/install-action@v2 with: - tool: mdbook, mdbook-linkcheck, mdbook-alerts, mdbook-katex, mdbook-mermaid + tool: mdbook@0.4.48, mdbook-linkcheck, mdbook-alerts, mdbook-katex, mdbook-mermaid - name: Build book run: mdbook build docs/ + - name: Test documentation examples + run: make test-docs + # Only Upload documentation if we want to deploy (i.e. push to next). - name: Setup Pages if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/next' }} @@ -58,7 +61,7 @@ jobs: uses: actions/upload-pages-artifact@v3 with: # We specify multiple [output] sections in our book.toml which causes mdbook to create separate folders for each. This moves the generated `html` into its own `html` subdirectory. - path: ./docs/book/html + path: ./docs/target/book/html # Deployment job only runs on push to next. deploy: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..a7302db3d --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,33 @@ +# Runs build related jobs. + +name: build + +# Limits workflow concurrency to only the latest commit in the PR. +concurrency: + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" + cancel-in-progress: true + +on: + push: + branches: [main, next] + pull_request: + types: [opened, reopened, synchronize] + +permissions: + contents: read + +jobs: + check: + name: check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@main + - uses: Swatinem/rust-cache@v2 + with: + # Only update the cache on push onto the next branch. This strikes a nice balance between + # cache hits and cache evictions (github has a 10GB cache limit). + save-if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/next' }} + - name: Install rust + run: rustup update --no-self-update + - name: Run check + run: make check diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 000000000..aad0fe708 --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,26 @@ +# Runs changelog related jobs. +# CI job heavily inspired by: https://github.com/tarides/changelog-check-action + +name: changelog + +on: + pull_request: + types: [opened, reopened, synchronize, labeled, unlabeled] + +permissions: + contents: read + +jobs: + changelog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@main + with: + fetch-depth: 0 + - name: Check for changes in changelog + env: + BASE_REF: ${{ github.event.pull_request.base.ref }} + NO_CHANGELOG_LABEL: ${{ contains(github.event.pull_request.labels.*.name, 'no changelog') }} + run: ./scripts/check-changelog.sh + shell: bash diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 6e6834a29..000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: CI -on: - push: - branches: - - main - pull_request: - types: [opened, reopened, synchronize] - -jobs: - test: - name: Test Rust ${{matrix.toolchain}} on ${{matrix.os}} - runs-on: ${{matrix.os}}-latest - strategy: - fail-fast: false - matrix: - toolchain: [stable, nightly] - os: [ubuntu] - steps: - - uses: actions/checkout@main - - name: Install rust - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{matrix.toolchain}} - override: true - - name: Test - uses: actions-rs/cargo@v1 - with: - command: test - - clippy: - name: Clippy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@main - - name: Install minimal stable with clippy - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - components: clippy - override: true - - - name: Clippy - uses: actions-rs/cargo@v1 - with: - command: clippy - args: --all -- -D clippy::all -D warnings - - rustfmt: - name: rustfmt - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@main - - name: Install minimal stable with rustfmt - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - components: rustfmt - override: true - - - name: rustfmt - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check diff --git a/.github/workflows/link-checker.yml b/.github/workflows/link-checker.yml new file mode 100644 index 000000000..2d46b01ff --- /dev/null +++ b/.github/workflows/link-checker.yml @@ -0,0 +1,56 @@ +name: Check documentation links + +on: + workflow_call: + inputs: + fail-fast: + required: false + default: true + type: boolean + schedule: + # runs once a day at 01:00 UTC + - cron: "0 1 * * *" + +permissions: + contents: read + issues: write + +jobs: + linkChecker: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set fail-fast flag depending on trigger + id: flags + run: | + if [ "${{ github.event_name }}" = "schedule" ]; then + echo "fail_fast=false" >> $GITHUB_OUTPUT + else + echo "fail_fast=${{ inputs.fail-fast }}" >> $GITHUB_OUTPUT + fi + + - name: Link Checker + id: lychee + uses: lycheeverse/lychee-action@v2.0.2 + with: + fail: ${{ steps.flags.outputs.fail_fast }} + args: | + --verbose + --exclude-mail + --exclude "localhost" + --exclude "127.0.0.1" + --exclude "^file://" + --exclude "github.com/0xMiden/.*" + docs/src/**/*.md + **/*.md + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Open issue on failure (only if fail-fast is false) + if: steps.lychee.outcome == 'failure' && steps.flags.outputs.fail_fast != 'true' + uses: peter-evans/create-issue-from-file@v5 + with: + title: Link Checker Report + content-filepath: ./lychee/out.md + labels: report, automated issue diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..a08f46ffb --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,65 @@ +# Runs linting related jobs. + +name: lint + +# Limits workflow concurrency to only the latest commit in the PR. +concurrency: + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" + cancel-in-progress: true + +on: + push: + branches: [main, next] + pull_request: + types: [opened, reopened, synchronize] + +permissions: + contents: read + +jobs: + clippy: + name: clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@main + - uses: Swatinem/rust-cache@v2 + with: + # Only update the cache on push onto the next branch. This strikes a nice balance between + # cache hits and cache evictions (github has a 10GB cache limit). + save-if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/next' }} + - name: Clippy + run: | + rustup update --no-self-update + rustup component add --toolchain stable clippy + make clippy + + rustfmt: + name: rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@main + - uses: Swatinem/rust-cache@v2 + with: + # Only update the cache on push onto the next branch. This strikes a nice balance between + # cache hits and cache evictions (github has a 10GB cache limit). + save-if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/next' }} + - name: Rustfmt + run: | + rustup update --no-self-update nightly + rustup +nightly component add rustfmt + make format-check + + doc: + name: doc + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@main + - uses: Swatinem/rust-cache@v2 + with: + # Only update the cache on push onto the next branch. This strikes a nice balance between + # cache hits and cache evictions (github has a 10GB cache limit). + save-if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/next' }} + - name: Build docs + run: | + rustup update --no-self-update + make doc diff --git a/.github/workflows/msrv.yml b/.github/workflows/msrv.yml new file mode 100644 index 000000000..8be264da5 --- /dev/null +++ b/.github/workflows/msrv.yml @@ -0,0 +1,34 @@ +name: Check MSRV + +on: + push: + branches: [next] + pull_request: + types: [opened, reopened, synchronize] + +# Limits workflow concurrency to only the latest commit in the PR. +concurrency: + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" + cancel-in-progress: true + +permissions: + contents: read + +jobs: + # Check MSRV (aka `rust-version`) in `Cargo.toml` is valid for workspace members + msrv: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install dependencies + run: sudo apt-get update && sudo apt-get install -y jq + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: Install cargo-msrv + run: cargo install cargo-msrv + - name: Cache rustup toolchains + run: rustup update + - name: Check MSRV for each workspace member + run: | + chmod +x scripts/check-msrv.sh + ./scripts/check-msrv.sh \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..2a7543c03 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,36 @@ +# Runs test related jobs. + +name: test + +# Limits workflow concurrency to only the latest commit in the PR. +concurrency: + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" + cancel-in-progress: true + +on: + push: + branches: [main, next] + pull_request: + types: [opened, reopened, synchronize] + +permissions: + contents: read + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@main + - uses: Swatinem/rust-cache@v2 + with: + # Only update the cache on push onto the next branch. This strikes a nice balance between + # cache hits and cache evictions (github has a 10GB cache limit). + save-if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/next' }} + - name: Install rust + run: rustup update --no-self-update + - name: Run tests + run: make test diff --git a/CHANGELOG.md b/CHANGELOG.md index ad04691e0..6ae931f72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## 0.5.0 (TBD) + +- Incremented MSRV to 1.89. +- Add a pass for common subexpression elimination in the constraint graph (#419). +- Reorder unrolling passes to unroll list comprehensions before match statements (#431). +- Refactored the unrolling pass in MIR (#434). +- Update documentation and tests thereof (#437). +- Add a constant propagation pass after other mir passes (#439). +- Removed `TraceSegmentId::index()` and replaced segment indexing with `TraceShape`/`FullTraceShape` across ACE codegen, MIR-to-AIR pass, and constraints (#442). +- Allow computed indices (#444). +- Fix regressions on MIR and list_comprehensions (#449). +- Fixed a vector unrolling issue in nested match evaluations (#491). +- Fix evaluator argument vector slice expansion (#495). +- Fixed MIR's constant propagation to fold 0^0 to 1 (#509) +- Support importing hierarchical modules (#507, #513, #514). +- Fix MIR inlining loop on deeply nested calls (#524). + ## 0.4.0 (2025-06-20) ### Language @@ -19,6 +36,7 @@ - Introduced initial version of the ACE backend (#370, #380, #386). - Updated Winterfell codegen to the latest version (#388). - Removed obsolete MASM codegen backend (#389). +- Add node to graph referencing reduced public input tables ([#414](https://github.com/0xMiden/air-script/issues/414)) ### Internal diff --git a/Cargo.toml b/Cargo.toml index 5adb3e840..c1bc061b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ authors = ["miden contributors"] license = "MIT" repository = "https://github.com/0xMiden/air-script" edition = "2024" -rust-version = "1.87" +rust-version = "1.90" [workspace.dependencies] anyhow = "1.0" diff --git a/Makefile b/Makefile index 40721cb3b..4d9d7c091 100644 --- a/Makefile +++ b/Makefile @@ -8,16 +8,36 @@ help: WARNINGS=RUSTDOCFLAGS="-D warnings" +# -- building -------------------------------------------------------------------------------------- + +.PHONY: build +build: ## Build the project + cargo build --workspace + +.PHONY: check +check: ## Run type checker + cargo check --workspace --all-targets + +# -- testing -------------------------------------------------------------------------------------- + +.PHONY: test +test: ## Run all tests + cargo test --workspace + +.PHONY: test-docs +test-docs: ## Test documentation examples and build + mdbook test docs + # -- linting -------------------------------------------------------------------------------------- .PHONY: clippy clippy: ## Run Clippy with configs - $(WARNINGS) cargo +nightly clippy --workspace --all-targets --all-features + $(WARNINGS) cargo +stable clippy --workspace --all-targets --all-features .PHONY: fix fix: ## Run Fix with configs - cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features + cargo +stable fix --allow-staged --allow-dirty --all-targets --all-features .PHONY: format @@ -32,3 +52,14 @@ format-check: ## Run Format using nightly toolchain but only in check mode .PHONY: lint lint: format fix clippy ## Run all linting tasks at once (Clippy, fixing, formatting) + +# --- docs ---------------------------------------------------------------------------------------- + +.PHONY: doc +doc: ## Generates & checks documentation + cargo doc --keep-going --release + + +.PHONY: book +book: ## Builds the book & serves documentation site + mdbook serve --open docs diff --git a/README.md b/README.md index bc16a0a9d..81b336b14 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # AirScript [![LICENSE](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/0xMiden/air-script/blob/main/LICENSE) -[![CI](https://github.com/0xMiden/air-script/actions/workflows/ci.yml/badge.svg)](https://github.com/0xMiden/air-script/actions/workflows/test.yml) -[![RUST_VERSION](https://img.shields.io/badge/rustc-1.87+-lightgray.svg)](https://www.rust-lang.org/tools/install) +[![test](https://github.com/0xMiden/air-script/actions/workflows/test.yml/badge.svg)](https://github.com/0xMiden/air-script/actions/workflows/test.yml) +[![build](https://github.com/0xMiden/air-script/actions/workflows/build.yml/badge.svg)](https://github.com/0xMiden/air-script/actions/workflows/build.yml) +[![RUST_VERSION](https://img.shields.io/badge/rustc-1.89+-lightgray.svg)](https://www.rust-lang.org/tools/install) [![Crates.io](https://img.shields.io/crates/v/air-script)](https://crates.io/crates/air-script) A domain-specific language for expressing AIR constraints for STARKs, especially for STARK-based virtual machines like [Miden VM](https://github.com/0xMiden/miden-vm). @@ -22,10 +23,57 @@ The project is organized into several crates as follows: | ---------------------- | ----------- | | [Parser](parser) | Contains the parser for AirScript. The parser is used to parse the constraints written in AirScript into an AST. | | [MIR](mir) | Contains the middle intermediate representation (`MIR`). The purpose of the `MIR` is to provide a representation of an AirScript program that allows for optimization and translation to `AirIR` containing the `AlgebraicGraph`. | -| [AIR](air) | Contains the IR for AirScript, `AirIR`. `AirIR` is initialized with an AirScript AST, which it converts to an internal representation that can be optimized and used to generate code in multiple target languages. | +| [AIR](air) | Contains the IR for AirScript, `AirIR`. `AirIR` is initialized with an AirScript MIR, which it converts to an internal representation that can be optimized and used to generate code in multiple target languages. | | [Winterfell code generator](codegen/winterfell/) | Contains a code generator targeting the [Winterfell prover](https://github.com/novifinancial/winterfell) Rust library. The Winterfell code generator converts a provided AirScript `AirIR` into Rust code that represents the AIR as a new custom struct that implements Winterfell's `Air` trait. | +| [ACE code generator](codegen/ace/) | Contains a code generator targeting Miden VM's ACE (Arithmetic Circuit Evaluation) chiplet. Converts AirScript constraints into arithmetic circuits optimized for recursive STARK proof verification within Miden assembly programs. | | [AirScript](air-script) | Aggregates all components of the AirScript compiler into a single place and provides a CLI as an executable to transpile AIRs defined in AirScript to the specified target language. Also contains integration tests for AirScript. | +## Documentation and Examples + +AirScript documentation uses mdBook and is located in the `docs/` directory. Examples are stored in `docs/examples/` and are included in the documentation using mdBook's include syntax. + +### Adding New Examples + +To add a new example to the documentation: + +1. **Create the example file**: Add your `.air` file to `docs/examples/` +2. **Test compilation**: Ensure the example compiles using the CLI: + ```bash + cargo build --release + target/release/airc transpile docs/examples/your_example.air -o /tmp/test.rs + ``` +3. **Include in documentation**: Add the example to the relevant markdown file in `docs/src/` using: + + ```air + {{#include ../../examples/your_example.air}} + ``` + +### Testing Documentation Examples + +The project includes an integration test that ensures all documentation examples compile successfully: + +```bash +cargo test -p air-script --test docs_sync +``` + +This test automatically: +- Builds the AirScript CLI tool +- Finds all `.air` files in `docs/examples/` +- Transpiles each example to verify compilation +- Cleans up generated files + +The `docs_sync` test runs as part of the CI pipeline to ensure documentation examples remain valid and up-to-date. + +### Documentation Build Test + +The project also includes a documentation build test that runs as part of CI: + +```bash +make test-docs +``` + +This test ensures that the documentation builds correctly and validates that `docs/build.rs` runs properly when the `DOCS_TEST` environment variable is set. + ## Contributing to AirScript AirScript is an open project and we welcome everyone to contribute! If you are interested in contributing to AirScript, please have a look at our [Contribution guidelines](https://github.com/0xMiden/air-script/blob/main/CONTRIBUTING.md). If you want to work on a specific issue, please add a comment on the GitHub issue indicating you are interested before submitting a PR. This will help avoid duplicated effort. If you have thoughts on how to improve AirScript, we'd love to know them. So, please don't hesitate to open issues. diff --git a/air-script/Cargo.toml b/air-script/Cargo.toml index 66736081e..8d5d5e5ed 100644 --- a/air-script/Cargo.toml +++ b/air-script/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-script" -version = "0.4.0" +version = "0.5.0" description = "AirScript language compiler" authors.workspace = true readme = "README.md" @@ -17,15 +17,15 @@ name = "airc" path = "src/main.rs" [dependencies] -air-codegen-winter = { package = "air-codegen-winter", path = "../codegen/winterfell", version = "0.4" } -air-ir = { package = "air-ir", path = "../air", version = "0.4" } -air-parser = { package = "air-parser", path = "../parser", version = "0.4" } -air-pass = { package = "air-pass", path = "../pass", version = "0.4" } -clap = { version = "4.2", features = ["derive"] } +air-codegen-winter = { package = "air-codegen-winter", path = "../codegen/winterfell", version = "0.5" } +air-ir = { package = "air-ir", path = "../air", version = "0.5" } +air-parser = { package = "air-parser", path = "../parser", version = "0.5" } +air-pass = { package = "air-pass", path = "../pass", version = "0.5" } +clap = { version = "4.5", features = ["derive"] } env_logger = "0.11" log = { version = "0.4", default-features = false } miden-diagnostics = { workspace = true } -mir = { package = "air-mir", path = "../mir", version = "0.4" } +mir = { package = "air-mir", path = "../mir", version = "0.5" } [dev-dependencies] expect-test = "1.4" diff --git a/air-script/README.md b/air-script/README.md index eabb19cfa..1c88fa324 100644 --- a/air-script/README.md +++ b/air-script/README.md @@ -4,19 +4,20 @@ This crate aggregates all components of the AirScript compiler into a single pla ## Basic Usage -An in-depth description of AirScript is available in the full AirScript [documentation](https://0xpolygonmiden.github.io/air-script/). +An in-depth description of AirScript is available in the full AirScript [documentation](https://0xmiden.github.io/air-script/). -The compiler has three stages, which can be imported and used independently or together. +The compiler has four stages, which can be imported and used independently or together. 1. [Parser](../parser/): scans and parses AirScript files and builds an AST -2. [IR](../ir/): produces an intermediate representation from an AirScript AST -3. [Code generation](../codegen/): translate an `AirIR` into a specific target language +2. [MIR](../mir/): produces a middle intermediate representation from the AirScript AST +3. [AIR](../air/): produces an intermediate representation from an AirScript MIR +4. [Code generation](../codegen/): translate an `AirIR` into a specific target language - [Winterfell Code Generator](../codegen/winterfell/): generates Rust code targeting the [Winterfell prover](https://github.com/novifinancial/winterfell). Example usage: ```Rust -use air_script::{Air, parse, passes, Pass, transforms, WinterfellCodeGenerator}; +use air_script::{parse, compile, WinterfellCodeGenerator}; use miden_diagnostics::{ term::termcolor::ColorChoice, CodeMap, DefaultEmitter, DiagnosticsHandler, }; @@ -28,21 +29,17 @@ let diagnostics = DiagnosticsHandler::new(Default::default(), codemap.clone(), e // Parse into AST let ast = parse(&diagnostics, codemap, source.as_str()).expect("parsing failed"); -// Lower to IR -let air = { - let mut pipeline = transforms::ConstantPropagation::new(&diagnostics) - .chain(transforms::Inlining::new(&diagnostics)) - .chain(passes::AstToAir::new(&diagnostics)); - pipeline.run(ast).expect("lowering failed") -}; + +// Compile AST into AIR +let air = compile(&diagnostics, ast).expect("compilation failed"); // Generate Rust code targeting the Winterfell prover -let code = WinterfellCodeGenerator::new(&ir).generate().expect("codegen failed"); +let code = WinterfellCodeGenerator.generate(&air).expect("codegen failed"); ``` An example of an AIR defined in AirScript can be found in the `examples/` directory. -To run the full transpilation pipeline, the CLI can be used for convenience. +To run the full transpilation the CLI can be used for convenience. ## Command-Line Interface (CLI) diff --git a/air-script/src/cli/transpile.rs b/air-script/src/cli/transpile.rs index cc9025fd2..2ecea6545 100644 --- a/air-script/src/cli/transpile.rs +++ b/air-script/src/cli/transpile.rs @@ -1,8 +1,6 @@ use std::{fs, path::PathBuf, sync::Arc}; -use air_ir::{CodeGenerator, CompileError}; -use air_pass::Pass; - +use air_ir::{CodeGenerator, CompileError, compile}; use clap::{Args, ValueEnum}; use miden_diagnostics::{ CodeMap, DefaultEmitter, DiagnosticsHandler, term::termcolor::ColorChoice, @@ -19,11 +17,6 @@ impl Target { } } } -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] -pub enum Pipeline { - WithMIR, - WithoutMIR, -} #[derive(Args)] pub struct Transpile { @@ -37,19 +30,8 @@ pub struct Transpile { )] output: Option, - #[arg( - short, - long, - help = "Defines the target language, defaults to Winterfell" - )] + #[arg(short, long, help = "Defines the target language, defaults to Winterfell")] target: Option, - - #[arg( - short, - long, - help = "Defines the compilation pipeline (WithMIR or WithoutMIR), defaults to WithMIR" - )] - pipeline: Option, } impl Transpile { @@ -62,37 +44,10 @@ impl Transpile { let emitter = Arc::new(DefaultEmitter::new(ColorChoice::Auto)); let diagnostics = DiagnosticsHandler::new(Default::default(), codemap.clone(), emitter); - let pipeline = self.pipeline.unwrap_or(Pipeline::WithMIR); // Parse from file to internal representation - let air = match pipeline { - Pipeline::WithMIR => { - println!("Transpiling with Mir pipeline..."); - air_parser::parse_file(&diagnostics, codemap, input_path) - .map_err(CompileError::Parse) - .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(mir::passes::AstToMir::new(&diagnostics)) - .chain(mir::passes::Inlining::new(&diagnostics)) - .chain(mir::passes::Unrolling::new(&diagnostics)) - .chain(air_ir::passes::MirToAir::new(&diagnostics)) - .chain(air_ir::passes::BusOpExpand::new(&diagnostics)); - pipeline.run(ast) - }) - } - Pipeline::WithoutMIR => { - println!("Transpiling without Mir pipeline..."); - air_parser::parse_file(&diagnostics, codemap, input_path) - .map_err(CompileError::Parse) - .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(air_parser::transforms::Inlining::new(&diagnostics)) - .chain(air_ir::passes::AstToAir::new(&diagnostics)); - pipeline.run(ast) - }) - } - }; + let air = air_parser::parse_file(&diagnostics, codemap, input_path) + .map_err(CompileError::Parse) + .and_then(|program| compile(&diagnostics, program)); match air { Ok(air) => { @@ -109,9 +64,10 @@ impl Transpile { let mut path = input_path.clone(); path.set_extension(target.extension()); path - } + }, }; - let code = backend.generate(&air).expect("code generation failed"); + let code = + backend.generate(&air).map_err(|e| format!("code generation failed: {e}"))?; if let Err(err) = fs::write(&output_path, code) { return Err(format!("{err:?}")); } @@ -120,11 +76,11 @@ impl Transpile { println!("============================================================"); Ok(()) - } + }, Err(err) => { diagnostics.emit(err); Err("compilation failed".into()) - } + }, } } } diff --git a/air-script/src/lib.rs b/air-script/src/lib.rs index 906933515..5d477c7d0 100644 --- a/air-script/src/lib.rs +++ b/air-script/src/lib.rs @@ -1,4 +1,3 @@ pub use air_codegen_winter::CodeGenerator as WinterfellCodeGenerator; -pub use air_ir::{Air, CompileError, passes}; +pub use air_ir::{Air, CompileError, compile}; pub use air_parser::{parse, parse_file, transforms}; -pub use air_pass::Pass; diff --git a/air-script/src/main.rs b/air-script/src/main.rs index d4166bcfb..754684312 100644 --- a/air-script/src/main.rs +++ b/air-script/src/main.rs @@ -1,6 +1,7 @@ -use clap::{Parser, Subcommand}; use std::io::Write; +use clap::{Parser, Subcommand}; + mod cli; #[derive(Parser)] diff --git a/air-script/tests/binary/binary.rs b/air-script/tests/binary/binary.rs index cbb2fcb97..8b7185bbd 100644 --- a/air-script/tests/binary/binary.rs +++ b/air-script/tests/binary/binary.rs @@ -82,8 +82,8 @@ impl Air for BinaryAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * main_current[0] - main_current[0] - E::ZERO; - result[1] = main_current[1] * main_current[1] - main_current[1] - E::ZERO; + result[0] = main_current[0] * main_current[0] - main_current[0]; + result[1] = main_current[1] * main_current[1] - main_current[1]; } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/binary/test_air.rs b/air-script/tests/binary/test_air.rs index 43efbdc04..4f40af405 100644 --- a/air-script/tests/binary/test_air.rs +++ b/air-script/tests/binary/test_air.rs @@ -3,7 +3,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ - binary::binary::{BinaryAir, PublicInputs}, + binary::binary::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -38,17 +39,4 @@ impl AirTester for BinaryAirTester { } } -#[test] -fn test_binary_air() { - let air_tester = Box::new(BinaryAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = BinaryAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!(test_binary_air, crate::binary::binary::BinaryAir, BinaryAirTester, 1024); diff --git a/air-script/tests/bitwise/bitwise.rs b/air-script/tests/bitwise/bitwise.rs index c631d2262..84916092d 100644 --- a/air-script/tests/bitwise/bitwise.rs +++ b/air-script/tests/bitwise/bitwise.rs @@ -82,23 +82,23 @@ impl Air for BitwiseAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * main_current[0] - main_current[0] - E::ZERO; - result[1] = periodic_values[1] * (main_next[0] - main_current[0]) - E::ZERO; - result[2] = main_current[3] * main_current[3] - main_current[3] - E::ZERO; - result[3] = main_current[4] * main_current[4] - main_current[4] - E::ZERO; - result[4] = main_current[5] * main_current[5] - main_current[5] - E::ZERO; - result[5] = main_current[6] * main_current[6] - main_current[6] - E::ZERO; - result[6] = main_current[7] * main_current[7] - main_current[7] - E::ZERO; - result[7] = main_current[8] * main_current[8] - main_current[8] - E::ZERO; - result[8] = main_current[9] * main_current[9] - main_current[9] - E::ZERO; - result[9] = main_current[10] * main_current[10] - main_current[10] - E::ZERO; - result[10] = periodic_values[0] * (main_current[1] - (E::ONE * main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])) - E::ZERO; - result[11] = periodic_values[0] * (main_current[2] - (E::ONE * main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])) - E::ZERO; - result[12] = periodic_values[1] * (main_next[1] - (main_current[1] * E::from(Felt::new(16_u64)) + E::ONE * main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])) - E::ZERO; - result[13] = periodic_values[1] * (main_next[2] - (main_current[2] * E::from(Felt::new(16_u64)) + E::ONE * main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])) - E::ZERO; - result[14] = periodic_values[0] * main_current[11] - E::ZERO; - result[15] = periodic_values[1] * (main_current[12] - main_next[11]) - E::ZERO; - result[16] = (E::ONE - main_current[0]) * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + E::ONE * main_current[3] * main_current[7] + E::from(Felt::new(2_u64)) * main_current[4] * main_current[8] + E::from(Felt::new(4_u64)) * main_current[5] * main_current[9] + E::from(Felt::new(8_u64)) * main_current[6] * main_current[10])) + main_current[0] * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + E::ONE * (main_current[3] + main_current[7] - E::from(Felt::new(2_u64)) * main_current[3] * main_current[7]) + E::from(Felt::new(2_u64)) * (main_current[4] + main_current[8] - E::from(Felt::new(2_u64)) * main_current[4] * main_current[8]) + E::from(Felt::new(4_u64)) * (main_current[5] + main_current[9] - E::from(Felt::new(2_u64)) * main_current[5] * main_current[9]) + E::from(Felt::new(8_u64)) * (main_current[6] + main_current[10] - E::from(Felt::new(2_u64)) * main_current[6] * main_current[10]))) - E::ZERO; + result[0] = main_current[0] * main_current[0] - main_current[0]; + result[1] = periodic_values[1] * (main_next[0] - main_current[0]); + result[2] = main_current[3] * main_current[3] - main_current[3]; + result[3] = main_current[4] * main_current[4] - main_current[4]; + result[4] = main_current[5] * main_current[5] - main_current[5]; + result[5] = main_current[6] * main_current[6] - main_current[6]; + result[6] = main_current[7] * main_current[7] - main_current[7]; + result[7] = main_current[8] * main_current[8] - main_current[8]; + result[8] = main_current[9] * main_current[9] - main_current[9]; + result[9] = main_current[10] * main_current[10] - main_current[10]; + result[10] = periodic_values[0] * (main_current[1] - (main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])); + result[11] = periodic_values[0] * (main_current[2] - (main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])); + result[12] = periodic_values[1] * (main_next[1] - (main_current[1] * E::from(Felt::new(16_u64)) + main_current[3] + E::from(Felt::new(2_u64)) * main_current[4] + E::from(Felt::new(4_u64)) * main_current[5] + E::from(Felt::new(8_u64)) * main_current[6])); + result[13] = periodic_values[1] * (main_next[2] - (main_current[2] * E::from(Felt::new(16_u64)) + main_current[7] + E::from(Felt::new(2_u64)) * main_current[8] + E::from(Felt::new(4_u64)) * main_current[9] + E::from(Felt::new(8_u64)) * main_current[10])); + result[14] = periodic_values[0] * main_current[11]; + result[15] = periodic_values[1] * (main_current[12] - main_next[11]); + result[16] = (E::ONE - main_current[0]) * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + main_current[3] * main_current[7] + E::from(Felt::new(2_u64)) * main_current[4] * main_current[8] + E::from(Felt::new(4_u64)) * main_current[5] * main_current[9] + E::from(Felt::new(8_u64)) * main_current[6] * main_current[10])) + main_current[0] * (main_current[12] - (main_current[11] * E::from(Felt::new(16_u64)) + main_current[3] + main_current[7] - E::from(Felt::new(2_u64)) * main_current[3] * main_current[7] + E::from(Felt::new(2_u64)) * (main_current[4] + main_current[8] - E::from(Felt::new(2_u64)) * main_current[4] * main_current[8]) + E::from(Felt::new(4_u64)) * (main_current[5] + main_current[9] - E::from(Felt::new(2_u64)) * main_current[5] * main_current[9]) + E::from(Felt::new(8_u64)) * (main_current[6] + main_current[10] - E::from(Felt::new(2_u64)) * main_current[6] * main_current[10]))); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/bitwise/test_air.rs b/air-script/tests/bitwise/test_air.rs index 8d16a82b2..16daca900 100644 --- a/air-script/tests/bitwise/test_air.rs +++ b/air-script/tests/bitwise/test_air.rs @@ -5,7 +5,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{AuxTraceWithMetadata, Trace, TraceTable, matrix::ColMatrix}; use crate::{ - bitwise::bitwise::{BitwiseAir, PublicInputs}, + bitwise::bitwise::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -49,17 +50,4 @@ impl AirTester for BitwiseAirTester { } } -#[test] -fn test_bitwise_air() { - let air_tester = Box::new(BitwiseAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = BitwiseAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!(test_bitwise_air, crate::bitwise::bitwise::BitwiseAir, BitwiseAirTester, 1024); diff --git a/air-script/tests/buses/buses_complex.air b/air-script/tests/buses/buses_complex.air index 10c69697e..401ca8b67 100644 --- a/air-script/tests/buses/buses_complex.air +++ b/air-script/tests/buses/buses_complex.air @@ -25,8 +25,11 @@ integrity_constraints { enf s1^2 = s1; enf s2^2 = s2; - p.insert(1, a) when s1; - p.remove(1, b) when s2; + let vec = [x for x in 0..3]; + let x = sum(vec) + b; + + p.insert(1, x, a) when s1; + p.remove(1, x, b) when s2; p.insert(2, b) when 1 - s1; p.remove(2, a) when 1 - s2; diff --git a/air-script/tests/buses/buses_complex.rs b/air-script/tests/buses/buses_complex.rs index ddc29ff8c..d15eeeb8f 100644 --- a/air-script/tests/buses/buses_complex.rs +++ b/air-script/tests/buses/buses_complex.rs @@ -98,7 +98,7 @@ impl Air for BusesAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[2]) + E::ONE - E::from(main_current[2])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[2])) + E::ONE - (E::ONE - E::from(main_current[2]))) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[3]) + E::ONE - E::from(main_current[3])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[3])) + E::ONE - (E::ONE - E::from(main_current[3]))) * aux_next[0]; + result[0] = ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + (E::from(Felt::new(3_u64)) + E::from(main_current[1])) * aux_rand_elements.rand_elements()[2] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[3]) * E::from(main_current[2]) + E::ONE - E::from(main_current[2])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[2])) + E::from(main_current[2])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + (E::from(Felt::new(3_u64)) + E::from(main_current[1])) * aux_rand_elements.rand_elements()[2] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[3]) * E::from(main_current[3]) + E::ONE - E::from(main_current[3])) * ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[3])) + E::from(main_current[3])) * aux_next[0]; result[1] = (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[2]) + (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[2]) - ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[1]) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(main_current[0]) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[4])); } } \ No newline at end of file diff --git a/air-script/tests/buses/buses_simple.rs b/air-script/tests/buses/buses_simple.rs index f58f3a7e5..b0012a15a 100644 --- a/air-script/tests/buses/buses_simple.rs +++ b/air-script/tests/buses/buses_simple.rs @@ -94,7 +94,7 @@ impl Air for BusesAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * (E::ONE - E::from(main_current[0])) + E::ONE - (E::ONE - E::from(main_current[0]))) * aux_next[0]; - result[1] = (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); + result[0] = ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * (E::ONE - E::from(main_current[0])) + E::from(main_current[0])) * aux_next[0]; + result[1] = (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); } } \ No newline at end of file diff --git a/air-script/tests/buses/buses_simple_with_evaluators.air b/air-script/tests/buses/buses_simple_with_evaluators.air new file mode 100644 index 000000000..a1f30c0f5 --- /dev/null +++ b/air-script/tests/buses/buses_simple_with_evaluators.air @@ -0,0 +1,36 @@ +def BusesAir + +trace_columns { + main: [a], +} + +buses { + multiset p, + logup q, +} + +public_inputs { + inputs: [2], +} + +boundary_constraints { + enf p.first = unconstrained; + enf q.first = null; + enf p.last = null; + enf q.last = null; +} + +ev bus_p([a]) { + p.insert(1) when a; + p.remove(1) when (1 - a); +} + +ev bus_q([a]) { + q.insert(1, 2) when a; + q.remove(1, 2) with 2; +} + +integrity_constraints { + enf bus_p([a]); + enf bus_q([a]); +} diff --git a/air-script/tests/buses/buses_varlen_boundary_both.air b/air-script/tests/buses/buses_varlen_boundary_both.air index ee6654aa8..d46bf567f 100644 --- a/air-script/tests/buses/buses_varlen_boundary_both.air +++ b/air-script/tests/buses/buses_varlen_boundary_both.air @@ -16,8 +16,8 @@ public_inputs { boundary_constraints { enf p.first = inputs; - enf q.first = inputs; - enf p.last = outputs; + enf q.first = null; + enf p.last = unconstrained; enf q.last = outputs; } diff --git a/air-script/tests/buses/buses_varlen_boundary_both.rs b/air-script/tests/buses/buses_varlen_boundary_both.rs index 673bc3a55..1b4be6ea2 100644 --- a/air-script/tests/buses/buses_varlen_boundary_both.rs +++ b/air-script/tests/buses/buses_varlen_boundary_both.rs @@ -41,10 +41,10 @@ impl BusesAir { self.trace_length() - self.context().num_transition_exemptions() } - pub fn bus_multiset_boundary_varlen<'a, const N: usize, I: IntoIterator + Clone, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: &I) -> E { + pub fn bus_multiset_boundary_varlen<'a, const N: usize, I: IntoIterator, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: I) -> E { let mut bus_p_last: E = E::ONE; let rand = aux_rand_elements.rand_elements(); - for row in public_inputs.clone().into_iter() { + for row in public_inputs { let mut p_last = rand[0]; for (c, p_i) in row.iter().enumerate() { p_last += E::from(*p_i) * rand[c + 1]; @@ -54,10 +54,10 @@ impl BusesAir { bus_p_last } - pub fn bus_logup_boundary_varlen<'a, const N: usize, I: IntoIterator + Clone, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: &I) -> E { + pub fn bus_logup_boundary_varlen<'a, const N: usize, I: IntoIterator, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: I) -> E { let mut bus_q_last = E::ZERO; let rand = aux_rand_elements.rand_elements(); - for row in public_inputs.clone().into_iter() { + for row in public_inputs { let mut q_last = rand[0]; for (c, p_i) in row.iter().enumerate() { let p_i = *p_i; @@ -81,7 +81,7 @@ impl Air for BusesAir { let main_degrees = vec![]; let aux_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(1)]; let num_main_assertions = 0; - let num_aux_assertions = 4; + let num_aux_assertions = 3; let context = AirContext::new_multi_segment( trace_info, @@ -106,10 +106,11 @@ impl Air for BusesAir { fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { let mut result = Vec::new(); - result.push(Assertion::single(0, 0, Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.inputs.iter()))); - result.push(Assertion::single(1, 0, Self::bus_logup_boundary_varlen(aux_rand_elements, &self.inputs.iter()))); - result.push(Assertion::single(0, self.last_step(), Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.outputs.iter()))); - result.push(Assertion::single(1, self.last_step(), Self::bus_logup_boundary_varlen(aux_rand_elements, &self.outputs.iter()))); + let reduced_inputs_multiset = Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.inputs); + let reduced_outputs_logup = Self::bus_logup_boundary_varlen(aux_rand_elements, &self.outputs); + result.push(Assertion::single(0, 0, reduced_inputs_multiset)); + result.push(Assertion::single(1, 0, E::ZERO)); + result.push(Assertion::single(1, self.last_step(), reduced_outputs_logup)); result } @@ -126,7 +127,7 @@ impl Air for BusesAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * (E::from(main_current[0]) - E::ONE) + E::ONE - (E::from(main_current[0]) - E::ONE)) * aux_next[0]; - result[1] = (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); + result[0] = ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * (E::from(main_current[0]) - E::ONE) + E::ONE - (E::from(main_current[0]) - E::ONE)) * aux_next[0]; + result[1] = (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); } } \ No newline at end of file diff --git a/air-script/tests/buses/buses_varlen_boundary_first.rs b/air-script/tests/buses/buses_varlen_boundary_first.rs index 3b5571a3d..9b36093bf 100644 --- a/air-script/tests/buses/buses_varlen_boundary_first.rs +++ b/air-script/tests/buses/buses_varlen_boundary_first.rs @@ -37,10 +37,10 @@ impl BusesAir { self.trace_length() - self.context().num_transition_exemptions() } - pub fn bus_multiset_boundary_varlen<'a, const N: usize, I: IntoIterator + Clone, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: &I) -> E { + pub fn bus_multiset_boundary_varlen<'a, const N: usize, I: IntoIterator, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: I) -> E { let mut bus_p_last: E = E::ONE; let rand = aux_rand_elements.rand_elements(); - for row in public_inputs.clone().into_iter() { + for row in public_inputs { let mut p_last = rand[0]; for (c, p_i) in row.iter().enumerate() { p_last += E::from(*p_i) * rand[c + 1]; @@ -50,10 +50,10 @@ impl BusesAir { bus_p_last } - pub fn bus_logup_boundary_varlen<'a, const N: usize, I: IntoIterator + Clone, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: &I) -> E { + pub fn bus_logup_boundary_varlen<'a, const N: usize, I: IntoIterator, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: I) -> E { let mut bus_q_last = E::ZERO; let rand = aux_rand_elements.rand_elements(); - for row in public_inputs.clone().into_iter() { + for row in public_inputs { let mut q_last = rand[0]; for (c, p_i) in row.iter().enumerate() { let p_i = *p_i; @@ -102,10 +102,12 @@ impl Air for BusesAir { fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { let mut result = Vec::new(); + let reduced_inputs_multiset = Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.inputs); + let reduced_inputs_logup = Self::bus_logup_boundary_varlen(aux_rand_elements, &self.inputs); + result.push(Assertion::single(0, 0, reduced_inputs_multiset)); result.push(Assertion::single(0, self.last_step(), E::ONE)); + result.push(Assertion::single(1, 0, reduced_inputs_logup)); result.push(Assertion::single(1, self.last_step(), E::ZERO)); - result.push(Assertion::single(0, 0, Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.inputs.iter()))); - result.push(Assertion::single(1, 0, Self::bus_logup_boundary_varlen(aux_rand_elements, &self.inputs.iter()))); result } @@ -122,7 +124,7 @@ impl Air for BusesAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * (E::from(main_current[0]) - E::ONE) + E::ONE - (E::from(main_current[0]) - E::ONE)) * aux_next[0]; - result[1] = (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); + result[0] = ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * (E::from(main_current[0]) - E::ONE) + E::ONE - (E::from(main_current[0]) - E::ONE)) * aux_next[0]; + result[1] = (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); } } \ No newline at end of file diff --git a/air-script/tests/buses/buses_varlen_boundary_last.rs b/air-script/tests/buses/buses_varlen_boundary_last.rs index e2de8b043..686604e08 100644 --- a/air-script/tests/buses/buses_varlen_boundary_last.rs +++ b/air-script/tests/buses/buses_varlen_boundary_last.rs @@ -37,10 +37,10 @@ impl BusesAir { self.trace_length() - self.context().num_transition_exemptions() } - pub fn bus_multiset_boundary_varlen<'a, const N: usize, I: IntoIterator + Clone, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: &I) -> E { + pub fn bus_multiset_boundary_varlen<'a, const N: usize, I: IntoIterator, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: I) -> E { let mut bus_p_last: E = E::ONE; let rand = aux_rand_elements.rand_elements(); - for row in public_inputs.clone().into_iter() { + for row in public_inputs { let mut p_last = rand[0]; for (c, p_i) in row.iter().enumerate() { p_last += E::from(*p_i) * rand[c + 1]; @@ -50,10 +50,10 @@ impl BusesAir { bus_p_last } - pub fn bus_logup_boundary_varlen<'a, const N: usize, I: IntoIterator + Clone, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: &I) -> E { + pub fn bus_logup_boundary_varlen<'a, const N: usize, I: IntoIterator, E: FieldElement>(aux_rand_elements: &AuxRandElements, public_inputs: I) -> E { let mut bus_q_last = E::ZERO; let rand = aux_rand_elements.rand_elements(); - for row in public_inputs.clone().into_iter() { + for row in public_inputs { let mut q_last = rand[0]; for (c, p_i) in row.iter().enumerate() { let p_i = *p_i; @@ -102,10 +102,12 @@ impl Air for BusesAir { fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { let mut result = Vec::new(); + let reduced_outputs_multiset = Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.outputs); + let reduced_outputs_logup = Self::bus_logup_boundary_varlen(aux_rand_elements, &self.outputs); result.push(Assertion::single(0, 0, E::ONE)); + result.push(Assertion::single(0, self.last_step(), reduced_outputs_multiset)); result.push(Assertion::single(1, 0, E::ZERO)); - result.push(Assertion::single(0, self.last_step(), Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.outputs.iter()))); - result.push(Assertion::single(1, self.last_step(), Self::bus_logup_boundary_varlen(aux_rand_elements, &self.outputs.iter()))); + result.push(Assertion::single(1, self.last_step(), reduced_outputs_logup)); result } @@ -122,7 +124,7 @@ impl Air for BusesAir { let main_next = main_frame.next(); let aux_current = aux_frame.current(); let aux_next = aux_frame.next(); - result[0] = ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1]) * (E::from(main_current[0]) - E::ONE) + E::ONE - (E::from(main_current[0]) - E::ONE)) * aux_next[0]; - result[1] = (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + E::ONE * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); + result[0] = ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * E::from(main_current[0]) + E::ONE - E::from(main_current[0])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1]) * (E::from(main_current[0]) - E::ONE) + E::ONE - (E::from(main_current[0]) - E::ONE)) * aux_next[0]; + result[1] = (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_current[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) - ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * aux_next[1] + (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(Felt::new(2_u64))); } } \ No newline at end of file diff --git a/air-script/tests/buses/test_air.rs b/air-script/tests/buses/test_air.rs index ad766cfb2..4c27304a2 100644 --- a/air-script/tests/buses/test_air.rs +++ b/air-script/tests/buses/test_air.rs @@ -3,7 +3,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{AuxTraceWithMetadata, Trace, TraceTable, matrix::ColMatrix}; use crate::{ - buses::buses_complex::{BusesAir, PublicInputs}, + buses::buses_complex::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -45,31 +46,15 @@ impl AirTester for BusesAirTester { fn build_aux_trace(&self, length: usize) -> Option> { let aux_trace_width = 2; - let num_rand_values = 3; + let num_rand_values = 4; let mut aux_trace = ColMatrix::new(vec![vec![Felt::new(0); length]; aux_trace_width]); aux_trace.update_row(0, &[Felt::new(1), Felt::new(0)]); aux_trace.update_row(length - 2, &[Felt::new(1), Felt::new(0)]); let aux_rand_elements = AuxRandElements::new(vec![Felt::new(0); num_rand_values]); - let aux_trace_with_meta = AuxTraceWithMetadata { - aux_trace, - aux_rand_elements, - }; + let aux_trace_with_meta = AuxTraceWithMetadata { aux_trace, aux_rand_elements }; Some(aux_trace_with_meta) } } -#[test] -fn test_buses_air() { - let air_tester = Box::new(BusesAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = BusesAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!(test_buses_air, crate::buses::buses_complex::BusesAir, BusesAirTester, 1024); diff --git a/air-script/tests/codegen/helpers.rs b/air-script/tests/codegen/helpers.rs index 6efb94f4e..3fc9c8aa6 100644 --- a/air-script/tests/codegen/helpers.rs +++ b/air-script/tests/codegen/helpers.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use air_ir::{CodeGenerator, CompileError}; -use air_pass::Pass; +use air_script::compile; use miden_diagnostics::{ CodeMap, DefaultEmitter, DiagnosticsHandler, term::termcolor::ColorChoice, }; @@ -9,10 +9,6 @@ use miden_diagnostics::{ pub enum Target { Winterfell, } -pub enum Pipeline { - WithMIR, - WithoutMIR, -} pub struct Test { input_path: String, @@ -22,35 +18,15 @@ impl Test { Test { input_path } } - pub fn transpile(&self, target: Target, pipeline: Pipeline) -> Result { + pub fn transpile(&self, target: Target) -> Result { let codemap = Arc::new(CodeMap::new()); let emitter = Arc::new(DefaultEmitter::new(ColorChoice::Auto)); let diagnostics = DiagnosticsHandler::new(Default::default(), codemap.clone(), emitter); // Parse from file to internal representation - let air = match pipeline { - Pipeline::WithMIR => air_parser::parse_file(&diagnostics, codemap, &self.input_path) - .map_err(CompileError::Parse) - .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(mir::passes::AstToMir::new(&diagnostics)) - .chain(mir::passes::Inlining::new(&diagnostics)) - .chain(mir::passes::Unrolling::new(&diagnostics)) - .chain(air_ir::passes::MirToAir::new(&diagnostics)) - .chain(air_ir::passes::BusOpExpand::new(&diagnostics)); - pipeline.run(ast) - })?, - Pipeline::WithoutMIR => air_parser::parse_file(&diagnostics, codemap, &self.input_path) - .map_err(CompileError::Parse) - .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(air_parser::transforms::Inlining::new(&diagnostics)) - .chain(air_ir::passes::AstToAir::new(&diagnostics)); - pipeline.run(ast) - })?, - }; + let air = air_parser::parse_file(&diagnostics, codemap, &self.input_path) + .map_err(CompileError::Parse) + .and_then(|program| compile(&diagnostics, program))?; let backend: Box> = match target { Target::Winterfell => Box::new(air_codegen_winter::CodeGenerator), diff --git a/air-script/tests/codegen/mod.rs b/air-script/tests/codegen/mod.rs index 4c3b267dd..d802f40b1 100644 --- a/air-script/tests/codegen/mod.rs +++ b/air-script/tests/codegen/mod.rs @@ -1,3 +1,2 @@ mod helpers; -mod winterfell_with_mir; -mod winterfell_wo_mir; +mod winterfell; diff --git a/air-script/tests/codegen/winterfell_with_mir.rs b/air-script/tests/codegen/winterfell.rs similarity index 60% rename from air-script/tests/codegen/winterfell_with_mir.rs rename to air-script/tests/codegen/winterfell.rs index 3968fe1bf..2fa96ccb8 100644 --- a/air-script/tests/codegen/winterfell_with_mir.rs +++ b/air-script/tests/codegen/winterfell.rs @@ -1,13 +1,11 @@ -use super::helpers::{Pipeline, Target, Test}; use expect_test::expect_file; -// tests_wo_mir -// ================================================================================================ +use super::helpers::{Target, Test}; #[test] fn binary() { let generated_air = Test::new("tests/binary/binary.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../binary/binary.rs"]; @@ -15,19 +13,19 @@ fn binary() { } #[test] -fn buses_simple() { - let generated_air = Test::new("tests/buses/buses_simple.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn bitwise() { + let generated_air = Test::new("tests/bitwise/bitwise.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../buses/buses_simple.rs"]; + let expected = expect_file!["../bitwise/bitwise.rs"]; expected.assert_eq(&generated_air); } #[test] fn buses_complex() { let generated_air = Test::new("tests/buses/buses_complex.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../buses/buses_complex.rs"]; @@ -35,29 +33,29 @@ fn buses_complex() { } #[test] -fn buses_varlen_boundary_first() { - let generated_air = Test::new("tests/buses/buses_varlen_boundary_first.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn buses_simple() { + let generated_air = Test::new("tests/buses/buses_simple.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../buses/buses_varlen_boundary_first.rs"]; + let expected = expect_file!["../buses/buses_simple.rs"]; expected.assert_eq(&generated_air); } #[test] -fn buses_varlen_boundary_last() { - let generated_air = Test::new("tests/buses/buses_varlen_boundary_last.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn buses_simple_with_evaluators() { + let generated_air = Test::new("tests/buses/buses_simple_with_evaluators.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../buses/buses_varlen_boundary_last.rs"]; + let expected = expect_file!["../buses/buses_simple.rs"]; expected.assert_eq(&generated_air); } #[test] fn buses_varlen_boundary_both() { let generated_air = Test::new("tests/buses/buses_varlen_boundary_both.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../buses/buses_varlen_boundary_both.rs"]; @@ -65,49 +63,60 @@ fn buses_varlen_boundary_both() { } #[test] -fn periodic_columns() { - let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn buses_varlen_boundary_first() { + let generated_air = Test::new("tests/buses/buses_varlen_boundary_first.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../periodic_columns/periodic_columns.rs"]; + let expected = expect_file!["../buses/buses_varlen_boundary_first.rs"]; expected.assert_eq(&generated_air); } #[test] -fn pub_inputs() { - let generated_air = Test::new("tests/pub_inputs/pub_inputs.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn buses_varlen_boundary_last() { + let generated_air = Test::new("tests/buses/buses_varlen_boundary_last.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../pub_inputs/pub_inputs.rs"]; + let expected = expect_file!["../buses/buses_varlen_boundary_last.rs"]; expected.assert_eq(&generated_air); } #[test] -fn system() { - let generated_air = Test::new("tests/system/system.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn computed_indices_complex() { + let generated_air = + Test::new("tests/computed_indices/computed_indices_complex.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../computed_indices/computed_indices_complex.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn computed_indices_simple() { + let generated_air = Test::new("tests/computed_indices/computed_indices_simple.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../system/system.rs"]; + let expected = expect_file!["../computed_indices/computed_indices_simple.rs"]; expected.assert_eq(&generated_air); } #[test] -fn bitwise() { - let generated_air = Test::new("tests/bitwise/bitwise.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn constant_in_range() { + let generated_air = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../bitwise/bitwise.rs"]; + let expected = expect_file!["../constant_in_range/constant_in_range.rs"]; expected.assert_eq(&generated_air); } #[test] fn constants() { let generated_air = Test::new("tests/constants/constants.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../constants/constants.rs"]; @@ -115,19 +124,28 @@ fn constants() { } #[test] -fn constant_in_range() { - let generated_air = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) - .unwrap(); +fn constraint_comprehension() { + let generated_air = + Test::new("tests/constraint_comprehension/constraint_comprehension.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); - let expected = expect_file!["../constant_in_range/constant_in_range.rs"]; + let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; + expected.assert_eq(&generated_air); + + let generated_air = + Test::new("tests/constraint_comprehension/cc_with_evaluators.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; expected.assert_eq(&generated_air); } #[test] fn evaluators() { let generated_air = Test::new("tests/evaluators/evaluators.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../evaluators/evaluators.rs"]; @@ -135,64 +153,76 @@ fn evaluators() { } #[test] -fn fibonacci() { - let generated_air = Test::new("tests/fibonacci/fibonacci.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn evaluators_slice() { + let generated_air = Test::new("tests/evaluators/evaluators_slice.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../fibonacci/fibonacci.rs"]; + let expected = expect_file!["../evaluators/evaluators_slice.rs"]; expected.assert_eq(&generated_air); } #[test] -fn functions_simple() { - let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn evaluators_nested_slice_call() { + let generated_air = Test::new("tests/evaluators/evaluators_nested_slice_call.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../functions/functions_simple.rs"]; + let expected = expect_file!["../evaluators/evaluators_nested_slice_call.rs"]; expected.assert_eq(&generated_air); } +// TODO: add support for nested slicing in general expressions. +// +// #[test] +// fn evaluators_slice_slicing() { +// let generated_air = Test::new("tests/evaluators/evaluators_slice_slicing.air".to_string()) +// .transpile(Target::Winterfell) +// .unwrap(); +// +// let expected = expect_file!["../evaluators/evaluators_slice_slicing.rs"]; +// expected.assert_eq(&generated_air); +// } + #[test] -fn functions_simple_inlined() { - // make sure that the constraints generated using inlined functions are the same as the ones - // generated using regular functions - let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn fibonacci() { + let generated_air = Test::new("tests/fibonacci/fibonacci.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../functions/functions_simple.rs"]; + let expected = expect_file!["../fibonacci/fibonacci.rs"]; expected.assert_eq(&generated_air); } #[test] fn functions_complex() { let generated_air = Test::new("tests/functions/functions_complex.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../functions/functions_complex_with_mir.rs"]; + let expected = expect_file!["../functions/functions_complex.rs"]; expected.assert_eq(&generated_air); } #[test] -fn variables() { - let generated_air = Test::new("tests/variables/variables.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn functions_simple() { + let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../variables/variables.rs"]; + let expected = expect_file!["../functions/functions_simple.rs"]; expected.assert_eq(&generated_air); } #[test] -fn trace_col_groups() { - let generated_air = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) +fn functions_simple_inlined() { + // make sure that the constraints generated using inlined functions are the same as the ones + // generated using regular functions + let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../trace_col_groups/trace_col_groups.rs"]; + let expected = expect_file!["../functions/functions_simple.rs"]; expected.assert_eq(&generated_air); } @@ -200,7 +230,7 @@ fn trace_col_groups() { fn indexed_trace_access() { let generated_air = Test::new("tests/indexed_trace_access/indexed_trace_access.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../indexed_trace_access/indexed_trace_access.rs"]; @@ -210,10 +240,10 @@ fn indexed_trace_access() { #[test] fn list_comprehension() { let generated_air = Test::new("tests/list_comprehension/list_comprehension.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../list_comprehension/list_comprehension_with_mir.rs"]; + let expected = expect_file!["../list_comprehension/list_comprehension.rs"]; expected.assert_eq(&generated_air); } @@ -221,7 +251,7 @@ fn list_comprehension() { fn list_comprehension_nested() { let generated_air = Test::new("tests/list_comprehension/list_comprehension_nested.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); let expected = expect_file!["../list_comprehension/list_comprehension_nested.rs"]; @@ -231,45 +261,107 @@ fn list_comprehension_nested() { #[test] fn list_folding() { let generated_air = Test::new("tests/list_folding/list_folding.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../list_folding/list_folding.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn periodic_columns() { + let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../list_folding/list_folding_with_mir.rs"]; + let expected = expect_file!["../periodic_columns/periodic_columns.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn pub_inputs() { + let generated_air = Test::new("tests/pub_inputs/pub_inputs.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../pub_inputs/pub_inputs.rs"]; expected.assert_eq(&generated_air); } #[test] fn selectors() { let generated_air = Test::new("tests/selectors/selectors.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../selectors/selectors_with_mir.rs"]; + let expected = expect_file!["../selectors/selectors.rs"]; expected.assert_eq(&generated_air); let generated_air = Test::new("tests/selectors/selectors_with_evaluators.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../selectors/selectors_with_evaluators_with_mir.rs"]; + let expected = expect_file!["../selectors/selectors_with_evaluators.rs"]; expected.assert_eq(&generated_air); } #[test] -fn constraint_comprehension() { - let generated_air = - Test::new("tests/constraint_comprehension/constraint_comprehension.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) - .unwrap(); +fn selectors_combine_simple() { + let generated_air = Test::new("tests/selectors/selectors_combine_simple.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); - let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; + let expected = expect_file!["../selectors/selectors_combine_simple.rs"]; expected.assert_eq(&generated_air); +} + +#[test] +fn selectors_combine_complex() { + let generated_air = Test::new("tests/selectors/selectors_combine_complex.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../selectors/selectors_combine_complex.rs"]; + expected.assert_eq(&generated_air); +} +#[test] +fn selectors_combine_with_list_comprehensions() { let generated_air = - Test::new("tests/constraint_comprehension/cc_with_evaluators.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithMIR) + Test::new("tests/selectors/selectors_combine_with_list_comprehensions.air".to_string()) + .transpile(Target::Winterfell) .unwrap(); - let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; + let expected = expect_file!["../selectors/selectors_combine_with_list_comprehensions.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn system() { + let generated_air = Test::new("tests/system/system.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../system/system.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn trace_col_groups() { + let generated_air = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../trace_col_groups/trace_col_groups.rs"]; + expected.assert_eq(&generated_air); +} + +#[test] +fn variables() { + let generated_air = Test::new("tests/variables/variables.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../variables/variables.rs"]; expected.assert_eq(&generated_air); } diff --git a/air-script/tests/codegen/winterfell_wo_mir.rs b/air-script/tests/codegen/winterfell_wo_mir.rs deleted file mode 100644 index 871b90f29..000000000 --- a/air-script/tests/codegen/winterfell_wo_mir.rs +++ /dev/null @@ -1,228 +0,0 @@ -use super::helpers::{Pipeline, Target, Test}; -use expect_test::expect_file; - -// tests_wo_mir -// ================================================================================================ - -#[test] -fn binary() { - let generated_air = Test::new("tests/binary/binary.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../binary/binary.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn buses_simple() { - Test::new("tests/buses/buses_simple.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .expect_err("Buses should not be supported in the WithoutMIR pipeline"); -} - -#[test] -fn buses_complex() { - Test::new("tests/buses/buses_complex.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .expect_err("Buses should not be supported in the WithoutMIR pipeline"); -} - -#[test] -fn periodic_columns() { - let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../periodic_columns/periodic_columns.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn pub_inputs() { - let generated_air = Test::new("tests/pub_inputs/pub_inputs.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../pub_inputs/pub_inputs.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn system() { - let generated_air = Test::new("tests/system/system.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../system/system.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn bitwise() { - let generated_air = Test::new("tests/bitwise/bitwise.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../bitwise/bitwise.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn constants() { - let generated_air = Test::new("tests/constants/constants.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../constants/constants.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn constant_in_range() { - let generated_air = Test::new("tests/constant_in_range/constant_in_range.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../constant_in_range/constant_in_range.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn evaluators() { - let generated_air = Test::new("tests/evaluators/evaluators.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../evaluators/evaluators.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn fibonacci() { - let generated_air = Test::new("tests/fibonacci/fibonacci.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../fibonacci/fibonacci.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn functions_simple() { - let generated_air = Test::new("tests/functions/functions_simple.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../functions/functions_simple.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn functions_simple_inlined() { - // make sure that the constraints generated using inlined functions are the same as the ones - // generated using regular functions - let generated_air = Test::new("tests/functions/inlined_functions_simple.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../functions/functions_simple.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn functions_complex() { - let generated_air = Test::new("tests/functions/functions_complex.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../functions/functions_complex.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn variables() { - let generated_air = Test::new("tests/variables/variables.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../variables/variables.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn trace_col_groups() { - let generated_air = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../trace_col_groups/trace_col_groups.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn indexed_trace_access() { - let generated_air = - Test::new("tests/indexed_trace_access/indexed_trace_access.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../indexed_trace_access/indexed_trace_access.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn list_comprehension() { - let generated_air = Test::new("tests/list_comprehension/list_comprehension.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../list_comprehension/list_comprehension.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn list_folding() { - let generated_air = Test::new("tests/list_folding/list_folding.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../list_folding/list_folding.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn selectors() { - let generated_air = Test::new("tests/selectors/selectors.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../selectors/selectors.rs"]; - expected.assert_eq(&generated_air); - - let generated_air = Test::new("tests/selectors/selectors_with_evaluators.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../selectors/selectors.rs"]; - expected.assert_eq(&generated_air); -} - -#[test] -fn constraint_comprehension() { - let generated_air = - Test::new("tests/constraint_comprehension/constraint_comprehension.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; - expected.assert_eq(&generated_air); - - let generated_air = - Test::new("tests/constraint_comprehension/cc_with_evaluators.air".to_string()) - .transpile(Target::Winterfell, Pipeline::WithoutMIR) - .unwrap(); - - let expected = expect_file!["../constraint_comprehension/constraint_comprehension.rs"]; - expected.assert_eq(&generated_air); -} diff --git a/air-script/tests/computed_indices/computed_indices_complex.air b/air-script/tests/computed_indices/computed_indices_complex.air new file mode 100644 index 000000000..1b1dbffbb --- /dev/null +++ b/air-script/tests/computed_indices/computed_indices_complex.air @@ -0,0 +1,61 @@ +def ComputedIndicesAir + +const MDS = [ + [1, 2], + [2, 3], + [3, 4] +]; + +trace_columns { + main: [a[2], s[2]], +} + +public_inputs { + input: [1], +} + +fn double(a: felt) -> felt { + let x = 3 * a; + let y = a; + return x - y; +} + +boundary_constraints { + enf a[0].first = 0; +} + +# Note: +# In this test, we aim to test that computed indices work well even if: +# - The value of the index can only be known late during the compilation process (during MIR's constant propagation) +# - The +integrity_constraints { + + # vec_1 is a list_comprehension that depends on the state + # vec_1 = [ + # 1 * s[0] + 2 * s[1], + # 2 * s[0] + 3 * s[1], + # 3 * s[0] + 4 * s[1] + # ]; + let vec_1 = apply_mds(s); + + let state_2 = [2, 0]; + # vec_2 is a list_comprehension that will not get constant-folded early, but will produce constant values + let vec_2 = apply_mds(state_2); + + # x will get the value 2 * 2 - 4 = 0 + let x = double(2) - vec_2[1]; + + # y will get the value vec_2[0] = 2 + let y = vec_2[x]; + + # z will then be vec_1[2] = 3 * s[0] + 4 * s[1] + let z = vec_1[y]; + + # we enforce 3 * s[0] + 4 * s[1] = 0 + enf z = 0; +} + +# We use apply_mds function to produce a list comprehension that will not get constant-folded during AST +fn apply_mds(state: felt[2]) -> felt[3] { + return [sum([s * m for (s, m) in (state, mds_row)]) for mds_row in MDS]; +} diff --git a/air-script/tests/list_comprehension/list_comprehension_with_mir.rs b/air-script/tests/computed_indices/computed_indices_complex.rs similarity index 67% rename from air-script/tests/list_comprehension/list_comprehension_with_mir.rs rename to air-script/tests/computed_indices/computed_indices_complex.rs index b8e8ecf69..b39fe04fa 100644 --- a/air-script/tests/list_comprehension/list_comprehension_with_mir.rs +++ b/air-script/tests/computed_indices/computed_indices_complex.rs @@ -4,41 +4,41 @@ use winter_math::{ExtensionOf, FieldElement, ToElements}; use winter_utils::{ByteWriter, Serializable}; pub struct PublicInputs { - stack_inputs: [Felt; 16], + input: [Felt; 1], } impl PublicInputs { - pub fn new(stack_inputs: [Felt; 16]) -> Self { - Self { stack_inputs } + pub fn new(input: [Felt; 1]) -> Self { + Self { input } } } impl Serializable for PublicInputs { fn write_into(&self, target: &mut W) { - self.stack_inputs.write_into(target); + self.input.write_into(target); } } impl ToElements for PublicInputs { fn to_elements(&self) -> Vec { let mut elements = Vec::new(); - elements.extend_from_slice(&self.stack_inputs); + elements.extend_from_slice(&self.input); elements } } -pub struct ListComprehensionAir { +pub struct ComputedIndicesAir { context: AirContext, - stack_inputs: [Felt; 16], + input: [Felt; 1], } -impl ListComprehensionAir { +impl ComputedIndicesAir { pub fn last_step(&self) -> usize { self.trace_length() - self.context().num_transition_exemptions() } } -impl Air for ListComprehensionAir { +impl Air for ComputedIndicesAir { type BaseField = Felt; type PublicInputs = PublicInputs; @@ -47,7 +47,7 @@ impl Air for ListComprehensionAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(1)]; + let main_degrees = vec![TransitionConstraintDegree::new(1)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -61,7 +61,7 @@ impl Air for ListComprehensionAir { options, ) .set_num_transition_exemptions(2); - Self { context, stack_inputs: public_inputs.stack_inputs } + Self { context, input: public_inputs.input } } fn get_periodic_column_values(&self) -> Vec> { @@ -70,7 +70,7 @@ impl Air for ListComprehensionAir { fn get_assertions(&self) -> Vec> { let mut result = Vec::new(); - result.push(Assertion::single(10, 0, Felt::ZERO)); + result.push(Assertion::single(0, 0, Felt::ZERO)); result } @@ -82,11 +82,7 @@ impl Air for ListComprehensionAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] - main_current[2]; - result[1] = main_current[4] - main_current[0] * E::from(Felt::new(2_u64)) * E::from(Felt::new(2_u64)) * E::from(Felt::new(2_u64)) * main_current[11]; - result[2] = main_current[4] - main_current[0] * (main_next[8] - main_next[12]); - result[3] = main_current[6] - main_current[0] * (main_current[9] - main_current[14]); - result[4] = main_current[1] - (E::ZERO + main_current[5] - main_current[8] - main_current[12] + E::ONE + main_current[6] - main_current[9] - main_current[13] + E::from(Felt::new(2_u64)) + main_current[7] - main_current[10] - main_current[14]); + result[0] = main_current[2] * E::from(Felt::new(3_u64)) + main_current[3] * E::from(Felt::new(4_u64)); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/computed_indices/computed_indices_simple.air b/air-script/tests/computed_indices/computed_indices_simple.air new file mode 100644 index 000000000..3dbaae038 --- /dev/null +++ b/air-script/tests/computed_indices/computed_indices_simple.air @@ -0,0 +1,34 @@ +def ComputedIndicesAir + +trace_columns { + main: [a, b, c, d, e, f, g, h], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + enf a.first = 0; +} + +integrity_constraints { + # vec = [0, 1, 2, 3, 4]; + let vec = [i for i in 0..5]; + + # x = [0 * 2, 1 * 2, 2 * 2, 3 * 2]; + # x = [0, 2, 4, 6]; + let x = [j * vec[1 + 1] for j in 0..4]; + enf a = x[0]; + enf b = x[1]; + enf c = x[2]; + enf d = x[3]; + + # y = [0 * 1, 1 * 2, 2 * 3, 3 * 4]; + # y = [0, 2, 6, 12]; + let y = [j * vec[j + 1] for j in 0..4]; + enf e' = y[0] * e; + enf f' = y[1] * f; + enf g' = y[2] * g; + enf h' = y[3] * h; +} diff --git a/air-script/tests/functions/functions_complex_with_mir.rs b/air-script/tests/computed_indices/computed_indices_simple.rs similarity index 75% rename from air-script/tests/functions/functions_complex_with_mir.rs rename to air-script/tests/computed_indices/computed_indices_simple.rs index f9ba361dd..bb55d3dd5 100644 --- a/air-script/tests/functions/functions_complex_with_mir.rs +++ b/air-script/tests/computed_indices/computed_indices_simple.rs @@ -27,18 +27,18 @@ impl ToElements for PublicInputs { } } -pub struct FunctionsAir { +pub struct ComputedIndicesAir { context: AirContext, stack_inputs: [Felt; 16], } -impl FunctionsAir { +impl ComputedIndicesAir { pub fn last_step(&self) -> usize { self.trace_length() - self.context().num_transition_exemptions() } } -impl Air for FunctionsAir { +impl Air for ComputedIndicesAir { type BaseField = Felt; type PublicInputs = PublicInputs; @@ -47,7 +47,7 @@ impl Air for FunctionsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(11), TransitionConstraintDegree::new(1)]; + let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -70,7 +70,7 @@ impl Air for FunctionsAir { fn get_assertions(&self) -> Vec> { let mut result = Vec::new(); - result.push(Assertion::single(3, 0, Felt::ZERO)); + result.push(Assertion::single(0, 0, Felt::ZERO)); result } @@ -82,8 +82,14 @@ impl Air for FunctionsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_next[16] - main_current[16] * ((main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[3] * main_current[1] * main_current[2] + main_current[3] * main_current[3] * (E::ONE - main_current[1]) * main_current[2] + main_current[3] * main_current[1] * (E::ONE - main_current[2]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2])) * main_current[0] - main_current[0] + E::ONE); - result[1] = main_next[3] - (E::ZERO + main_current[4] + main_current[5] + main_current[6] + main_current[7] + main_current[8] + main_current[9] + main_current[10] + main_current[11] + main_current[12] + main_current[13] + main_current[14] + main_current[15] + E::ONE) * E::from(Felt::new(2_u64)); + result[0] = main_current[0]; + result[1] = main_current[1] - E::from(Felt::new(2_u64)); + result[2] = main_current[2] - E::from(Felt::new(4_u64)); + result[3] = main_current[3] - E::from(Felt::new(6_u64)); + result[4] = main_next[4]; + result[5] = main_next[5] - E::from(Felt::new(2_u64)) * main_current[5]; + result[6] = main_next[6] - E::from(Felt::new(6_u64)) * main_current[6]; + result[7] = main_next[7] - E::from(Felt::new(12_u64)) * main_current[7]; } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/computed_indices/mod.rs b/air-script/tests/computed_indices/mod.rs new file mode 100644 index 000000000..5f582b136 --- /dev/null +++ b/air-script/tests/computed_indices/mod.rs @@ -0,0 +1,7 @@ +#[rustfmt::skip] +#[allow(clippy::all)] +mod computed_indices_complex; +#[rustfmt::skip] +#[allow(clippy::all)] +mod computed_indices_simple; +mod test_air; diff --git a/air-script/tests/computed_indices/test_air.rs b/air-script/tests/computed_indices/test_air.rs new file mode 100644 index 000000000..5a953f6e5 --- /dev/null +++ b/air-script/tests/computed_indices/test_air.rs @@ -0,0 +1,61 @@ +use winter_air::Air; +use winter_math::fields::f64::BaseElement as Felt; +use winterfell::{Trace, TraceTable}; + +use crate::{ + computed_indices::computed_indices_simple::{ComputedIndicesAir, PublicInputs}, + helpers::{AirTester, MyTraceTable}, +}; + +#[derive(Clone)] +struct ComputedIndicesAirTester {} + +impl AirTester for ComputedIndicesAirTester { + type PubInputs = PublicInputs; + + fn build_main_trace(&self, length: usize) -> MyTraceTable { + let trace_width = 8; + let mut trace = TraceTable::new(trace_width, length); + + trace.fill( + |state| { + state[0] = Felt::new(0); + state[1] = Felt::new(2); + state[2] = Felt::new(4); + state[3] = Felt::new(6); + state[4] = Felt::new(0); + state[5] = Felt::new(0); + state[6] = Felt::new(0); + state[7] = Felt::new(0); + }, + |_, state| { + state[4] *= Felt::new(0); + state[5] *= Felt::new(2); + state[6] *= Felt::new(6); + state[7] *= Felt::new(12); + }, + ); + + MyTraceTable::new(trace, 0) + } + + fn public_inputs(&self) -> PublicInputs { + let zero = Felt::new(0); + PublicInputs::new([zero; 16]) + } +} + +#[test] +fn test_computed_indices_air() { + let air_tester = Box::new(ComputedIndicesAirTester {}); + let length = 1024; + + let main_trace = air_tester.build_main_trace(length); + let aux_trace = air_tester.build_aux_trace(length); + let pub_inputs = air_tester.public_inputs(); + let trace_info = air_tester.build_trace_info(length); + let options = air_tester.build_proof_options(); + + let air = ComputedIndicesAir::new(trace_info, pub_inputs, options); + main_trace.validate::(&air, aux_trace.as_ref()); +} diff --git a/air-script/tests/constant_in_range/constant_in_range.rs b/air-script/tests/constant_in_range/constant_in_range.rs index aff438e9f..2c53ec6ff 100644 --- a/air-script/tests/constant_in_range/constant_in_range.rs +++ b/air-script/tests/constant_in_range/constant_in_range.rs @@ -82,7 +82,7 @@ impl Air for ConstantInRangeAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] - (E::ZERO + main_current[1] - main_current[4] - main_current[8] + E::ONE + main_current[2] - main_current[5] - main_current[9] + E::from(Felt::new(2_u64)) + main_current[3] - main_current[6] - main_current[10]); + result[0] = main_current[0] - (main_current[1] - main_current[4] - main_current[8] + E::ONE + main_current[2] - main_current[5] - main_current[9] + E::from(Felt::new(2_u64)) + main_current[3] - main_current[6] - main_current[10]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/constant_in_range/test_air.rs b/air-script/tests/constant_in_range/test_air.rs index 3092705d4..064e5eb19 100644 --- a/air-script/tests/constant_in_range/test_air.rs +++ b/air-script/tests/constant_in_range/test_air.rs @@ -3,7 +3,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ - constant_in_range::constant_in_range::{ConstantInRangeAir, PublicInputs}, + constant_in_range::constant_in_range::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -45,17 +46,9 @@ impl AirTester for ConstantInRangeAirTester { } } -#[test] -fn test_constant_in_range_air() { - let air_tester = Box::new(ConstantInRangeAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = ConstantInRangeAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_constant_in_range_air, + crate::constant_in_range::constant_in_range::ConstantInRangeAir, + ConstantInRangeAirTester, + 1024 +); diff --git a/air-script/tests/constants/constants.rs b/air-script/tests/constants/constants.rs index 6dbae1db0..19b8849a1 100644 --- a/air-script/tests/constants/constants.rs +++ b/air-script/tests/constants/constants.rs @@ -85,7 +85,7 @@ impl Air for ConstantsAir { result.push(Assertion::single(0, 0, Felt::ONE)); result.push(Assertion::single(1, 0, Felt::ONE)); result.push(Assertion::single(2, 0, Felt::ZERO)); - result.push(Assertion::single(3, 0, Felt::ONE - Felt::new(2) + Felt::new(2) - Felt::ZERO)); + result.push(Assertion::single(3, 0, Felt::ONE)); result.push(Assertion::single(4, 0, Felt::ONE)); result.push(Assertion::single(6, self.last_step(), Felt::ZERO)); result @@ -100,9 +100,9 @@ impl Air for ConstantsAir { let main_current = frame.current(); let main_next = frame.next(); result[0] = main_next[0] - (main_current[0] + E::ONE); - result[1] = main_next[1] - E::ZERO * main_current[1]; - result[2] = main_next[2] - E::ONE * main_current[2]; - result[3] = main_next[5] - (main_current[5] + E::ONE + E::ZERO); + result[1] = main_next[1]; + result[2] = main_next[2] - main_current[2]; + result[3] = main_next[5] - (main_current[5] + E::ONE); result[4] = main_current[4] - E::ONE; } diff --git a/air-script/tests/constants/test_air.rs b/air-script/tests/constants/test_air.rs index bf8563ba4..97a048a31 100644 --- a/air-script/tests/constants/test_air.rs +++ b/air-script/tests/constants/test_air.rs @@ -3,7 +3,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ - constants::constants::{ConstantsAir, PublicInputs}, + constants::constants::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -43,17 +44,9 @@ impl AirTester for ConstantsAirTester { } } -#[test] -fn test_constants_air() { - let air_tester = Box::new(ConstantsAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = ConstantsAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_constants_air, + crate::constants::constants::ConstantsAir, + ConstantsAirTester, + 1024 +); diff --git a/air-script/tests/constraint_comprehension/test_air.rs b/air-script/tests/constraint_comprehension/test_air.rs index 37a804042..6bdb9cc8d 100644 --- a/air-script/tests/constraint_comprehension/test_air.rs +++ b/air-script/tests/constraint_comprehension/test_air.rs @@ -3,9 +3,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ - constraint_comprehension::constraint_comprehension::{ - ConstraintComprehensionAir, PublicInputs, - }, + constraint_comprehension::constraint_comprehension::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -49,17 +48,9 @@ impl AirTester for ConstraintComprehensionAirTester { } } -#[test] -fn test_constraint_comprehension_air() { - let air_tester = Box::new(ConstraintComprehensionAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = ConstraintComprehensionAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_constraint_comprehension_air, + crate::constraint_comprehension::constraint_comprehension::ConstraintComprehensionAir, + ConstraintComprehensionAirTester, + 1024 +); diff --git a/air-script/tests/docs_sync.rs b/air-script/tests/docs_sync.rs new file mode 100644 index 000000000..c7fc12e72 --- /dev/null +++ b/air-script/tests/docs_sync.rs @@ -0,0 +1,64 @@ +use std::{path::Path, process::Command}; + +#[test] +fn docs_sync() { + let examples_dir = Path::new("../docs/examples"); + // Use CARGO_MANIFEST_DIR to build an absolute path to airc, needed on Windows to correctly use + // `current_dir`. + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + let airc_path = Path::new(&manifest_dir).join("../target/release/airc"); + + // Build the CLI tool first + let build_output = Command::new("cargo") + .args(["build", "--release", "-p", "air-script"]) + .current_dir("../") + .output() + .expect("Failed to build airc CLI"); + + assert!( + build_output.status.success(), + "Failed to build airc CLI: {}", + String::from_utf8_lossy(&build_output.stderr) + ); + + // Find all .air files in the examples directory + let air_files: Vec<_> = std::fs::read_dir(examples_dir) + .expect("Failed to read examples directory") + .filter_map(|entry| { + let path = entry.expect("Failed to read directory entry").path(); + if path.extension().and_then(|ext| ext.to_str()) == Some("air") { + Some(path) + } else { + None + } + }) + .collect(); + + assert!(!air_files.is_empty(), "No .air files found in docs/examples"); + + // Compile each example + for air_file in air_files { + let file_name = air_file.file_name().unwrap().to_string_lossy(); + let output_path = air_file.with_extension("rs"); + + let output = Command::new(&airc_path) + .args(["transpile", air_file.to_str().unwrap(), "-o", output_path.to_str().unwrap()]) + .current_dir("../") + .output() + .unwrap_or_else(|_| panic!("Failed to transpile {file_name}")); + + assert!( + output.status.success(), + "Failed to transpile {}: {}", + file_name, + String::from_utf8_lossy(&output.stderr) + ); + + println!("Successfully transpiled: {file_name}"); + + // Clean up generated Rust files + let _ = std::fs::remove_file(output_path); + } + + println!("All documentation examples compiled successfully!"); +} diff --git a/air-script/tests/evaluators/evaluators.air b/air-script/tests/evaluators/evaluators.air index 9eb6270d6..dbce31dc6 100644 --- a/air-script/tests/evaluators/evaluators.air +++ b/air-script/tests/evaluators/evaluators.air @@ -14,6 +14,18 @@ ev are_all_binary([c[3]]) { enf is_binary([c]) for c in c; } +fn square_all(x: felt[3]) -> felt[3] { + return [x^2 for x in x]; +} + +ev are_squares_in_order([d[3]]) { + let base = [i for i in 0..3]; + let squares = square_all(base); + enf d[0] = squares[0]; + enf d[1] = squares[1]; + enf d[2] = squares[2]; +} + trace_columns { main: [b, c[3], d[3]], } @@ -30,4 +42,5 @@ integrity_constraints { enf are_unchanged([b, c[1], d[2]]); enf is_binary([b]); enf are_all_binary([c]); + enf are_squares_in_order([d]); } \ No newline at end of file diff --git a/air-script/tests/evaluators/evaluators.rs b/air-script/tests/evaluators/evaluators.rs index 78e001197..1c1a798e4 100644 --- a/air-script/tests/evaluators/evaluators.rs +++ b/air-script/tests/evaluators/evaluators.rs @@ -47,7 +47,7 @@ impl Air for EvaluatorsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2)]; + let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -89,6 +89,9 @@ impl Air for EvaluatorsAir { result[4] = main_current[1] * main_current[1] - main_current[1]; result[5] = main_current[2] * main_current[2] - main_current[2]; result[6] = main_current[3] * main_current[3] - main_current[3]; + result[7] = main_current[4]; + result[8] = main_current[5] - E::ONE; + result[9] = main_current[6] - E::from(Felt::new(4_u64)); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/evaluators/evaluators_nested_slice_call.air b/air-script/tests/evaluators/evaluators_nested_slice_call.air new file mode 100644 index 000000000..539e708e9 --- /dev/null +++ b/air-script/tests/evaluators/evaluators_nested_slice_call.air @@ -0,0 +1,37 @@ +def EvaluatorsSliceAir + +ev is_binary([x]) { + enf x^2 = x; +} + +trace_columns { + main: [a[20]], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + enf a.first = 0 for a in a; +} + +integrity_constraints { + enf constraints([a[0..15]]); +} + +ev constraints([cols[15]]) { + enf chiplets_constraints([cols[5..15]]); +} + +ev chiplets_constraints([chiplets[10]]) { + enf chiplet_selectors([chiplets[0..5]]); +} + +ev chiplet_selectors([s[5]]) { # Expects 5 individual parameters + enf is_binary([s[0]]); + enf is_binary([s[1]]) when s[0]; + enf is_binary([s[2]]) when s[0] & s[1]; + enf is_binary([s[3]]) when s[0] & s[1] & s[2]; + enf is_binary([s[4]]) when s[0] & s[1] & s[2] & s[3]; +} diff --git a/air-script/tests/list_folding/list_folding_with_mir.rs b/air-script/tests/evaluators/evaluators_nested_slice_call.rs similarity index 59% rename from air-script/tests/list_folding/list_folding_with_mir.rs rename to air-script/tests/evaluators/evaluators_nested_slice_call.rs index 3ad7e54ea..04c9eb822 100644 --- a/air-script/tests/list_folding/list_folding_with_mir.rs +++ b/air-script/tests/evaluators/evaluators_nested_slice_call.rs @@ -27,18 +27,18 @@ impl ToElements for PublicInputs { } } -pub struct ListFoldingAir { +pub struct EvaluatorsSliceAir { context: AirContext, stack_inputs: [Felt; 16], } -impl ListFoldingAir { +impl EvaluatorsSliceAir { pub fn last_step(&self) -> usize { self.trace_length() - self.context().num_transition_exemptions() } } -impl Air for ListFoldingAir { +impl Air for EvaluatorsSliceAir { type BaseField = Felt; type PublicInputs = PublicInputs; @@ -47,9 +47,9 @@ impl Air for ListFoldingAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(2)]; + let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(6)]; let aux_degrees = vec![]; - let num_main_assertions = 1; + let num_main_assertions = 20; let num_aux_assertions = 0; let context = AirContext::new_multi_segment( @@ -70,7 +70,26 @@ impl Air for ListFoldingAir { fn get_assertions(&self) -> Vec> { let mut result = Vec::new(); + result.push(Assertion::single(0, 0, Felt::ZERO)); + result.push(Assertion::single(1, 0, Felt::ZERO)); + result.push(Assertion::single(2, 0, Felt::ZERO)); + result.push(Assertion::single(3, 0, Felt::ZERO)); + result.push(Assertion::single(4, 0, Felt::ZERO)); + result.push(Assertion::single(5, 0, Felt::ZERO)); + result.push(Assertion::single(6, 0, Felt::ZERO)); + result.push(Assertion::single(7, 0, Felt::ZERO)); + result.push(Assertion::single(8, 0, Felt::ZERO)); + result.push(Assertion::single(9, 0, Felt::ZERO)); + result.push(Assertion::single(10, 0, Felt::ZERO)); result.push(Assertion::single(11, 0, Felt::ZERO)); + result.push(Assertion::single(12, 0, Felt::ZERO)); + result.push(Assertion::single(13, 0, Felt::ZERO)); + result.push(Assertion::single(14, 0, Felt::ZERO)); + result.push(Assertion::single(15, 0, Felt::ZERO)); + result.push(Assertion::single(16, 0, Felt::ZERO)); + result.push(Assertion::single(17, 0, Felt::ZERO)); + result.push(Assertion::single(18, 0, Felt::ZERO)); + result.push(Assertion::single(19, 0, Felt::ZERO)); result } @@ -82,10 +101,11 @@ impl Air for ListFoldingAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_next[5] - (E::ZERO + main_current[9] + main_current[10] + main_current[11] + main_current[12] + E::ONE * main_current[13] * main_current[14] * main_current[15] * main_current[16]); - result[1] = main_next[6] - (E::ZERO + main_current[9] + main_current[10] + main_current[11] + main_current[12] + E::ONE * main_current[13] * main_current[14] * main_current[15] * main_current[16]); - result[2] = main_next[7] - (E::ZERO + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16] + E::ONE * (main_current[9] + main_current[13]) * (main_current[10] + main_current[14]) * (main_current[11] + main_current[15]) * (main_current[12] + main_current[16])); - result[3] = main_next[8] - (main_current[1] + E::ZERO + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16] + E::ZERO + main_current[9] * main_current[13] + main_current[10] * main_current[14] + main_current[11] * main_current[15] + main_current[12] * main_current[16]); + result[0] = main_current[5] * main_current[5] - main_current[5]; + result[1] = main_current[5] * (main_current[6] * main_current[6] - main_current[6]); + result[2] = main_current[5] * main_current[6] * (main_current[7] * main_current[7] - main_current[7]); + result[3] = main_current[5] * main_current[6] * main_current[7] * (main_current[8] * main_current[8] - main_current[8]); + result[4] = main_current[5] * main_current[6] * main_current[7] * main_current[8] * (main_current[9] * main_current[9] - main_current[9]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/evaluators/evaluators_slice.air b/air-script/tests/evaluators/evaluators_slice.air new file mode 100644 index 000000000..baeaeb45b --- /dev/null +++ b/air-script/tests/evaluators/evaluators_slice.air @@ -0,0 +1,33 @@ +def EvaluatorsSliceAir + +ev is_binary([x]) { + enf x^2 = x; +} + +trace_columns { + main: [a[20]], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + enf a.first = 0 for a in a; +} + +integrity_constraints { + enf chiplets_constraints([a]); +} + +ev chiplets_constraints([chiplets[20]]) { + enf chiplet_selectors([chiplets[0..5]]); # Should expand slice to 5 args +} + +ev chiplet_selectors([s[5]]) { # Expects 5 individual parameters + enf is_binary([s[0]]); + enf is_binary([s[1]]) when s[0]; + enf is_binary([s[2]]) when s[0] & s[1]; + enf is_binary([s[3]]) when s[0] & s[1] & s[2]; + enf is_binary([s[4]]) when s[0] & s[1] & s[2] & s[3]; +} diff --git a/air-script/tests/evaluators/evaluators_slice.rs b/air-script/tests/evaluators/evaluators_slice.rs new file mode 100644 index 000000000..4442a5e67 --- /dev/null +++ b/air-script/tests/evaluators/evaluators_slice.rs @@ -0,0 +1,120 @@ +use winter_air::{Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement, ToElements}; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 16], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 16]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + self.stack_inputs.write_into(target); + } +} + +impl ToElements for PublicInputs { + fn to_elements(&self) -> Vec { + let mut elements = Vec::new(); + elements.extend_from_slice(&self.stack_inputs); + elements + } +} + +pub struct EvaluatorsSliceAir { + context: AirContext, + stack_inputs: [Felt; 16], +} + +impl EvaluatorsSliceAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for EvaluatorsSliceAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(6)]; + let aux_degrees = vec![]; + let num_main_assertions = 20; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, Felt::ZERO)); + result.push(Assertion::single(1, 0, Felt::ZERO)); + result.push(Assertion::single(2, 0, Felt::ZERO)); + result.push(Assertion::single(3, 0, Felt::ZERO)); + result.push(Assertion::single(4, 0, Felt::ZERO)); + result.push(Assertion::single(5, 0, Felt::ZERO)); + result.push(Assertion::single(6, 0, Felt::ZERO)); + result.push(Assertion::single(7, 0, Felt::ZERO)); + result.push(Assertion::single(8, 0, Felt::ZERO)); + result.push(Assertion::single(9, 0, Felt::ZERO)); + result.push(Assertion::single(10, 0, Felt::ZERO)); + result.push(Assertion::single(11, 0, Felt::ZERO)); + result.push(Assertion::single(12, 0, Felt::ZERO)); + result.push(Assertion::single(13, 0, Felt::ZERO)); + result.push(Assertion::single(14, 0, Felt::ZERO)); + result.push(Assertion::single(15, 0, Felt::ZERO)); + result.push(Assertion::single(16, 0, Felt::ZERO)); + result.push(Assertion::single(17, 0, Felt::ZERO)); + result.push(Assertion::single(18, 0, Felt::ZERO)); + result.push(Assertion::single(19, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[0] * main_current[0] - main_current[0]; + result[1] = main_current[0] * (main_current[1] * main_current[1] - main_current[1]); + result[2] = main_current[0] * main_current[1] * (main_current[2] * main_current[2] - main_current[2]); + result[3] = main_current[0] * main_current[1] * main_current[2] * (main_current[3] * main_current[3] - main_current[3]); + result[4] = main_current[0] * main_current[1] * main_current[2] * main_current[3] * (main_current[4] * main_current[4] - main_current[4]); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + } +} \ No newline at end of file diff --git a/air-script/tests/evaluators/evaluators_slice_slicing.air b/air-script/tests/evaluators/evaluators_slice_slicing.air new file mode 100644 index 000000000..bb1245dcf --- /dev/null +++ b/air-script/tests/evaluators/evaluators_slice_slicing.air @@ -0,0 +1,34 @@ +def EvaluatorsSliceAir + +ev is_binary([x]) { + enf x^2 = x; +} + +trace_columns { + main: [a[20]], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + enf a.first = 0 for a in a; +} + +integrity_constraints { + enf chiplets_constraints([a]); +} + +ev chiplets_constraints([chiplets[20]]) { + let sub_slice = chiplets[0..5]; + enf chiplet_selectors([sub_slice[0..5]]); # Should expand slice to 5 args +} + +ev chiplet_selectors([s[5]]) { # Expects 5 individual parameters + enf is_binary([s[0]]); + enf is_binary([s[1]]) when s[0]; + enf is_binary([s[2]]) when s[0] & s[1]; + enf is_binary([s[3]]) when s[0] & s[1] & s[2]; + enf is_binary([s[4]]) when s[0] & s[1] & s[2] & s[3]; +} diff --git a/air-script/tests/evaluators/test_air.rs b/air-script/tests/evaluators/test_air.rs index 14a683fe9..6bf42ed83 100644 --- a/air-script/tests/evaluators/test_air.rs +++ b/air-script/tests/evaluators/test_air.rs @@ -1,9 +1,10 @@ use winter_air::Air; -use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{FieldElement, fields::f64::BaseElement as Felt}; use winterfell::{Trace, TraceTable}; use crate::{ - evaluators::evaluators::{EvaluatorsAir, PublicInputs}, + evaluators::evaluators::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -22,10 +23,11 @@ impl AirTester for EvaluatorsAirTester { |state| { state[0] = start; state[1] = start; + state[2] = start; state[3] = start; - state[4] = start; - state[5] = start; - state[6] = start; + state[4] = Felt::ZERO; + state[5] = Felt::new(1); + state[6] = Felt::new(4); }, |_, state| {}, ); @@ -39,17 +41,9 @@ impl AirTester for EvaluatorsAirTester { } } -#[test] -fn test_evaluators_air() { - let air_tester = Box::new(EvaluatorsAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = EvaluatorsAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_evaluators_air, + crate::evaluators::evaluators::EvaluatorsAir, + EvaluatorsAirTester, + 1024 +); diff --git a/air-script/tests/fibonacci/test_air.rs b/air-script/tests/fibonacci/test_air.rs index 4ba4674e2..e71a8dc1f 100644 --- a/air-script/tests/fibonacci/test_air.rs +++ b/air-script/tests/fibonacci/test_air.rs @@ -5,7 +5,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{AuxTraceWithMetadata, Trace, TraceTable, matrix::ColMatrix}; use crate::{ - fibonacci::fibonacci::{FibonacciAir, PublicInputs}, + fibonacci::fibonacci::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -44,17 +45,9 @@ impl AirTester for FibonacciAirTester { } } -#[test] -fn test_fibonacci_air() { - let air_tester = Box::new(FibonacciAirTester {}); - let length = 32; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = FibonacciAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_fibonacci_air, + crate::fibonacci::fibonacci::FibonacciAir, + FibonacciAirTester, + 32 +); diff --git a/air-script/tests/function_import/function_import.air b/air-script/tests/function_import/function_import.air new file mode 100644 index 000000000..7bf22591e --- /dev/null +++ b/air-script/tests/function_import/function_import.air @@ -0,0 +1,22 @@ +def FunctionImportTest + +use utils::add; +use utils::multiply; + +trace_columns { + main: [a, b, c], +} + +public_inputs { + x: [1], +} + +boundary_constraints { + enf a.first = 0; +} + +integrity_constraints { + # Use imported pure functions + enf a = add(b, c); + enf b = multiply(a, c); +} diff --git a/air-script/tests/function_import/function_import.rs b/air-script/tests/function_import/function_import.rs new file mode 100644 index 000000000..bb6044837 --- /dev/null +++ b/air-script/tests/function_import/function_import.rs @@ -0,0 +1,98 @@ +use winter_air::{Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement, ToElements}; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + x: [Felt; 1], +} + +impl PublicInputs { + pub fn new(x: [Felt; 1]) -> Self { + Self { x } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + self.x.write_into(target); + } +} + +impl ToElements for PublicInputs { + fn to_elements(&self) -> Vec { + let mut elements = Vec::new(); + elements.extend_from_slice(&self.x); + elements + } +} + +pub struct FunctionImportTest { + context: AirContext, + x: [Felt; 1], +} + +impl FunctionImportTest { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for FunctionImportTest { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(2)]; + let aux_degrees = vec![]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, x: public_inputs.x } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[0] - (main_current[1] + main_current[2]); + result[1] = main_current[1] - main_current[0] * main_current[2]; + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + } +} \ No newline at end of file diff --git a/air-script/tests/function_import/mod.rs b/air-script/tests/function_import/mod.rs new file mode 100644 index 000000000..77e95c2a2 --- /dev/null +++ b/air-script/tests/function_import/mod.rs @@ -0,0 +1,3 @@ +#[rustfmt::skip] +#[allow(clippy::all)] +mod function_import; diff --git a/air-script/tests/function_import/utils.air b/air-script/tests/function_import/utils.air new file mode 100644 index 000000000..501983952 --- /dev/null +++ b/air-script/tests/function_import/utils.air @@ -0,0 +1,9 @@ +mod utils + +fn add(a: felt, b: felt) -> felt { + return a + b; +} + +fn multiply(a: felt, b: felt) -> felt { + return a * b; +} diff --git a/air-script/tests/functions/functions_simple.air b/air-script/tests/functions/functions_simple.air index 2188e84bc..b8aff0342 100644 --- a/air-script/tests/functions/functions_simple.air +++ b/air-script/tests/functions/functions_simple.air @@ -47,7 +47,7 @@ boundary_constraints { integrity_constraints { # -------- function call is assigned to a variable and used in a binary expression ------------ - # binary expression invloving scalar expressions + # binary expression involving scalar expressions let simple_expression = t * v; enf simple_expression = 1; diff --git a/air-script/tests/functions/functions_simple.rs b/air-script/tests/functions/functions_simple.rs index 6bce04e38..d1ca9052f 100644 --- a/air-script/tests/functions/functions_simple.rs +++ b/air-script/tests/functions/functions_simple.rs @@ -47,7 +47,7 @@ impl Air for FunctionsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(1)]; + let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(5), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -89,8 +89,7 @@ impl Air for FunctionsAir { result[4] = main_current[0] * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; result[5] = main_current[1] + (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; result[6] = main_current[4] + main_current[5] + main_current[6] + main_current[7] - E::ONE; - result[7] = (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * main_current[4] * main_current[5] * main_current[6] * main_current[7] - E::ONE; - result[8] = (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * E::from(Felt::new(4_u64)) - E::ONE; + result[7] = (main_current[4] + main_current[5] + main_current[6] + main_current[7]) * E::from(Felt::new(4_u64)) - E::ONE; } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/functions/inlined_functions_simple.air b/air-script/tests/functions/inlined_functions_simple.air index f6df245a8..a4b847270 100644 --- a/air-script/tests/functions/inlined_functions_simple.air +++ b/air-script/tests/functions/inlined_functions_simple.air @@ -17,7 +17,7 @@ boundary_constraints { integrity_constraints { # -------- function call is assigned to a variable and used in a binary expression ------------ - # binary expression invloving scalar expressions + # binary expression involving scalar expressions let simple_expression = t * v; enf simple_expression = 1; diff --git a/air-script/tests/functions/mod.rs b/air-script/tests/functions/mod.rs index 9c6b2ea40..4a85b16cc 100644 --- a/air-script/tests/functions/mod.rs +++ b/air-script/tests/functions/mod.rs @@ -3,8 +3,5 @@ mod functions_complex; #[rustfmt::skip] #[allow(clippy::all)] -mod functions_complex_with_mir; -#[rustfmt::skip] -#[allow(clippy::all)] mod functions_simple; mod test_air; diff --git a/air-script/tests/functions/test_air.rs b/air-script/tests/functions/test_air.rs index 105216c6c..ae578bf0d 100644 --- a/air-script/tests/functions/test_air.rs +++ b/air-script/tests/functions/test_air.rs @@ -3,7 +3,8 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ - functions::functions_complex_with_mir::{FunctionsAir, PublicInputs}, + functions::functions_complex::PublicInputs, + generate_air_test, helpers::{AirTester, MyTraceTable}, }; @@ -58,25 +59,18 @@ impl AirTester for FunctionsAirTester { 16, // blowup factor 0, // grinding factor FieldExtension::None, - 8, // FRI folding factor - 31, // FRI max remainder polynomial degree - BatchingMethod::Linear, // method of batching used in computing constraint composition polynomial + 8, // FRI folding factor + 31, // FRI max remainder polynomial degree + BatchingMethod::Linear, /* method of batching used in computing constraint + * composition polynomial */ BatchingMethod::Linear, // method of batching used in computing DEEP polynomial ) } } -#[test] -fn test_functions_complex_air() { - let air_tester = Box::new(FunctionsAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = FunctionsAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_functions_complex_air, + crate::functions::functions_complex::FunctionsAir, + FunctionsAirTester, + 1024 +); diff --git a/air-script/tests/helpers/macros.rs b/air-script/tests/helpers/macros.rs new file mode 100644 index 000000000..e134b8d48 --- /dev/null +++ b/air-script/tests/helpers/macros.rs @@ -0,0 +1,29 @@ +// Helper macros for test generation + +/// Generates an AIR test function with the standard boilerplate +/// +/// # Arguments +/// * `test_name` - The identifier for the test function (e.g., `test_binary_air`) +/// * `air_name` - The identifier for the AIR struct (e.g., `BinaryAir`) +/// * `tester_name` - The identifier for the `AirTester` struct (e.g., `BinaryAirTester`) +/// * `trace_length` - The length of the trace for the test (e.g., `32` or `1024`) +#[macro_export] +macro_rules! generate_air_test { + ($test_name:ident, $air_name:path, $tester_name:ident, $trace_length:expr) => { + #[test] + fn $test_name() { + use winter_math::fields::f64::BaseElement as Felt; + let air_tester = Box::new($tester_name {}); + let length = $trace_length; + + let main_trace = air_tester.build_main_trace(length); + let aux_trace = air_tester.build_aux_trace(length); + let pub_inputs = air_tester.public_inputs(); + let trace_info = air_tester.build_trace_info(length); + let options = air_tester.build_proof_options(); + + let air = <$air_name>::new(trace_info, pub_inputs, options); + main_trace.validate::<$air_name, Felt>(&air, aux_trace.as_ref()); + } + }; +} diff --git a/air-script/tests/helpers/mod.rs b/air-script/tests/helpers/mod.rs index 8406a3389..dbabc0709 100644 --- a/air-script/tests/helpers/mod.rs +++ b/air-script/tests/helpers/mod.rs @@ -2,8 +2,11 @@ use winter_air::{BatchingMethod, EvaluationFrame, FieldExtension, ProofOptions, use winter_math::fields::f64::BaseElement as Felt; use winterfell::{AuxTraceWithMetadata, Trace, TraceTable, matrix::ColMatrix}; -/// We need to encapsulate the trace table in a struct to manually implement the `aux_trace_width` method of the `Table` trait. -/// Otherwise, using only a TraceTable will return an `aux_trace_width` of 0 even if we provide a non-empty aux trace in `Trace::validate`, +pub mod macros; + +/// We need to encapsulate the trace table in a struct to manually implement the `aux_trace_width` +/// method of the `Table` trait. Otherwise, using only a TraceTable will return an +/// `aux_trace_width` of 0 even if we provide a non-empty aux trace in `Trace::validate`, /// and it fails the tests. pub struct MyTraceTable { pub trace: TraceTable, @@ -62,9 +65,10 @@ pub trait AirTester { 8, // blowup factor 0, // grinding factor FieldExtension::None, - 8, // FRI folding factor - 31, // FRI max remainder polynomial degree - BatchingMethod::Linear, // method of batching used in computing constraint composition polynomial + 8, // FRI folding factor + 31, // FRI max remainder polynomial degree + BatchingMethod::Linear, /* method of batching used in computing constraint + * composition polynomial */ BatchingMethod::Linear, // method of batching used in computing DEEP polynomial ) } diff --git a/air-script/tests/indexed_trace_access/test_air.rs b/air-script/tests/indexed_trace_access/test_air.rs index 63a95da4e..707dbba10 100644 --- a/air-script/tests/indexed_trace_access/test_air.rs +++ b/air-script/tests/indexed_trace_access/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - indexed_trace_access::indexed_trace_access::{PublicInputs, TraceAccessAir}, + indexed_trace_access::indexed_trace_access::PublicInputs, }; #[derive(Clone)] @@ -39,17 +40,9 @@ impl AirTester for TraceAccessAirTester { } } -#[test] -fn test_indexed_trace_access_air() { - let air_tester = Box::new(TraceAccessAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = TraceAccessAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_indexed_trace_access_air, + crate::indexed_trace_access::indexed_trace_access::TraceAccessAir, + TraceAccessAirTester, + 1024 +); diff --git a/air-script/tests/list_comprehension/list_comprehension.air b/air-script/tests/list_comprehension/list_comprehension.air index d35ed4728..d85ed88ff 100644 --- a/air-script/tests/list_comprehension/list_comprehension.air +++ b/air-script/tests/list_comprehension/list_comprehension.air @@ -12,6 +12,10 @@ boundary_constraints { enf c[2].first = 0; } +fn double(a: felt) -> felt { + return 2*a; +} + integrity_constraints { let x = [fmp for fmp in fmp]; enf clk = x[1]; @@ -25,6 +29,11 @@ integrity_constraints { let diff_slice_iterables = [x - y for (x, y) in (c[0..2], d[1..3])]; enf b[1] = clk * diff_slice_iterables[1]; - let m = [w + x - y - z for (w, x, y, z) in (0..3, b, c[0..3], d[0..3])]; + let m = [10 * w + x - y - z for (w, x, y, z) in (0..3, b, c[0..3], d[0..3])]; + enf fmp[0] = m[0] + m[1] + m[2]; + + let alpha = [double(i) for i in 0..3]; + let beta = [double(i) for i in 0..4]; + enf d[2] = alpha[2] + beta[3]; } \ No newline at end of file diff --git a/air-script/tests/list_comprehension/list_comprehension.rs b/air-script/tests/list_comprehension/list_comprehension.rs index d7617e424..1c9cfcf0a 100644 --- a/air-script/tests/list_comprehension/list_comprehension.rs +++ b/air-script/tests/list_comprehension/list_comprehension.rs @@ -47,7 +47,7 @@ impl Air for ListComprehensionAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(1)]; + let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -86,7 +86,8 @@ impl Air for ListComprehensionAir { result[1] = main_current[4] - main_current[0] * E::from(Felt::new(8_u64)) * main_current[11]; result[2] = main_current[4] - main_current[0] * (main_next[8] - main_next[12]); result[3] = main_current[6] - main_current[0] * (main_current[9] - main_current[14]); - result[4] = main_current[1] - (E::ZERO + main_current[5] - main_current[8] - main_current[12] + E::ONE + main_current[6] - main_current[9] - main_current[13] + E::from(Felt::new(2_u64)) + main_current[7] - main_current[10] - main_current[14]); + result[4] = main_current[1] - (main_current[5] - main_current[8] - main_current[12] + E::from(Felt::new(10_u64)) + main_current[6] - main_current[9] - main_current[13] + E::from(Felt::new(20_u64)) + main_current[7] - main_current[10] - main_current[14]); + result[5] = main_current[14] - E::from(Felt::new(10_u64)); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/list_comprehension/list_comprehension_nested.air b/air-script/tests/list_comprehension/list_comprehension_nested.air index 4eb73007b..1e0d413b1 100644 --- a/air-script/tests/list_comprehension/list_comprehension_nested.air +++ b/air-script/tests/list_comprehension/list_comprehension_nested.air @@ -21,7 +21,8 @@ boundary_constraints { integrity_constraints { let state = a; let expected = [3, 5, 7]; - enf apply_mds(state) = expected; + let result = apply_mds(state); + enf result = expected for (expected, result) in (expected, result); } fn apply_mds(state: felt[2]) -> felt[3] { diff --git a/air-script/tests/list_comprehension/list_comprehension_nested.rs b/air-script/tests/list_comprehension/list_comprehension_nested.rs index ab9703069..dfdb535f9 100644 --- a/air-script/tests/list_comprehension/list_comprehension_nested.rs +++ b/air-script/tests/list_comprehension/list_comprehension_nested.rs @@ -82,9 +82,9 @@ impl Air for ListComprehensionAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = E::ZERO + main_current[0] * E::ONE + main_current[1] * E::from(Felt::new(2_u64)) - E::from(Felt::new(3_u64)); - result[1] = E::ZERO + main_current[0] * E::from(Felt::new(2_u64)) + main_current[1] * E::from(Felt::new(3_u64)) - E::from(Felt::new(5_u64)); - result[2] = E::ZERO + main_current[0] * E::from(Felt::new(3_u64)) + main_current[1] * E::from(Felt::new(4_u64)) - E::from(Felt::new(7_u64)); + result[0] = main_current[0] + main_current[1] * E::from(Felt::new(2_u64)) - E::from(Felt::new(3_u64)); + result[1] = main_current[0] * E::from(Felt::new(2_u64)) + main_current[1] * E::from(Felt::new(3_u64)) - E::from(Felt::new(5_u64)); + result[2] = main_current[0] * E::from(Felt::new(3_u64)) + main_current[1] * E::from(Felt::new(4_u64)) - E::from(Felt::new(7_u64)); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/list_comprehension/mod.rs b/air-script/tests/list_comprehension/mod.rs index f34af1f5f..5f090bea2 100644 --- a/air-script/tests/list_comprehension/mod.rs +++ b/air-script/tests/list_comprehension/mod.rs @@ -1,7 +1,4 @@ #[rustfmt::skip] #[allow(clippy::all)] mod list_comprehension; -#[rustfmt::skip] -#[allow(clippy::all)] -mod list_comprehension_with_mir; mod test_air; diff --git a/air-script/tests/list_comprehension/test_air.rs b/air-script/tests/list_comprehension/test_air.rs index 51e851be8..c3e9b396f 100644 --- a/air-script/tests/list_comprehension/test_air.rs +++ b/air-script/tests/list_comprehension/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - list_comprehension::list_comprehension_with_mir::{ListComprehensionAir, PublicInputs}, + list_comprehension::list_comprehension::PublicInputs, }; #[derive(Clone)] @@ -21,7 +22,7 @@ impl AirTester for ListComprehensionAirTester { trace.fill( |state| { state[0] = start; - state[1] = Felt::new(3); + state[1] = Felt::new(20); state[2] = start; state[3] = start; state[4] = start; @@ -34,7 +35,7 @@ impl AirTester for ListComprehensionAirTester { state[11] = start; state[12] = start; state[13] = start; - state[14] = start; + state[14] = Felt::new(10); state[15] = start; }, |_, state| { @@ -51,17 +52,9 @@ impl AirTester for ListComprehensionAirTester { } } -#[test] -fn test_list_comprehension_air() { - let air_tester = Box::new(ListComprehensionAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = ListComprehensionAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_list_comprehension_air, + crate::list_comprehension::list_comprehension::ListComprehensionAir, + ListComprehensionAirTester, + 1024 +); diff --git a/air-script/tests/list_folding/mod.rs b/air-script/tests/list_folding/mod.rs index 7d3abea94..5e992bc22 100644 --- a/air-script/tests/list_folding/mod.rs +++ b/air-script/tests/list_folding/mod.rs @@ -1,7 +1,4 @@ #[rustfmt::skip] #[allow(clippy::all)] mod list_folding; -#[rustfmt::skip] -#[allow(clippy::all)] -mod list_folding_with_mir; mod test_air; diff --git a/air-script/tests/list_folding/test_air.rs b/air-script/tests/list_folding/test_air.rs index 7fbd23ba5..9d79fad28 100644 --- a/air-script/tests/list_folding/test_air.rs +++ b/air-script/tests/list_folding/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - list_folding::list_folding_with_mir::{ListFoldingAir, PublicInputs}, + list_folding::list_folding::PublicInputs, }; #[derive(Clone)] @@ -52,17 +53,9 @@ impl AirTester for ListFoldingAirTester { } } -#[test] -fn test_list_folding_air() { - let air_tester = Box::new(ListFoldingAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = ListFoldingAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_list_folding_air, + crate::list_folding::list_folding::ListFoldingAir, + ListFoldingAirTester, + 1024 +); diff --git a/air-script/tests/mod.rs b/air-script/tests/mod.rs index 4733eeff3..e96e603d8 100644 --- a/air-script/tests/mod.rs +++ b/air-script/tests/mod.rs @@ -9,6 +9,8 @@ mod bitwise; #[allow(unused_variables, dead_code, unused_mut)] mod buses; #[allow(unused_variables, dead_code, unused_mut)] +mod computed_indices; +#[allow(unused_variables, dead_code, unused_mut)] mod constant_in_range; #[allow(unused_variables, dead_code, unused_mut)] mod constants; @@ -19,6 +21,8 @@ mod evaluators; #[allow(unused_variables, dead_code, unused_mut)] mod fibonacci; #[allow(unused_variables, dead_code, unused_mut)] +mod function_import; +#[allow(unused_variables, dead_code, unused_mut)] mod functions; #[allow(unused_variables, dead_code, unused_mut)] mod indexed_trace_access; @@ -38,3 +42,5 @@ mod system; mod trace_col_groups; #[allow(unused_variables, dead_code, unused_mut)] mod variables; + +mod docs_sync; diff --git a/air-script/tests/periodic_columns/periodic_columns.rs b/air-script/tests/periodic_columns/periodic_columns.rs index eb1c592da..db71fbe06 100644 --- a/air-script/tests/periodic_columns/periodic_columns.rs +++ b/air-script/tests/periodic_columns/periodic_columns.rs @@ -82,8 +82,8 @@ impl Air for PeriodicColumnsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = periodic_values[0] * (main_current[1] + main_current[2]) - E::ZERO; - result[1] = periodic_values[1] * (main_next[0] - main_current[0]) - E::ZERO; + result[0] = periodic_values[0] * (main_current[1] + main_current[2]); + result[1] = periodic_values[1] * (main_next[0] - main_current[0]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/periodic_columns/test_air.rs b/air-script/tests/periodic_columns/test_air.rs index 3a0b31ed4..f157a98df 100644 --- a/air-script/tests/periodic_columns/test_air.rs +++ b/air-script/tests/periodic_columns/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - periodic_columns::periodic_columns::{PeriodicColumnsAir, PublicInputs}, + periodic_columns::periodic_columns::PublicInputs, }; #[derive(Clone)] @@ -36,17 +37,9 @@ impl AirTester for PeriodicColumnsAirTester { } } -#[test] -fn test_periodic_columns_air() { - let air_tester = Box::new(PeriodicColumnsAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = PeriodicColumnsAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_periodic_columns_air, + crate::periodic_columns::periodic_columns::PeriodicColumnsAir, + PeriodicColumnsAirTester, + 1024 +); diff --git a/air-script/tests/pub_inputs/test_air.rs b/air-script/tests/pub_inputs/test_air.rs index 14a02b2e4..ad929a857 100644 --- a/air-script/tests/pub_inputs/test_air.rs +++ b/air-script/tests/pub_inputs/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - pub_inputs::pub_inputs::{PubInputsAir, PublicInputs}, + pub_inputs::pub_inputs::PublicInputs, }; #[derive(Clone)] @@ -37,17 +38,9 @@ impl AirTester for PubInputsAirTester { } } -#[test] -fn test_pub_inputs_air() { - let air_tester = Box::new(PubInputsAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = PubInputsAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_pub_inputs_air, + crate::pub_inputs::pub_inputs::PubInputsAir, + PubInputsAirTester, + 1024 +); diff --git a/air-script/tests/selectors/mod.rs b/air-script/tests/selectors/mod.rs index 60136267f..9e7f0b7ff 100644 --- a/air-script/tests/selectors/mod.rs +++ b/air-script/tests/selectors/mod.rs @@ -3,11 +3,14 @@ mod selectors; #[rustfmt::skip] #[allow(clippy::all)] -mod selectors_with_evaluators; +mod selectors_combine_simple; +#[rustfmt::skip] +#[allow(clippy::all)] +mod selectors_combine_complex; #[rustfmt::skip] #[allow(clippy::all)] -mod selectors_with_evaluators_with_mir; +mod selectors_with_evaluators; #[rustfmt::skip] #[allow(clippy::all)] -mod selectors_with_mir; +mod selectors_combine_with_list_comprehensions; mod test_air; diff --git a/air-script/tests/selectors/selectors.rs b/air-script/tests/selectors/selectors.rs index c7b0c1202..6bbdee036 100644 --- a/air-script/tests/selectors/selectors.rs +++ b/air-script/tests/selectors/selectors.rs @@ -47,7 +47,7 @@ impl Air for SelectorsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(3)]; + let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -82,9 +82,8 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = (main_next[3] - E::ZERO) * main_current[0] * (E::ONE - main_current[1]); - result[1] = (main_next[3] - main_current[3]) * main_current[0] * main_current[1] * main_current[2]; - result[2] = (main_next[3] - E::ONE) * (E::ONE - main_current[1]) * (E::ONE - main_current[2]); + result[0] = main_current[0] * (E::ONE - main_current[1]) * main_next[3]; + result[1] = main_current[0] * main_current[1] * main_current[2] * (main_next[3] - main_current[3]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors_combine_complex.air b/air-script/tests/selectors/selectors_combine_complex.air new file mode 100644 index 000000000..e79ac9dbb --- /dev/null +++ b/air-script/tests/selectors/selectors_combine_complex.air @@ -0,0 +1,79 @@ +def SelectorsAir + +const R = [8, 4, 20, 31, 5, 15]; + +trace_columns { + main: [s[3], a, b, c], +} + +buses { + multiset p, +} + +public_inputs { + stack_inputs: [1], +} + +boundary_constraints { + enf c.first = 0; + enf p.first = null; + enf p.last = null; +} + +fn double_with_mult(a: felt) -> felt { + return 2*a; +} + +fn double_with_add(a: felt) -> felt { + return a+a; +} + +# Note: If the constraints `a, b, ..., i` verify the following equivalency: +# a = c = e +# g = i +# Then the match statement: +# enf match { +# case s0: {a, b, c, d}, +# case s1: {e, f, g}, +# case s2: {h, i}, +# }; +# Is expected to get be reduced to only three constraints: +# a * (s0 + s1) + s2 * h = 0 +# g * (s1 + s2) + s0 * b = 0 +# s0 * d + s1 * f = 0 + +ev ev_s0([a, b, c]) { + enf a = double_with_mult(R[0]); #a + enf b = R[1]; #b + enf a = double_with_add(R[0]); #c + enf c = R[2]; #d + + p.insert(1, 2) when c; +} + +ev ev_s1([a, b, c]) { + enf a = 16; #e + enf b = R[3]; #f + enf c = R[4]; #g + + p.insert(1, 2) when c; +} + +ev ev_s2([a, b, c]) { + enf b = R[4]; #h + enf c = R[4]; #i + + p.remove(3, 4) when b; +} + +integrity_constraints { + let s0 = s[0]; + let s1 = !s[0] & s[1]; + let s2 = !s[0] & !s[1]; + + enf match { + case s0: ev_s0([a, b, c]), + case s1: ev_s1([a, b, c]), + case s2: ev_s2([a, b, c]), + }; +} \ No newline at end of file diff --git a/air-script/tests/selectors/selectors_combine_complex.rs b/air-script/tests/selectors/selectors_combine_complex.rs new file mode 100644 index 000000000..a1f1fa791 --- /dev/null +++ b/air-script/tests/selectors/selectors_combine_complex.rs @@ -0,0 +1,102 @@ +use winter_air::{Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement, ToElements}; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 1], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 1]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + self.stack_inputs.write_into(target); + } +} + +impl ToElements for PublicInputs { + fn to_elements(&self) -> Vec { + let mut elements = Vec::new(); + elements.extend_from_slice(&self.stack_inputs); + elements + } +} + +pub struct SelectorsAir { + context: AirContext, + stack_inputs: [Felt; 1], +} + +impl SelectorsAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for SelectorsAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3)]; + let aux_degrees = vec![TransitionConstraintDegree::new(6)]; + let num_main_assertions = 1; + let num_aux_assertions = 2; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(5, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxRandElements) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, E::ONE)); + result.push(Assertion::single(0, self.last_step(), E::ONE)); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = (main_current[0] + (E::ONE - main_current[0]) * main_current[1]) * (main_current[3] - E::from(Felt::new(16_u64))) + (E::ONE - main_current[0]) * (E::ONE - main_current[1]) * (main_current[4] - E::from(Felt::new(5_u64))); + result[1] = (E::ONE - main_current[0]) * (main_current[5] - E::from(Felt::new(5_u64))) + main_current[0] * (main_current[4] - E::from(Felt::new(4_u64))); + result[2] = main_current[0] * (main_current[5] - E::from(Felt::new(20_u64))) + (E::ONE - main_current[0]) * main_current[1] * (main_current[4] - E::from(Felt::new(31_u64))); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + result[0] = ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * E::from(main_current[0]) * E::from(main_current[5]) + E::ONE - E::from(main_current[0]) * E::from(main_current[5])) * ((aux_rand_elements.rand_elements()[0] + aux_rand_elements.rand_elements()[1] + E::from(Felt::new(2_u64)) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[0])) * E::from(main_current[1]) * E::from(main_current[5]) + E::ONE - (E::ONE - E::from(main_current[0])) * E::from(main_current[1]) * E::from(main_current[5])) * aux_current[0] - ((aux_rand_elements.rand_elements()[0] + E::from(Felt::new(3_u64)) * aux_rand_elements.rand_elements()[1] + E::from(Felt::new(4_u64)) * aux_rand_elements.rand_elements()[2]) * (E::ONE - E::from(main_current[0])) * (E::ONE - E::from(main_current[1])) * E::from(main_current[4]) + E::ONE - (E::ONE - E::from(main_current[0])) * (E::ONE - E::from(main_current[1])) * E::from(main_current[4])) * aux_next[0]; + } +} \ No newline at end of file diff --git a/air-script/tests/selectors/selectors_combine_simple.air b/air-script/tests/selectors/selectors_combine_simple.air new file mode 100644 index 000000000..4db36e41b --- /dev/null +++ b/air-script/tests/selectors/selectors_combine_simple.air @@ -0,0 +1,30 @@ +def SelectorsAir + +trace_columns { + main: [s[3], c], +} + +public_inputs { + stack_inputs: [1], +} + +boundary_constraints { + enf c.first = 0; +} + +ev op_add([s[3]]) { + enf s[0]' = s[0] + s[1]; + enf s[1]' = s[2]; +} + +ev op_mul([s[3]]) { + enf s[1]' = s[2]; + enf s[0]' = s[0] * s[1]; +} + +integrity_constraints { + enf match { + case c: op_add([s[0..3]]), + case !c: op_mul([s[0..3]]), + }; +} \ No newline at end of file diff --git a/air-script/tests/selectors/selectors_with_mir.rs b/air-script/tests/selectors/selectors_combine_simple.rs similarity index 84% rename from air-script/tests/selectors/selectors_with_mir.rs rename to air-script/tests/selectors/selectors_combine_simple.rs index 13bab0ff8..0df19ebce 100644 --- a/air-script/tests/selectors/selectors_with_mir.rs +++ b/air-script/tests/selectors/selectors_combine_simple.rs @@ -4,11 +4,11 @@ use winter_math::{ExtensionOf, FieldElement, ToElements}; use winter_utils::{ByteWriter, Serializable}; pub struct PublicInputs { - stack_inputs: [Felt; 16], + stack_inputs: [Felt; 1], } impl PublicInputs { - pub fn new(stack_inputs: [Felt; 16]) -> Self { + pub fn new(stack_inputs: [Felt; 1]) -> Self { Self { stack_inputs } } } @@ -29,7 +29,7 @@ impl ToElements for PublicInputs { pub struct SelectorsAir { context: AirContext, - stack_inputs: [Felt; 16], + stack_inputs: [Felt; 1], } impl SelectorsAir { @@ -47,7 +47,7 @@ impl Air for SelectorsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(3)]; + let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(3)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -82,9 +82,8 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * (E::ONE - main_current[1]) * (main_next[3] - E::ZERO) - E::ZERO; - result[1] = main_current[0] * main_current[1] * main_current[2] * (main_next[3] - main_current[3]) - E::ZERO; - result[2] = (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE) - E::ZERO; + result[0] = main_next[1] - main_current[2]; + result[1] = main_current[3] * (main_next[0] - (main_current[0] + main_current[1])) + (E::ONE - main_current[3]) * (main_next[0] - main_current[0] * main_current[1]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors_combine_with_list_comprehensions.air b/air-script/tests/selectors/selectors_combine_with_list_comprehensions.air new file mode 100644 index 000000000..980d9e5cf --- /dev/null +++ b/air-script/tests/selectors/selectors_combine_with_list_comprehensions.air @@ -0,0 +1,76 @@ +def SelectorsAir + +trace_columns { + main: [s[3], a, b, c], +} + +public_inputs { + stack_inputs: [1], +} + +boundary_constraints { + enf c.first = 0; +} + +fn double_mul(a: felt) -> felt { + return 2*a; +} + +fn double_add(a: felt) -> felt { + return a+a; +} + +fn triple(a: felt) -> felt { + return 3*a; +} + +# Note: If the constraints `a, b, ..., i` verify the following equivalency: +# a = c = e +# g = i +# Then the match statement: +# enf match { +# case s0: {a, b, c, d}, +# case s1: {e, f, g}, +# case s2: {h, i}, +# }; +# Is expected to get be reduced to only three constraints: +# a * (s0 + s1) + s2 * h = 0 +# g * (s1 + s2) + s0 * b = 0 +# s0 * d + s1 * f = 0 + +ev ev_s0([a, b, c]) { + let vec1 = [double_add(i) for i in 0..5]; + + enf a = vec1[0]; #a + enf b = vec1[1]; #b + enf a = vec1[1] - 2; #c + enf c = vec1[2]; #d +} + +ev ev_s1([a, b, c]) { + let vec2 = [double_mul(i) for i in 0..5]; + + enf a = vec2[2] - 4; #e + enf b = vec2[3]; #f + enf c = vec2[4]; #g +} + +ev ev_s2([a, b, c]) { + let vec_3i = [triple(i) for i in 0..5]; + let vec_i = [i for i in 0..5]; + let vec = [m-n for (m,n) in (vec_3i, vec_i)]; + enf b = vec[4]; #h + enf c = vec[4]; #i +} + +integrity_constraints { + let s0 = s[0]; + let s1 = !s[0] & s[1]; + let s2 = !s[0] & !s[1]; + + enf match { + case s0: ev_s0([a, b, c]), + case s1: ev_s1([a, b, c]), + case s2: ev_s2([a, b, c]), + }; +} \ No newline at end of file diff --git a/air-script/tests/selectors/selectors_with_evaluators_with_mir.rs b/air-script/tests/selectors/selectors_combine_with_list_comprehensions.rs similarity index 78% rename from air-script/tests/selectors/selectors_with_evaluators_with_mir.rs rename to air-script/tests/selectors/selectors_combine_with_list_comprehensions.rs index 52126dfc7..b873a1155 100644 --- a/air-script/tests/selectors/selectors_with_evaluators_with_mir.rs +++ b/air-script/tests/selectors/selectors_combine_with_list_comprehensions.rs @@ -4,11 +4,11 @@ use winter_math::{ExtensionOf, FieldElement, ToElements}; use winter_utils::{ByteWriter, Serializable}; pub struct PublicInputs { - stack_inputs: [Felt; 16], + stack_inputs: [Felt; 1], } impl PublicInputs { - pub fn new(stack_inputs: [Felt; 16]) -> Self { + pub fn new(stack_inputs: [Felt; 1]) -> Self { Self { stack_inputs } } } @@ -29,7 +29,7 @@ impl ToElements for PublicInputs { pub struct SelectorsAir { context: AirContext, - stack_inputs: [Felt; 16], + stack_inputs: [Felt; 1], } impl SelectorsAir { @@ -47,7 +47,7 @@ impl Air for SelectorsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(3)]; + let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -70,7 +70,7 @@ impl Air for SelectorsAir { fn get_assertions(&self) -> Vec> { let mut result = Vec::new(); - result.push(Assertion::single(3, 0, Felt::ZERO)); + result.push(Assertion::single(5, 0, Felt::ZERO)); result } @@ -82,9 +82,9 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = main_current[0] * (E::ONE - main_current[1]) * (main_next[3] - E::ZERO) - E::ZERO; - result[1] = main_current[1] * main_current[2] * (main_current[0] * (main_next[3] - main_current[3]) - E::ZERO) - E::ZERO; - result[2] = (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE) - E::ZERO; + result[0] = (main_current[0] + (E::ONE - main_current[0]) * main_current[1]) * main_current[3] + (E::ONE - main_current[0]) * (E::ONE - main_current[1]) * (main_current[4] - E::from(Felt::new(8_u64))); + result[1] = (E::ONE - main_current[0]) * (main_current[5] - E::from(Felt::new(8_u64))) + main_current[0] * (main_current[4] - E::from(Felt::new(2_u64))); + result[2] = main_current[0] * (main_current[5] - E::from(Felt::new(4_u64))) + (E::ONE - main_current[0]) * main_current[1] * (main_current[4] - E::from(Felt::new(6_u64))); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/selectors_with_evaluators.rs b/air-script/tests/selectors/selectors_with_evaluators.rs index 8af67cbc4..4902ac774 100644 --- a/air-script/tests/selectors/selectors_with_evaluators.rs +++ b/air-script/tests/selectors/selectors_with_evaluators.rs @@ -15,7 +15,7 @@ impl PublicInputs { impl Serializable for PublicInputs { fn write_into(&self, target: &mut W) { - target.write(self.stack_inputs); + self.stack_inputs.write_into(target); } } @@ -47,7 +47,7 @@ impl Air for SelectorsAir { } fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { - let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4), TransitionConstraintDegree::new(3)]; + let main_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4)]; let aux_degrees = vec![]; let num_main_assertions = 1; let num_aux_assertions = 0; @@ -82,9 +82,8 @@ impl Air for SelectorsAir { fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { let main_current = frame.current(); let main_next = frame.next(); - result[0] = (main_next[3] - E::ZERO) * main_current[0] * (E::ONE - main_current[1]); - result[1] = (main_next[3] - main_current[3]) * main_current[0] * main_current[1] * main_current[2]; - result[2] = (main_next[3] - E::ONE) * (E::ONE - main_current[1]) * (E::ONE - main_current[2]); + result[0] = main_current[0] * (E::ONE - main_current[1]) * main_next[3]; + result[1] = main_current[1] * main_current[2] * main_current[0] * (main_next[3] - main_current[3]) + (E::ONE - main_current[1]) * (E::ONE - main_current[2]) * (main_next[3] - E::ONE); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air-script/tests/selectors/test_air.rs b/air-script/tests/selectors/test_air.rs index be27cf8c1..4f2f3378b 100644 --- a/air-script/tests/selectors/test_air.rs +++ b/air-script/tests/selectors/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - selectors::selectors_with_evaluators_with_mir::{PublicInputs, SelectorsAir}, + selectors::selectors_with_evaluators::PublicInputs, }; #[derive(Clone)] @@ -39,17 +40,9 @@ impl AirTester for SelectorsAirTester { } } -#[test] -fn test_selectors_with_evaluators_air() { - let air_tester = Box::new(SelectorsAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = SelectorsAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_selectors_with_evaluators_air, + crate::selectors::selectors_with_evaluators::SelectorsAir, + SelectorsAirTester, + 1024 +); diff --git a/air-script/tests/system/test_air.rs b/air-script/tests/system/test_air.rs index f35361b38..3aa6815c6 100644 --- a/air-script/tests/system/test_air.rs +++ b/air-script/tests/system/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - system::system::{PublicInputs, SystemAir}, + system::system::PublicInputs, }; #[derive(Clone)] @@ -38,17 +39,4 @@ impl AirTester for SystemAirTester { } } -#[test] -fn test_system_air() { - let air_tester = Box::new(SystemAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = SystemAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!(test_system_air, crate::system::system::SystemAir, SystemAirTester, 1024); diff --git a/air-script/tests/trace_col_groups/test_air.rs b/air-script/tests/trace_col_groups/test_air.rs index 7872ba6bb..cb62434e9 100644 --- a/air-script/tests/trace_col_groups/test_air.rs +++ b/air-script/tests/trace_col_groups/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - trace_col_groups::trace_col_groups::{PublicInputs, TraceColGroupAir}, + trace_col_groups::trace_col_groups::PublicInputs, }; #[derive(Clone)] @@ -45,17 +46,9 @@ impl AirTester for TraceColGroupAirTester { } } -#[test] -fn test_trace_col_groups_air() { - let air_tester = Box::new(TraceColGroupAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = TraceColGroupAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_trace_col_groups_air, + crate::trace_col_groups::trace_col_groups::TraceColGroupAir, + TraceColGroupAirTester, + 1024 +); diff --git a/air-script/tests/variables/test_air.rs b/air-script/tests/variables/test_air.rs index 675eea4a7..c45e18516 100644 --- a/air-script/tests/variables/test_air.rs +++ b/air-script/tests/variables/test_air.rs @@ -3,8 +3,9 @@ use winter_math::fields::f64::BaseElement as Felt; use winterfell::{Trace, TraceTable}; use crate::{ + generate_air_test, helpers::{AirTester, MyTraceTable}, - variables::variables::{PublicInputs, VariablesAir}, + variables::variables::PublicInputs, }; #[derive(Clone)] @@ -39,17 +40,9 @@ impl AirTester for VariablesAirTester { } } -#[test] -fn test_variables_air() { - let air_tester = Box::new(VariablesAirTester {}); - let length = 1024; - - let main_trace = air_tester.build_main_trace(length); - let aux_trace = air_tester.build_aux_trace(length); - let pub_inputs = air_tester.public_inputs(); - let trace_info = air_tester.build_trace_info(length); - let options = air_tester.build_proof_options(); - - let air = VariablesAir::new(trace_info, pub_inputs, options); - main_trace.validate::(&air, aux_trace.as_ref()); -} +generate_air_test!( + test_variables_air, + crate::variables::variables::VariablesAir, + VariablesAirTester, + 1024 +); diff --git a/air-script/tests/variables/variables.rs b/air-script/tests/variables/variables.rs index 22127d0b8..9acea8ee2 100644 --- a/air-script/tests/variables/variables.rs +++ b/air-script/tests/variables/variables.rs @@ -88,9 +88,9 @@ impl Air for VariablesAir { let main_current = frame.current(); let main_next = frame.next(); result[0] = main_current[0] * main_current[0] - main_current[0]; - result[1] = periodic_values[0] * (main_next[0] - main_current[0]) - E::ZERO; + result[1] = periodic_values[0] * (main_next[0] - main_current[0]); result[2] = (E::ONE - main_current[0]) * (main_current[3] - main_current[1] - main_current[2]) - (E::from(Felt::new(6_u64)) - (E::from(Felt::new(7_u64)) - main_current[0])); - result[3] = main_current[0] * (main_current[3] - main_current[1] * main_current[2]) - (E::from(Felt::new(4_u64)) - E::from(Felt::new(3_u64)) - main_next[0]); + result[3] = main_current[0] * (main_current[3] - main_current[1] * main_current[2]) - (E::ONE - main_next[0]); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxRandElements, result: &mut [E]) diff --git a/air/Cargo.toml b/air/Cargo.toml index 953743e75..304671d07 100644 --- a/air/Cargo.toml +++ b/air/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-ir" -version = "0.4.0" +version = "0.5.0" description = "Intermediate representation for the AirScript language" authors.workspace = true readme = "README.md" @@ -12,9 +12,11 @@ rust-version.workspace = true edition.workspace = true [dependencies] -air-parser = { package = "air-parser", path = "../parser", version = "0.4" } -air-pass = { package = "air-pass", path = "../pass", version = "0.4" } +air-parser = { package = "air-parser", path = "../parser", version = "0.5" } +air-pass = { package = "air-pass", path = "../pass", version = "0.5" } anyhow = { workspace = true } miden-diagnostics = { workspace = true } -mir = { package = "air-mir", path = "../mir", version = "0.4" } +mir = { package = "air-mir", path = "../mir", version = "0.5" } thiserror = { workspace = true } +rand = "0.9" +winter-math = { package = "winter-math", version = "0.12", default-features = false } diff --git a/air/README.md b/air/README.md index 60b70d0d6..c5576a6ed 100644 --- a/air/README.md +++ b/air/README.md @@ -6,7 +6,7 @@ The purpose of the `AirIR` is to provide a simple and accurate representation of ## Generating the AirIR -Generate an `AirIR` from either an AirScript AST (the output of the AirScript parser) or a MIR (the Middle Intermediate Representation for AirScript). +Generate an `AirIR` from a MIR (the Middle Intermediate Representation for AirScript). Example usage: @@ -14,21 +14,8 @@ Example usage: // parse the source string to a Result containing the AST or an Error let ast = parse(source.as_str()).expect("Parsing failed"); -// Create the compilation pipeline needed to translate the AST to AIR -let pipeline_with_mir = air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(mir::passes::AstToMir::new(&diagnostics)) - .chain(mir::passes::Inlining::new(&diagnostics)) - .chain(mir::passes::Unrolling::new(&diagnostics)) - .chain(air_ir::passes::MirToAir::new(&diagnostics)) - .chain(air_ir::passes::BusOpExpand::new(&diagnostics)); - -let pipeline_without_mir = air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(air_parser::transforms::Inlining::new(&diagnostics)) - .chain(air_ir::passes::AstToAir::new(&diagnostics)); - -// process the AST to get a Result containing the AIR or a CompileError -let air_from_ast = pipeline_without_mir.run(ast) -let air_from_mir = pipeline_with_mir.run(ast) +// Compile AST into AIR +let air = compile(&diagnostics, ast).expect("compilation failed"); ``` ## AirIR diff --git a/air/src/graph/cse.rs b/air/src/graph/cse.rs new file mode 100644 index 000000000..9a2d1d5c1 --- /dev/null +++ b/air/src/graph/cse.rs @@ -0,0 +1,85 @@ +extern crate alloc; +use alloc::collections::BTreeMap; + +use mir::ir::QuadFelt; + +use crate::graph::{AlgebraicGraph, Node, NodeIndex, Operation, RandomInputs}; + +impl AlgebraicGraph { + /// Evaluates all the nodes in the graph at random points and returns a map of node indices to + /// their evaluations. + fn evaluate_on_random_inputs(&self) -> Vec { + let mut random_inputs = RandomInputs::default(); + for index in 0..self.num_nodes() { + let node_index = NodeIndex(index); + random_inputs.eval(self, &node_index); + } + random_inputs.into_evaluations() + } + + /// Given the evaluations of all nodes in the graph, ordered by their node indices, eliminates + /// common subexpressions by replacing nodes with identical evaluations. + /// + /// In the process, as some nodes will be removed from the graph, node indices need to be + /// remapped to keep consistent values (from `NodeIndex(0)` to `NodeIndex(self.num_nodes())`). + /// This function returns the node indices remapping map. + pub fn eliminate_common_subexpressions(&mut self) -> BTreeMap { + let evals = self.evaluate_on_random_inputs(); + + let mut new_nodes = Vec::new(); + + // 1. Keep track of evaluations, indices rewrites + let mut evals_vec: Vec = Vec::with_capacity(evals.len()); + let mut renumbering_map: BTreeMap = BTreeMap::new(); + + for (index, eval) in evals.iter().enumerate() { + let node_index = NodeIndex(index); + + // 2. For each node, check if its evaluation already exists in the map. + if let Some(existing_index) = evals_vec.iter().position(|e| e == eval) { + // 3. If it does, we will replace the node with the existing one + renumbering_map.insert(node_index, NodeIndex(existing_index)); + } else { + // 4. If it doesn't, rewrite the node indices if needed and add the node to the new + // graph + evals_vec.push(*eval); + + let op = self.node(&node_index).op(); + let new_op = match op { + Operation::Value(_) => { + // Values do not need renumbering, they are leaf nodes + op.clone() + }, + // Note: for Add, Sub, and Mul operations, we assume it's children have already + // been handled. This holds because when building the graph, we always insert + // children before parents, so their indices will always be lower and thus have + // been already processed and added to the `renumbering_map`. + Operation::Add(lhs, rhs) => { + let new_lhs = *renumbering_map.get(lhs).expect("Child of an operation not found in renumbering_map, but we should have already processed it"); + let new_rhs = *renumbering_map.get(rhs).expect("Child of an operation not found in renumbering_map, but we should have already processed it"); + Operation::Add(new_lhs, new_rhs) + }, + Operation::Sub(lhs, rhs) => { + let new_lhs = *renumbering_map.get(lhs).expect("Child of an operation not found in renumbering_map, but we should have already processed it"); + let new_rhs = *renumbering_map.get(rhs).expect("Child of an operation not found in renumbering_map, but we should have already processed it"); + Operation::Sub(new_lhs, new_rhs) + }, + Operation::Mul(lhs, rhs) => { + let new_lhs = *renumbering_map.get(lhs).expect("Child of an operation not found in renumbering_map, but we should have already processed it"); + let new_rhs = *renumbering_map.get(rhs).expect("Child of an operation not found in renumbering_map, but we should have already processed it"); + Operation::Mul(new_lhs, new_rhs) + }, + }; + let new_node = Node { op: new_op }; + let new_index = new_nodes.len(); + new_nodes.push(new_node); + renumbering_map.insert(node_index, NodeIndex(new_index)); + } + } + + // Replace the nodes in the graph with the new nodes + self.nodes = new_nodes; + + renumbering_map + } +} diff --git a/air/src/graph/mod.rs b/air/src/graph/mod.rs index 3a2292b3f..2d77b0ecd 100644 --- a/air/src/graph/mod.rs +++ b/air/src/graph/mod.rs @@ -2,11 +2,13 @@ use std::collections::BTreeMap; use crate::ir::*; +mod cse; + /// A unique identifier for a node in an [AlgebraicGraph] /// /// The raw value of this identifier is an index in the `nodes` vector /// of the [AlgebraicGraph] struct. -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct NodeIndex(usize); impl core::ops::Add for NodeIndex { type Output = NodeIndex; @@ -94,7 +96,43 @@ impl AlgebraicGraph { } } - /// TODO: docs + /// Recursively analyzes a subgraph starting from the specified node and infers the trace + /// segment and constraint domain that the subgraph should be applied to. + /// + /// This function performs a bottom-up traversal of the constraint expression graph to + /// determine: + /// - The **trace segment** that this constraint expression operates on (main trace vs auxiliary + /// trace) + /// - The **constraint domain** that specifies which rows the constraint should be applied to + /// + /// # Arguments + /// + /// * `index` - The index of the node to analyze + /// * `default_domain` - The default constraint domain to use for leaf nodes that don't specify + /// their own domain (e.g., constants, public inputs) + /// + /// # Returns + /// + /// A tuple containing: + /// - `TraceSegmentId`: The trace segment this expression should be applied to (0 for main, 1+ + /// for auxiliary) + /// - `ConstraintDomain`: The domain specifying which rows to apply the constraint to + /// + /// # Inference Rules + /// + /// - **Constants**: Use default segment and provided default domain + /// - **Periodic columns**: Use default segment with `EveryRow` domain (invalid in boundary + /// constraints) + /// - **Public inputs**: Use default segment with provided domain (invalid in integrity + /// constraints) + /// - **Random values**: Use auxiliary segment with provided domain + /// - **Trace access**: Use the access's segment with domain inferred from row offset + /// - **Binary operations**: Use the maximum segment and merged domain from both operands + /// + /// # Errors + /// + /// Returns `ConstraintError::IncompatibleConstraintDomains` if operands of a binary operation + /// have incompatible constraint domains that cannot be merged. pub fn node_details( &self, index: &NodeIndex, @@ -103,23 +141,30 @@ impl AlgebraicGraph { // recursively walk the subgraph and infer the trace segment and domain match self.node(index).op() { Operation::Value(value) => match value { - Value::Constant(_) => Ok((DEFAULT_SEGMENT, default_domain)), - Value::RandomValue(_) => Ok((AUX_SEGMENT, default_domain)), + Value::Constant(_) => Ok((TraceSegmentId::Main, default_domain)), + Value::RandomValue(_) => Ok((TraceSegmentId::Aux, default_domain)), Value::PeriodicColumn(_) => { assert!( !default_domain.is_boundary(), "unexpected access to periodic column in boundary constraint" ); // the default domain for [IntegrityConstraints] is `EveryRow` - Ok((DEFAULT_SEGMENT, ConstraintDomain::EveryRow)) - } + Ok((TraceSegmentId::Main, ConstraintDomain::EveryRow)) + }, Value::PublicInput(_) => { assert!( !default_domain.is_integrity(), "unexpected access to public input in integrity constraint" ); - Ok((DEFAULT_SEGMENT, default_domain)) - } + Ok((TraceSegmentId::Main, default_domain)) + }, + Value::PublicInputTable(_) => { + assert!( + !default_domain.is_integrity(), + "unexpected access to public input table in integrity constraint" + ); + Ok((TraceSegmentId::Main, default_domain)) + }, Value::TraceAccess(trace_access) => { let domain = if default_domain.is_boundary() { assert_eq!( @@ -132,7 +177,7 @@ impl AlgebraicGraph { }; Ok((trace_access.segment, domain)) - } + }, }, Operation::Add(lhs, rhs) | Operation::Sub(lhs, rhs) | Operation::Mul(lhs, rhs) => { let (lhs_segment, lhs_domain) = self.node_details(lhs, default_domain)?; @@ -142,7 +187,7 @@ impl AlgebraicGraph { let domain = lhs_domain.merge(rhs_domain)?; Ok((trace_segment, domain)) - } + }, } } @@ -172,28 +217,31 @@ impl AlgebraicGraph { // recursively walk the subgraph and compute the degree from the operation and child nodes match self.node(index).op() { Operation::Value(value) => match value { - Value::Constant(_) | Value::PublicInput(_) | Value::RandomValue(_) => 0, + Value::Constant(_) + | Value::PublicInput(_) + | Value::PublicInputTable(_) + | Value::RandomValue(_) => 0, Value::TraceAccess(_) => 1, Value::PeriodicColumn(pc) => { - cycles.insert(pc.name, pc.cycle); + cycles.insert(pc.name.clone(), pc.cycle); 0 - } + }, }, Operation::Add(lhs, rhs) => { let lhs_base = self.accumulate_degree(cycles, lhs); let rhs_base = self.accumulate_degree(cycles, rhs); lhs_base.max(rhs_base) - } + }, Operation::Sub(lhs, rhs) => { let lhs_base = self.accumulate_degree(cycles, lhs); let rhs_base = self.accumulate_degree(cycles, rhs); lhs_base.max(rhs_base) - } + }, Operation::Mul(lhs, rhs) => { let lhs_base = self.accumulate_degree(cycles, lhs); let rhs_base = self.accumulate_degree(cycles, rhs); lhs_base + rhs_base - } + }, } } } diff --git a/air/src/ir/bus.rs b/air/src/ir/bus.rs index 969128ada..90c202be8 100644 --- a/air/src/ir/bus.rs +++ b/air/src/ir/bus.rs @@ -20,7 +20,7 @@ pub struct Bus { } /// Represents the boundaries of a bus, which can be either a public input table or an empty bus. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum BusBoundary { /// A reference to a public input table. PublicInputTable(PublicInputTableAccess), @@ -30,25 +30,21 @@ pub enum BusBoundary { Unconstrained, } -/// Represents an access of a public input table, similar in nature to [TraceAccess]. +/// Represents an access of a public input table. /// /// It can only be bound to a [Bus]'s .first or .last boundary constraints. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct PublicInputTableAccess { /// The name of the public input to bind pub table_name: Identifier, - /// The name of the bus - pub bus_name: Identifier, /// The number of columns in the public input table pub num_cols: usize, + /// The type of the bus + pub bus_type: BusType, } impl PublicInputTableAccess { - pub const fn new(table_name: Identifier, bus_name: Identifier, num_cols: usize) -> Self { - Self { - table_name, - num_cols, - bus_name, - } + pub const fn new(table_name: Identifier, num_cols: usize, bus_type: BusType) -> Self { + Self { table_name, num_cols, bus_type } } } @@ -65,11 +61,7 @@ pub struct BusOp { impl BusOp { pub fn new(columns: Vec, latch: NodeIndex, op_kind: BusOpKind) -> Self { - Self { - columns, - latch, - op_kind, - } + Self { columns, latch, op_kind } } } @@ -81,12 +73,6 @@ impl Bus { last: BusBoundary, bus_ops: Vec, ) -> Self { - Self { - name, - bus_type, - first, - last, - bus_ops, - } + Self { name, bus_type, first, last, bus_ops } } } diff --git a/air/src/ir/constraints.rs b/air/src/ir/constraints.rs index 90f46de5c..34cac3498 100644 --- a/air/src/ir/constraints.rs +++ b/air/src/ir/constraints.rs @@ -1,8 +1,9 @@ +extern crate alloc; +use alloc::collections::BTreeSet; use core::fmt; -use crate::graph::{AlgebraicGraph, NodeIndex}; - use super::*; +use crate::graph::{AlgebraicGraph, NodeIndex}; #[derive(Debug, thiserror::Error)] pub enum ConstraintError { @@ -10,29 +11,28 @@ pub enum ConstraintError { IncompatibleConstraintDomains(ConstraintDomain, ConstraintDomain), } -/// [Constraints] is the algebraic graph representation of all the constraints -/// in an [AirScript]. The graph contains all of the constraints, each of which -/// is a subgraph consisting of all the expressions involved in evaluating the constraint, -/// including constants, references to the trace, public inputs, random values, and -/// periodic columns. +/// [Constraints] is the algebraic graph representation of all the constraints in an AirScript. The +/// graph contains all of the constraints, each of which is a subgraph consisting of all the +/// expressions involved in evaluating the constraint, including constants, references to the trace, +/// public inputs, random values, and periodic columns. /// -/// Internally, this struct also holds a matrix for each constraint type (boundary, -/// integrity), where each row corresponds to a trace segment (in the same order) -/// and contains a vector of [ConstraintRoot] for all of the constraints of that type -/// to be applied to that trace segment. +/// Internally, this struct also holds a matrix for each constraint type (boundary, integrity), +/// where each row corresponds to a trace segment (in the same order) and contains a vector of +/// [ConstraintRoot] for all of the constraints of that type to be applied to that trace segment. /// -/// For example, integrity constraints for the main execution trace, which has a trace segment -/// id of 0, will be specified by the vector of constraint roots found at index 0 of the +/// For example, integrity constraints for the main execution trace, which has a trace segment id of +/// 0, will be specified by the vector of constraint roots found at index 0 of the /// `integrity_constraints` matrix. #[derive(Default, Debug)] pub struct Constraints { - /// Constraint roots for all boundary constraints against the execution trace, by trace segment, - /// where boundary constraints are any constraints that apply to either the first or the last - /// row of the trace. - boundary_constraints: Vec>, - /// Constraint roots for all integrity constraints against the execution trace, by trace segment, - /// where integrity constraints are any constraints that apply to every row or every frame. - integrity_constraints: Vec>, + /// Constraint roots for all boundary constraints against the execution trace, by trace + /// segment, where boundary constraints are any constraints that apply to either the first + /// or the last row of the trace. + boundary_constraints: TraceShape>, + /// Constraint roots for all integrity constraints against the execution trace, by trace + /// segment, where integrity constraints are any constraints that apply to every row or + /// every frame. + integrity_constraints: TraceShape>, /// A directed acyclic graph which represents all of the constraints and their subexpressions. graph: AlgebraicGraph, } @@ -40,8 +40,8 @@ impl Constraints { /// Constructs a new [Constraints] graph from the given parts pub const fn new( graph: AlgebraicGraph, - boundary_constraints: Vec>, - integrity_constraints: Vec>, + boundary_constraints: TraceShape>, + integrity_constraints: TraceShape>, ) -> Self { Self { graph, @@ -50,12 +50,41 @@ impl Constraints { } } - /// Returns the number of boundary constraints applied against the specified trace segment. - pub fn num_boundary_constraints(&self, trace_segment: TraceSegmentId) -> usize { - if self.boundary_constraints.len() <= trace_segment { - return 0; + /// Updates the root boundary and integrity constraints to use the new node indices + /// values, given in the `renumbering_map`. + /// + /// This functions also removes duplicate constraints (that share the same root and domain). + /// + /// # Panics + /// Panics if a constraint's node index is not found in the renumbering map. + pub fn renumber_and_deduplicate_constraints( + &mut self, + renumbering_map: &BTreeMap, + ) { + // Iterate over all boundary and integrity constraints + for (_, segment_constraints) in self + .boundary_constraints + .iter_mut() + .chain(self.integrity_constraints.iter_mut()) + { + let mut added_indices = BTreeSet::new(); + segment_constraints.retain_mut(|constraint| { + let new_index = *renumbering_map + .get(constraint.node_index()) + .expect("Error: cannot find constraint index in renumbering map"); + // Don't keep duplicate constraints + if !added_indices.insert((new_index, constraint.domain)) { + return false; + } + // If this constraint is new, we update its node index and keep it + constraint.update_node_index(new_index); + true + }); } + } + /// Returns the number of boundary constraints applied against the specified trace segment. + pub fn num_boundary_constraints(&self, trace_segment: TraceSegmentId) -> usize { self.boundary_constraints[trace_segment].len() } @@ -64,22 +93,15 @@ impl Constraints { /// Each boundary constraint is represented by a [ConstraintRoot] which is /// the root of the subgraph representing the constraint within the [AlgebraicGraph] pub fn boundary_constraints(&self, trace_segment: TraceSegmentId) -> &[ConstraintRoot] { - if self.boundary_constraints.len() <= trace_segment { - return &[]; - } - - &self.boundary_constraints[trace_segment] + self.boundary_constraints[trace_segment].as_slice() } - /// Returns a vector of the degrees of the integrity constraints for the specified trace segment. + /// Returns a vector of the degrees of the integrity constraints for the specified trace + /// segment. pub fn integrity_constraint_degrees( &self, trace_segment: TraceSegmentId, ) -> Vec { - if self.integrity_constraints.len() <= trace_segment { - return vec![]; - } - self.integrity_constraints[trace_segment] .iter() .map(|entry_index| self.graph.degree(entry_index.node_index())) @@ -91,11 +113,7 @@ impl Constraints { /// Each integrity constraint is represented by a [ConstraintRoot] which is /// the root of the subgraph representing the constraint within the [AlgebraicGraph] pub fn integrity_constraints(&self, trace_segment: TraceSegmentId) -> &[ConstraintRoot] { - if self.integrity_constraints.len() <= trace_segment { - return &[]; - } - - &self.integrity_constraints[trace_segment] + self.integrity_constraints[trace_segment].as_slice() } /// Inserts a new constraint against `trace_segment`, using the provided `root` and `domain` @@ -107,25 +125,21 @@ impl Constraints { ) { let root = ConstraintRoot::new(root, domain); if domain.is_boundary() { - if self.boundary_constraints.len() <= trace_segment { - self.boundary_constraints.resize(trace_segment + 1, vec![]); - } self.boundary_constraints[trace_segment].push(root); } else { - if self.integrity_constraints.len() <= trace_segment { - self.integrity_constraints.resize(trace_segment + 1, vec![]); - } self.integrity_constraints[trace_segment].push(root); } } - /// Returns the underlying [AlgebraicGraph] representing all constraints and their sub-expressions. + /// Returns the underlying [AlgebraicGraph] representing all constraints and their + /// sub-expressions. #[inline] pub const fn graph(&self) -> &AlgebraicGraph { &self.graph } - /// Returns a mutable reference to the underlying [AlgebraicGraph] representing all constraints and their sub-expressions. + /// Returns a mutable reference to the underlying [AlgebraicGraph] representing all constraints + /// and their sub-expressions. #[inline] pub fn graph_mut(&mut self) -> &mut AlgebraicGraph { &mut self.graph @@ -151,6 +165,12 @@ impl ConstraintRoot { &self.index } + /// Updates the node index this constraint refers to. This should be called if the graph is + /// updated after its initial construction, such as during common subexpression elimination. + pub fn update_node_index(&mut self, new_index: NodeIndex) { + self.index = new_index; + } + /// Returns the [ConstraintDomain] for this constraint, which specifies the rows against which /// the constraint should be applied. pub const fn domain(&self) -> ConstraintDomain { @@ -161,7 +181,7 @@ impl ConstraintRoot { /// [ConstraintDomain] corresponds to the domain over which a constraint is applied. /// /// See the docs on each variant for more details. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ConstraintDomain { /// For boundary constraints which apply to the first row FirstRow, @@ -209,9 +229,11 @@ impl ConstraintDomain { /// that represents the maximum of the two. /// /// For example, if one domain is [ConstraintDomain::EveryFrame(2)] and the other - /// is [ConstraintDomain::EveryFrame(3)], then the result will be [ConstraintDomain::EveryFrame(3)]. + /// is [ConstraintDomain::EveryFrame(3)], then the result will be + /// [ConstraintDomain::EveryFrame(3)]. /// - /// NOTE: Domains for boundary constraints (FirstRow and LastRow) cannot be merged with other domains. + /// NOTE: Domains for boundary constraints (FirstRow and LastRow) cannot be merged with other + /// domains. pub fn merge(self, other: Self) -> Result { if self == other { return Ok(other); @@ -241,7 +263,7 @@ impl fmt::Display for ConstraintDomain { Self::EveryRow => write!(f, "every row"), Self::EveryFrame(size) => { write!(f, "every frame of {size} consecutive rows") - } + }, } } } diff --git a/air/src/ir/degree.rs b/air/src/ir/degree.rs index d23e5a554..d6bbccb06 100644 --- a/air/src/ir/degree.rs +++ b/air/src/ir/degree.rs @@ -1,10 +1,10 @@ //! The [IntegrityConstraintDegree] struct and documentation contained in this file is a duplicate //! of the [TransitionConstraintDegree] struct defined in the Winterfell STARK prover library -//! (https://github.com/novifinancial/winterfell), which is licensed under the MIT license. The +//! (), which is licensed under the MIT license. The //! implementation in this file is a subset of the Winterfell code. //! //! The original code is available in the Winterfell library in the `air` crate: -//! https://github.com/novifinancial/winterfell/blob/main/air/src/air/transition/degree.rs +//! use super::MIN_CYCLE_LENGTH; @@ -37,14 +37,8 @@ impl IntegrityConstraintDegree { /// should be set to 2. If a constraint involves multiplication of three trace columns, /// `degree` should be set to 3 etc. pub fn new(degree: usize) -> Self { - assert!( - degree > 0, - "integrity constraint degree must be at least one, but was zero" - ); - Self { - base: degree, - cycles: vec![], - } + assert!(degree > 0, "integrity constraint degree must be at least one, but was zero"); + Self { base: degree, cycles: vec![] } } /// Creates a new integrity degree descriptor for constraints which involve multiplication @@ -72,9 +66,6 @@ impl IntegrityConstraintDegree { "cycle length must be a power of two, but was {cycle} for cycle {i}" ); } - Self { - base: base_degree, - cycles, - } + Self { base: base_degree, cycles } } } diff --git a/air/src/ir/mod.rs b/air/src/ir/mod.rs index 50a264c29..0393558ba 100644 --- a/air/src/ir/mod.rs +++ b/air/src/ir/mod.rs @@ -2,16 +2,10 @@ mod bus; mod constraints; mod degree; mod operation; +mod random_inputs; mod trace; mod value; -pub use self::bus::{Bus, BusBoundary, BusOp, BusOpKind, BusType, PublicInputTableAccess}; -pub use self::constraints::{ConstraintDomain, ConstraintError, ConstraintRoot, Constraints}; -pub use self::degree::IntegrityConstraintDegree; -pub use self::operation::Operation; -pub use self::trace::TraceAccess; -pub use self::value::{PeriodicColumnAccess, PublicInputAccess, Value}; - pub use air_parser::{ Symbol, ast::{ @@ -20,16 +14,151 @@ pub use air_parser::{ }, }; -/// The default segment against which a constraint is applied is the main trace segment. -pub const DEFAULT_SEGMENT: TraceSegmentId = 0; -/// The auxiliary trace segment. -pub const AUX_SEGMENT: TraceSegmentId = 1; +pub use self::{ + bus::{Bus, BusBoundary, BusOp, BusOpKind, BusType, PublicInputTableAccess}, + constraints::{ConstraintDomain, ConstraintError, ConstraintRoot, Constraints}, + degree::IntegrityConstraintDegree, + operation::Operation, + random_inputs::RandomInputs, + trace::TraceAccess, + value::{PeriodicColumnAccess, PublicInputAccess, Value}, +}; + +/// A fixed two segment trace shape containing values for the main and aux segments. +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct TraceShape { + pub main: T, + pub aux: T, +} + +impl TraceShape { + pub fn new(main: T, aux: T) -> Self { + Self { main, aux } + } + + pub fn map U>(&self, mut f: F) -> TraceShape { + TraceShape { main: f(&self.main), aux: f(&self.aux) } + } + + /// Returns an iterator over mutable references to `(TraceSegmentId, T)` in segment order. + pub fn iter_mut(&mut self) -> impl Iterator { + let (main, aux) = (&mut self.main, &mut self.aux); + [(TraceSegmentId::Main, main), (TraceSegmentId::Aux, aux)].into_iter() + } +} + +impl core::ops::Index for TraceShape { + type Output = T; + fn index(&self, index: TraceSegmentId) -> &Self::Output { + match index { + TraceSegmentId::Main => &self.main, + TraceSegmentId::Aux => &self.aux, + } + } +} + +impl core::ops::IndexMut for TraceShape { + fn index_mut(&mut self, index: TraceSegmentId) -> &mut Self::Output { + match index { + TraceSegmentId::Main => &mut self.main, + TraceSegmentId::Aux => &mut self.aux, + } + } +} + +impl core::ops::Index for TraceShape { + type Output = T; + fn index(&self, index: usize) -> &Self::Output { + match index { + 0 => &self.main, + 1 => &self.aux, + _ => panic!("invalid segment index"), + } + } +} + +impl core::ops::IndexMut for TraceShape { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + match index { + 0 => &mut self.main, + 1 => &mut self.aux, + _ => panic!("invalid segment index"), + } + } +} + +/// A fixed three segment trace shape containing values for the main, aux, and quotient segments. +/// +/// This wraps a two segment `TraceShape` for the witness traces, and adds a separate +/// `quotient` segment which is not addressable via `TraceSegmentId`. +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct FullTraceShape { + pub segments: TraceShape, + pub quotient: T, +} + +impl FullTraceShape { + pub fn new(main: T, aux: T, quotient: T) -> Self { + Self { + segments: TraceShape::new(main, aux), + quotient, + } + } + + #[inline] + pub fn segments(&self) -> &TraceShape { + &self.segments + } + + #[inline] + pub fn segments_mut(&mut self) -> &mut TraceShape { + &mut self.segments + } +} + +impl core::ops::Index for FullTraceShape { + type Output = T; + fn index(&self, index: TraceSegmentId) -> &Self::Output { + &self.segments[index] + } +} + +impl core::ops::IndexMut for FullTraceShape { + fn index_mut(&mut self, index: TraceSegmentId) -> &mut Self::Output { + &mut self.segments[index] + } +} + +impl core::ops::Index for FullTraceShape { + type Output = T; + fn index(&self, index: usize) -> &Self::Output { + match index { + 0 => &self.segments.main, + 1 => &self.segments.aux, + 2 => &self.quotient, + _ => panic!("invalid segment index"), + } + } +} + +impl core::ops::IndexMut for FullTraceShape { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + match index { + 0 => &mut self.segments.main, + 1 => &mut self.segments.aux, + 2 => &mut self.quotient, + _ => panic!("invalid segment index"), + } + } +} + /// The offset of the "current" row during constraint evaluation. pub const CURRENT_ROW: usize = 0; /// The minimum cycle length of a periodic column pub const MIN_CYCLE_LENGTH: usize = 2; -use std::collections::BTreeMap; +extern crate alloc; +use alloc::collections::BTreeMap; use miden_diagnostics::{SourceSpan, Spanned}; @@ -69,10 +198,7 @@ pub struct Air { } impl Default for Air { fn default() -> Self { - Self::new(Identifier::new( - SourceSpan::UNKNOWN, - Symbol::intern("unnamed"), - )) + Self::new(Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("unnamed"))) } } impl Air { @@ -80,7 +206,7 @@ impl Air { /// /// An empty [Air] is meaningless until it has been populated with /// constraints and associated metadata. This is typically done by converting - /// an [air_parser::ast::Program] to this struct using the [crate::passes::AstToAir] + /// an [air_parser::ast::Program] to this struct using the [crate::passes::MirToAir] /// translation pass. pub fn new(name: Identifier) -> Self { Self { @@ -104,6 +230,29 @@ impl Air { self.public_inputs.values() } + /// Returns a list of all accesses to reduced public input tables in canonical order. + pub fn reduced_public_input_table_accesses(&self) -> Vec { + let mut accesses: Vec<_> = self + .buses + .values() + .flat_map(|bus| { + [bus.first, bus.last] + .iter() + .filter_map(|boundary| { + if let BusBoundary::PublicInputTable(access) = boundary { + Some(*access) + } else { + None + } + }) + .collect::>() + }) + .collect(); + accesses.sort(); + accesses.dedup(); + accesses + } + pub fn periodic_columns(&self) -> impl Iterator + '_ { self.periodic_columns.values() } diff --git a/air/src/ir/operation.rs b/air/src/ir/operation.rs index 0b9ab3be7..765f1f6d0 100644 --- a/air/src/ir/operation.rs +++ b/air/src/ir/operation.rs @@ -1,10 +1,9 @@ -use crate::graph::NodeIndex; - use super::*; +use crate::graph::NodeIndex; /// [Operation] defines the various node types represented /// in the [AlgebraicGraph]. -#[derive(Debug, PartialEq, Eq, Copy, Clone, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord)] pub enum Operation { /// Evaluates to a [Value] /// @@ -25,9 +24,9 @@ impl Operation { /// precedence are evaluated left-to-right. pub fn precedence(&self) -> usize { match self { - Self::Add(_, _) => 1, - Self::Sub(_, _) => 2, - Self::Mul(_, _) => 3, + Self::Add(..) => 1, + Self::Sub(..) => 2, + Self::Mul(..) => 3, _ => 4, } } diff --git a/air/src/ir/random_inputs.rs b/air/src/ir/random_inputs.rs new file mode 100644 index 000000000..c5de4578e --- /dev/null +++ b/air/src/ir/random_inputs.rs @@ -0,0 +1,122 @@ +extern crate alloc; +use alloc::collections::BTreeMap; + +use air_parser::ast::TraceSegmentId; +use mir::ir::{QuadFelt, const_quad_felt, query_indexed_eval, query_mapped_eval}; +use rand::prelude::*; +use winter_math::fields::f64::BaseElement as Felt; + +use crate::{ + AlgebraicGraph, NodeIndex, Operation, PeriodicColumnAccess, PublicInputAccess, + PublicInputTableAccess, Value, +}; + +/// Holds both: +/// - the random inputs taken by leaf nodes, in order to persist them across different node +/// evaluations. +/// - the evaluations of all the nodes in the graph +#[derive(Debug, Clone, Default)] +pub struct RandomInputs { + rng: ThreadRng, + // A vector to hold the random values taken for the main trace, indexed in the following way: + // $main[0], $main[0]', $main[1], $main[1]', $main[2], ... + main_trace: Vec, + aux_trace: Vec, + rand_values: Vec, + public_inputs: BTreeMap, + periodic_columns: BTreeMap, + public_inputs_tables: BTreeMap, + // A map to hold the the current evaluations of nodes at random points + evals_map: BTreeMap, +} + +impl RandomInputs { + /// Evaluates a given algebraic graph node at random points. + pub fn eval(&mut self, graph: &AlgebraicGraph, node_index: &NodeIndex) -> QuadFelt { + let op = graph.node(node_index).op(); + match op { + Operation::Add(lhs, rhs) => { + let lhs_eval = self.eval(graph, lhs); + let rhs_eval = self.eval(graph, rhs); + let add_eval = lhs_eval + rhs_eval; + self.evals_map.insert(*node_index, add_eval); + add_eval + }, + Operation::Sub(lhs, rhs) => { + let lhs_eval = self.eval(graph, lhs); + let rhs_eval = self.eval(graph, rhs); + let sub_eval = lhs_eval - rhs_eval; + self.evals_map.insert(*node_index, sub_eval); + sub_eval + }, + Operation::Mul(lhs, rhs) => { + let lhs_eval = self.eval(graph, lhs); + let rhs_eval = self.eval(graph, rhs); + let mul_eval = lhs_eval * rhs_eval; + self.evals_map.insert(*node_index, mul_eval); + mul_eval + }, + Operation::Value(value) => match value { + Value::Constant(c) => { + let felt = Felt::new(*c); + let eval = const_quad_felt(felt); + self.evals_map.insert(*node_index, eval); + eval + }, + // For each trace segment, we associate a random value to each trace access, + // indexed in the following way, each column having two + // distinct evaluations to account for the two possible row offsets: + // $main[0], $main[0]', $main[1], $main[1]', $main[2], ... + // Note: if we encounter a trace access corresponding to an index we have not + // yet evaluated, we will randomly generate values for + // this trace access, but also for all previous indices. + Value::TraceAccess(trace_access) => match trace_access.segment { + TraceSegmentId::Main => { + let index = trace_access.column * 2 + trace_access.row_offset; + let eval = query_indexed_eval(&mut self.rng, &mut self.main_trace, index); + self.evals_map.insert(*node_index, eval); + eval + }, + TraceSegmentId::Aux => { + let index = trace_access.column * 2 + trace_access.row_offset; + let eval = query_indexed_eval(&mut self.rng, &mut self.aux_trace, index); + self.evals_map.insert(*node_index, eval); + eval + }, + }, + Value::RandomValue(u) => { + let eval = query_indexed_eval(&mut self.rng, &mut self.rand_values, *u); + self.evals_map.insert(*node_index, eval); + eval + }, + // For PublicInput, PeriodicColumn and PublicInputTable, we use the Hash of the + // element to associate a unique random value or each public input + // and each periodic column access + Value::PublicInput(pi) => { + let eval = query_mapped_eval(&mut self.rng, &mut self.public_inputs, pi); + self.evals_map.insert(*node_index, eval); + eval + }, + Value::PeriodicColumn(pc) => { + let eval = query_mapped_eval(&mut self.rng, &mut self.periodic_columns, pc); + self.evals_map.insert(*node_index, eval); + eval + }, + Value::PublicInputTable(public_input_table_access) => { + let eval = query_mapped_eval( + &mut self.rng, + &mut self.public_inputs_tables, + public_input_table_access, + ); + self.evals_map.insert(*node_index, eval); + eval + }, + }, + } + } + + /// Consumes self and returns all the evaluations, ordered by `NodeIndex`. + pub fn into_evaluations(self) -> Vec { + self.evals_map.into_values().collect() + } +} diff --git a/air/src/ir/trace.rs b/air/src/ir/trace.rs index 8c118f088..43bb534be 100644 --- a/air/src/ir/trace.rs +++ b/air/src/ir/trace.rs @@ -1,6 +1,6 @@ use air_parser::ast::{TraceColumnIndex, TraceSegmentId}; -/// [TraceAccess] is like [SymbolAccess], but is used to describe an access to a specific trace column or columns. +/// [TraceAccess] is used to describe an access to a specific trace column or columns. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct TraceAccess { /// The trace segment being accessed @@ -18,19 +18,6 @@ pub struct TraceAccess { impl TraceAccess { /// Creates a new [TraceAccess]. pub const fn new(segment: TraceSegmentId, column: TraceColumnIndex, row_offset: usize) -> Self { - Self { - segment, - column, - row_offset, - } - } - - /// Creates a new [TraceAccess] with a new column index that is updated according to the - /// provided offsets. All other data is left unchanged. - pub fn clone_with_offsets(&self, offsets: &[Vec]) -> Self { - Self { - column: offsets[self.segment][self.column], - ..*self - } + Self { segment, column, row_offset } } } diff --git a/air/src/ir/value.rs b/air/src/ir/value.rs index dbfbee14a..b9ff43789 100644 --- a/air/src/ir/value.rs +++ b/air/src/ir/value.rs @@ -4,7 +4,7 @@ use super::*; /// /// Values are either constant, or evaluated at runtime using the context /// provided to an AirScript program (i.e. random values, public inputs, etc.). -#[derive(Debug, Eq, PartialEq, Copy, Clone, PartialOrd, Ord)] +#[derive(Debug, Eq, PartialEq, Clone, PartialOrd, Ord)] pub enum Value { /// A constant value. Constant(u64), @@ -16,12 +16,14 @@ pub enum Value { PeriodicColumn(PeriodicColumnAccess), /// A reference to a specific element of a given public input PublicInput(PublicInputAccess), + /// A reference to a specific public input table used as a boundary for one of the buses. + PublicInputTable(PublicInputTableAccess), /// A reference to the `random_values` array, specifically the element at the given index RandomValue(usize), } /// Represents an access of a [PeriodicColumn], similar in nature to [TraceAccess] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct PeriodicColumnAccess { pub name: QualifiedIdentifier, pub cycle: usize, @@ -33,7 +35,7 @@ impl PeriodicColumnAccess { } /// Represents an access of a [PublicInput], similar in nature to [TraceAccess] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct PublicInputAccess { /// The name of the public input to access pub name: Identifier, diff --git a/air/src/lib.rs b/air/src/lib.rs index 73b99e6c0..94a2ad57c 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -5,11 +5,57 @@ pub mod passes; #[cfg(test)] mod tests; -pub use self::codegen::CodeGenerator; -pub use self::graph::{AlgebraicGraph, Node, NodeIndex}; -pub use self::ir::*; +use air_parser::ast::Program; +use air_pass::Pass; +use miden_diagnostics::{Diagnostic, DiagnosticsHandler, ToDiagnostic}; +use mir::ir::Mir; -use miden_diagnostics::{Diagnostic, ToDiagnostic}; +pub use self::{ + codegen::CodeGenerator, + graph::{AlgebraicGraph, Node, NodeIndex}, + ir::*, +}; + +/// Compiles an AirScript program from the the parsed AST to the AIR +pub fn compile(diagnostics: &DiagnosticsHandler, program: Program) -> Result { + let mut pipeline = ast_to_air_pipeline(diagnostics); + pipeline.run(program) +} + +/// Creates a pipeline of passes that transforms an AST into AIR. +pub fn ast_to_air_pipeline<'a>( + diagnostics: &DiagnosticsHandler, +) -> impl Pass = Program, Output<'a> = Air, Error = CompileError> { + let ast_passes = air_parser::AstPasses::new(diagnostics); + let mir_passes = mir::MirPasses::new(diagnostics); + let air_ir_passes = crate::AirPasses::new(diagnostics); + + ast_passes.chain(mir_passes).chain(air_ir_passes) +} + +/// Abstracts the various passes done on the AIR representation of the program. +pub struct AirPasses<'a> { + diagnostics: &'a DiagnosticsHandler, +} + +impl<'a> AirPasses<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } +} + +impl Pass for AirPasses<'_> { + type Input<'a> = Mir; + type Output<'a> = Air; + type Error = CompileError; + + fn run<'a>(&mut self, input: Self::Input<'a>) -> Result, Self::Error> { + let mut passes = passes::MirToAir::new(self.diagnostics) + .chain(passes::BusOpExpand::new(self.diagnostics)) + .chain(passes::CommonSubexpressionElimination::new(self.diagnostics)); + passes.run(input) + } +} #[derive(Debug, thiserror::Error)] pub enum CompileError { diff --git a/air/src/passes/common_subexpression_elimination.rs b/air/src/passes/common_subexpression_elimination.rs new file mode 100644 index 000000000..bc162b0da --- /dev/null +++ b/air/src/passes/common_subexpression_elimination.rs @@ -0,0 +1,44 @@ +use air_pass::Pass; +use miden_diagnostics::DiagnosticsHandler; + +use crate::{Air, CompileError}; + +/// This pass aims to remove duplicate nodes in the Algebraic Graph by evaluating +/// each node at random inputs. The process relies on: +/// - Iterating over and evaluating all the nodes in order of their NodeIndex +/// - Do not insert nodes that evaluate to the same value as an existing node +/// - Update the indices of the nodes in the graph to reflect the changes +/// +/// Note: This pass requires that boundary constraint should be inserted in the graph before +/// integrity constraints or buses, to keep the nodes consistent with Winterfell codegen's +/// expectation. +pub struct CommonSubexpressionElimination<'a> { + #[allow(unused)] + diagnostics: &'a DiagnosticsHandler, +} + +impl Pass for CommonSubexpressionElimination<'_> { + type Input<'a> = Air; + type Output<'a> = Air; + type Error = CompileError; + + fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { + // Evaluate the nodes at random points, and eliminate common subexpressions in the graph + // based on the evaluations. This will both: + // - Remove nodes that have the same evaluation, keeping only one instance. + // - Update the indices of the nodes to reflect the changes in the graph. + let renumbering_map = ir.constraint_graph_mut().eliminate_common_subexpressions(); + + // Update constraints with the new node indices + ir.constraints.renumber_and_deduplicate_constraints(&renumbering_map); + + Ok(ir) + } +} + +impl<'a> CommonSubexpressionElimination<'a> { + #[allow(unused)] + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } +} diff --git a/air/src/passes/expand_buses.rs b/air/src/passes/expand_buses.rs index 2ef9b2011..ff02487dd 100644 --- a/air/src/passes/expand_buses.rs +++ b/air/src/passes/expand_buses.rs @@ -1,11 +1,10 @@ -use air_parser::ast::{Boundary, BusType}; +use air_parser::ast::{Boundary, BusType, TraceSegmentId}; use air_pass::Pass; use miden_diagnostics::DiagnosticsHandler; use mir::ir::BusOpKind; use crate::{ - AUX_SEGMENT, Air, BusBoundary, BusOp, CompileError, ConstraintDomain, NodeIndex, Operation, - TraceAccess, + Air, BusBoundary, BusOp, CompileError, ConstraintDomain, NodeIndex, Operation, TraceAccess, }; pub struct BusOpExpand<'a> { @@ -42,50 +41,45 @@ impl Pass for BusOpExpand<'_> { let bus_ops = bus.bus_ops.clone(); - let bus_trace_access = TraceAccess::new(AUX_SEGMENT, bus_index, 0); - let bus_trace_access_with_offset = TraceAccess::new(AUX_SEGMENT, bus_index, 1); + let bus_trace_access = TraceAccess::new(TraceSegmentId::Aux, bus_index, 0); + let bus_trace_access_with_offset = TraceAccess::new(TraceSegmentId::Aux, bus_index, 1); - let bus_access = - ir.constraint_graph_mut() - .insert_node(Operation::Value(crate::Value::TraceAccess( - bus_trace_access, - ))); + let bus_access = ir + .constraint_graph_mut() + .insert_node(Operation::Value(crate::Value::TraceAccess(bus_trace_access))); let bus_access_with_offset = ir.constraint_graph_mut() .insert_node(Operation::Value(crate::Value::TraceAccess( bus_trace_access_with_offset, ))); - // Then, depending on the bus type, expand the integrity constraint - match bus_type { - BusType::Multiset => { - self.expand_multiset_constraint( - &mut ir, - bus_ops, - bus_access, - bus_access_with_offset, - ); - } - BusType::Logup => { - self.expand_logup_constraint( - &mut ir, - bus_ops, - bus_access, - bus_access_with_offset, - ); + // Then, depending on the bus type, expand the integrity constraint if + // the bus is constrained + if !bus_ops.is_empty() { + match bus_type { + BusType::Multiset => { + self.expand_multiset_constraint( + &mut ir, + bus_ops, + bus_access, + bus_access_with_offset, + ); + }, + BusType::Logup => { + self.expand_logup_constraint( + &mut ir, + bus_ops, + bus_access, + bus_access_with_offset, + ); + }, } } } ir.num_random_values = buses .values() - .map(|bus| { - bus.bus_ops - .iter() - .map(|a| a.columns.len() + 1) - .max() - .unwrap_or_default() - }) + .map(|bus| bus.bus_ops.iter().map(|a| a.columns.len() + 1).max().unwrap_or_default()) .max() .unwrap_or_default() as u16; @@ -108,40 +102,43 @@ impl<'a> BusOpExpand<'a> { boundary: Boundary, bus_index: usize, ) { - match bus_boundary { - // Boundaries to PublicInputTable should be handled later during codegen, as we cannot - // know at this point the length of the table, so we cannot generate the resulting constraint - BusBoundary::PublicInputTable(_public_input_table_access) => {} - // Unconstrained boundaries do not require any constraints - BusBoundary::Unconstrained => {} + let value = match bus_boundary { + // Boundaries to PublicInputTable reference a value corresponding to the random + // reduction of a public input table for a given bus type (multiset or logUp) + BusBoundary::PublicInputTable(public_input_table_access) => { + ir.constraint_graph_mut().insert_node(Operation::Value( + crate::Value::PublicInputTable(*public_input_table_access), + )) + }, BusBoundary::Null => { - // The value of the constraint for an empty bus depends on the bus types (1 for multiset, 0 for logup) - let value = match bus_type { + // The value of the constraint for an empty bus depends on the bus types (1 for + // multiset, 0 for logup) + match bus_type { BusType::Multiset => ir .constraint_graph_mut() .insert_node(Operation::Value(crate::Value::Constant(1))), BusType::Logup => ir .constraint_graph_mut() .insert_node(Operation::Value(crate::Value::Constant(0))), - }; - - let bus_trace_access = TraceAccess::new(AUX_SEGMENT, bus_index, 0); - let bus_access = ir.constraint_graph_mut().insert_node(Operation::Value( - crate::Value::TraceAccess(bus_trace_access), - )); - - // Then, we enforce for instance the constraint `p.first = 1` or `q.first = 0` to have an empty bus initially - let root = ir - .constraint_graph_mut() - .insert_node(Operation::Sub(bus_access, value)); - let domain = match boundary { - Boundary::First => ConstraintDomain::FirstRow, - Boundary::Last => ConstraintDomain::LastRow, - }; - // Store the generated constraint - ir.constraints.insert_constraint(AUX_SEGMENT, root, domain); - } - } + } + }, + // Unconstrained boundaries do not require any constraints + BusBoundary::Unconstrained => return, + }; + let bus_trace_access = TraceAccess::new(TraceSegmentId::Aux, bus_index, 0); + let bus_access = ir + .constraint_graph_mut() + .insert_node(Operation::Value(crate::Value::TraceAccess(bus_trace_access))); + + // Then, we enforce for instance the constraint `p.first = 0/1` or `q.first = value` to + // have an empty bus initially or equal to the values given in a public input table. + let root = ir.constraint_graph_mut().insert_node(Operation::Sub(bus_access, value)); + let domain = match boundary { + Boundary::First => ConstraintDomain::FirstRow, + Boundary::Last => ConstraintDomain::LastRow, + }; + // Store the generated constraint + ir.constraints.insert_constraint(TraceSegmentId::Aux, root, domain); } /// Helper function to expand the integrity constraint of a multiset bus @@ -168,7 +165,8 @@ impl<'a> BusOpExpand<'a> { // p.remove(c, d) when (1 - s) // => p' * (( A0 + A1 c + A2 d ) ( 1 - s ) + s) = p * ( A0 + A1 a + A2 b ) s + 1 - s - // p' * ( columns removed combined with alphas ) = p * ( columns inserted combined with alphas ) + // p' * ( columns removed combined with alphas ) = p * ( columns inserted combined with + // alphas ) let mut args_combined = graph.insert_node(Operation::Value(crate::Value::RandomValue(0))); @@ -196,7 +194,8 @@ impl<'a> BusOpExpand<'a> { let args_combined_with_latch_and_latch_inverse = graph.insert_node(Operation::Add(args_combined_with_latch, inverse_latch)); - // 4. Multiply them to p_factor or p_prime_factor (depending on bus_op_kind: insert: p, remove: p_prime) + // 4. Multiply them to p_factor or p_prime_factor (depending on bus_op_kind: insert: p, + // remove: p_prime) match bus_op_kind { BusOpKind::Insert => { p_factor = match p_factor { @@ -206,7 +205,7 @@ impl<'a> BusOpExpand<'a> { ))), None => Some(args_combined_with_latch_and_latch_inverse), }; - } + }, BusOpKind::Remove => { p_prime_factor = match p_prime_factor { Some(p_prime_factor) => Some(graph.insert_node(Operation::Mul( @@ -215,11 +214,12 @@ impl<'a> BusOpExpand<'a> { ))), None => Some(args_combined_with_latch_and_latch_inverse), }; - } + }, } } - // 5. Multiply the factors with the bus column (with and without offset for p' and p respectively) + // 5. Multiply the factors with the bus column (with and without offset for p' and p + // respectively) let p_prod = match p_factor { Some(p_factor) => graph.insert_node(Operation::Mul(p_factor, bus_access)), None => bus_access, @@ -227,7 +227,7 @@ impl<'a> BusOpExpand<'a> { let p_prime_prod = match p_prime_factor { Some(p_prime_factor) => { graph.insert_node(Operation::Mul(p_prime_factor, bus_access_with_offset)) - } + }, None => bus_access_with_offset, }; @@ -235,7 +235,7 @@ impl<'a> BusOpExpand<'a> { let root = graph.insert_node(Operation::Sub(p_prod, p_prime_prod)); ir.constraints - .insert_constraint(AUX_SEGMENT, root, ConstraintDomain::EveryRow); + .insert_constraint(TraceSegmentId::Aux, root, ConstraintDomain::EveryRow); } /// Helper function to expand the integrity constraint of a logup bus @@ -252,8 +252,9 @@ impl<'a> BusOpExpand<'a> { // q.remove(e, f, g) when s // => q' + s / ( A0 + A1 e + A2 f + A3 g ) = q + d / ( A0 + A1 a + A2 b + A3 c ) - // q' + s / ( columns removed combined with alphas ) = q + d / ( columns inserted combined with alphas ) - // PROD * q' + s * ( columns inserted combined with alphas ) = PROD * q + d * ( columns removed combined with alphas ) + // q' + s / ( columns removed combined with alphas ) = q + d / ( columns inserted combined + // with alphas ) PROD * q' + s * ( columns inserted combined with alphas ) = PROD * + // q + d * ( columns removed combined with alphas ) // 1. Compute all the factors @@ -287,12 +288,13 @@ impl<'a> BusOpExpand<'a> { total_factors = match total_factors { Some(total_factors) => { Some(graph.insert_node(Operation::Mul(total_factors, *factor))) - } + }, None => Some(*factor), }; } - // 3. For each column, compute the product of all factors except the one of the current column, and multiply it with the latch + // 3. For each column, compute the product of all factors except the one of the current + // column, and multiply it with the latch let mut terms_added_to_bus = None; let mut terms_removed_from_bus = None; @@ -317,7 +319,7 @@ impl<'a> BusOpExpand<'a> { let factors_without_current_with_latch = match factors_without_current { Some(factors_without_current) => { graph.insert_node(Operation::Mul(factors_without_current, latch)) - } + }, None => latch, }; @@ -331,7 +333,7 @@ impl<'a> BusOpExpand<'a> { ))), None => Some(factors_without_current_with_latch), }; - } + }, BusOpKind::Remove => { terms_removed_from_bus = match terms_removed_from_bus { Some(terms_removed_from_bus) => Some(graph.insert_node(Operation::Add( @@ -340,7 +342,7 @@ impl<'a> BusOpExpand<'a> { ))), None => Some(factors_without_current_with_latch), }; - } + }, } } @@ -352,25 +354,25 @@ impl<'a> BusOpExpand<'a> { let q_prime_prod = match total_factors { Some(total_factors) => { graph.insert_node(Operation::Mul(total_factors, bus_access_with_offset)) - } + }, None => bus_access_with_offset, }; let q_term = match terms_added_to_bus { Some(terms_added_to_bus) => { graph.insert_node(Operation::Add(q_prod, terms_added_to_bus)) - } + }, None => q_prod, }; let q_prime_term = match terms_removed_from_bus { Some(terms_removed_from_bus) => { graph.insert_node(Operation::Add(q_prime_prod, terms_removed_from_bus)) - } + }, None => q_prime_prod, }; // 5. Create the resulting constraint let root = graph.insert_node(Operation::Sub(q_term, q_prime_term)); ir.constraints - .insert_constraint(AUX_SEGMENT, root, ConstraintDomain::EveryRow); + .insert_constraint(TraceSegmentId::Aux, root, ConstraintDomain::EveryRow); } } diff --git a/air/src/passes/mod.rs b/air/src/passes/mod.rs index 87e3a7b64..c02ad8b82 100644 --- a/air/src/passes/mod.rs +++ b/air/src/passes/mod.rs @@ -1,7 +1,8 @@ +mod common_subexpression_elimination; mod expand_buses; -mod translate_from_ast; mod translate_from_mir; -pub use self::expand_buses::BusOpExpand; -pub use self::translate_from_ast::AstToAir; -pub use self::translate_from_mir::MirToAir; +pub use self::{ + common_subexpression_elimination::CommonSubexpressionElimination, expand_buses::BusOpExpand, + translate_from_mir::MirToAir, +}; diff --git a/air/src/passes/translate_from_ast.rs b/air/src/passes/translate_from_ast.rs deleted file mode 100644 index 1e56ffcec..000000000 --- a/air/src/passes/translate_from_ast.rs +++ /dev/null @@ -1,650 +0,0 @@ -use air_parser::{LexicalScope, ast}; -use air_pass::Pass; - -use miden_diagnostics::{DiagnosticsHandler, Severity, Span, Spanned}; - -use crate::{CompileError, graph::NodeIndex, ir::*}; - -/// This pass creates the [Air] from the [ast::Program]. -/// -/// It should be deprecated once the compilation pipeline uses the [Mir] construct. -/// -pub struct AstToAir<'a> { - diagnostics: &'a DiagnosticsHandler, -} -impl<'a> AstToAir<'a> { - /// Create a new instance of this pass - #[inline] - pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { diagnostics } - } -} -impl Pass for AstToAir<'_> { - type Input<'a> = ast::Program; - type Output<'a> = Air; - type Error = CompileError; - - fn run<'a>(&mut self, program: Self::Input<'a>) -> Result, Self::Error> { - let mut air = Air::new(program.name); - - let trace_columns = program.trace_columns; - let boundary_constraints = program.boundary_constraints; - let integrity_constraints = program.integrity_constraints; - - air.trace_segment_widths = trace_columns.iter().map(|ts| ts.size as u16).collect(); - air.periodic_columns = program.periodic_columns; - air.public_inputs = program.public_inputs; - - let mut builder = AirBuilder { - diagnostics: self.diagnostics, - air: &mut air, - trace_columns, - bindings: Default::default(), - }; - - for bc in boundary_constraints.iter() { - builder.build_boundary_constraint(bc)?; - } - - for ic in integrity_constraints.iter() { - builder.build_integrity_constraint(ic)?; - } - - Ok(air) - } -} - -#[derive(Debug, Clone)] -enum MemoizedBinding { - /// The binding was reduced to a node in the graph - Scalar(NodeIndex), - /// The binding represents a vector of nodes in the graph - Vector(Vec), - /// The binding represents a matrix of nodes in the graph - Matrix(Vec>), -} - -struct AirBuilder<'a> { - diagnostics: &'a DiagnosticsHandler, - air: &'a mut Air, - trace_columns: Vec, - bindings: LexicalScope, -} -impl AirBuilder<'_> { - fn build_boundary_constraint(&mut self, bc: &ast::Statement) -> Result<(), CompileError> { - match bc { - ast::Statement::Enforce(ast::ScalarExpr::Binary(ast::BinaryExpr { - op: ast::BinaryOp::Eq, - lhs, - rhs, - .. - })) => self.build_boundary_equality(lhs, rhs), - ast::Statement::Let(expr) => { - self.build_let(expr, |bldr, stmt| bldr.build_boundary_constraint(stmt)) - } - invalid => { - self.diagnostics - .diagnostic(Severity::Bug) - .with_message("invalid boundary constraint") - .with_primary_label( - invalid.span(), - "expected this to have been reduced to an equality", - ) - .emit(); - Err(CompileError::Failed) - } - } - } - - fn build_integrity_constraint(&mut self, ic: &ast::Statement) -> Result<(), CompileError> { - match ic { - ast::Statement::Enforce(ast::ScalarExpr::Binary(ast::BinaryExpr { - op: ast::BinaryOp::Eq, - lhs, - rhs, - .. - })) => self.build_integrity_equality(lhs, rhs, None), - ast::Statement::EnforceIf( - ast::ScalarExpr::Binary(ast::BinaryExpr { - op: ast::BinaryOp::Eq, - lhs, - rhs, - .. - }), - condition, - ) => self.build_integrity_equality(lhs, rhs, Some(condition)), - ast::Statement::Let(expr) => { - self.build_let(expr, |bldr, stmt| bldr.build_integrity_constraint(stmt)) - } - invalid => { - self.diagnostics - .diagnostic(Severity::Bug) - .with_message("invalid integrity constraint") - .with_primary_label( - invalid.span(), - "expected this to have been reduced to an equality", - ) - .emit(); - Err(CompileError::Failed) - } - } - } - - fn build_let( - &mut self, - expr: &ast::Let, - mut statement_builder: F, - ) -> Result<(), CompileError> - where - F: FnMut(&mut AirBuilder, &ast::Statement) -> Result<(), CompileError>, - { - let bound = self.eval_expr(&expr.value)?; - self.bindings.enter(); - self.bindings.insert(expr.name, bound); - for stmt in expr.body.iter() { - statement_builder(self, stmt)?; - } - self.bindings.exit(); - Ok(()) - } - - fn build_boundary_equality( - &mut self, - lhs: &ast::ScalarExpr, - rhs: &ast::ScalarExpr, - ) -> Result<(), CompileError> { - let lhs_span = lhs.span(); - let rhs_span = rhs.span(); - - // The left-hand side of a boundary constraint equality expression is always a bounded symbol access - // against a trace column. It is fine to panic here if that is ever violated. - let ast::ScalarExpr::BoundedSymbolAccess(access) = lhs else { - self.diagnostics - .diagnostic(Severity::Bug) - .with_message("invalid boundary constraint") - .with_primary_label( - lhs_span, - "expected bounded trace column access here, e.g. 'main[0].first'", - ) - .emit(); - return Err(CompileError::Failed); - }; - // Insert the trace access into the graph - let trace_access = self.trace_access(&access.column).unwrap(); - - // Raise a validation error if this column boundary has already been constrained - if let Some(prev) = self.trace_columns[trace_access.segment].mark_constrained( - lhs_span, - trace_access.column, - access.boundary, - ) { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("overlapping boundary constraints") - .with_primary_label( - lhs_span, - "this constrains a column and boundary that has already been constrained", - ) - .with_secondary_label(prev, "previous constraint occurs here") - .emit(); - return Err(CompileError::Failed); - } - - let lhs = self.insert_op(Operation::Value(Value::TraceAccess(trace_access))); - // Insert the right-hand expression into the graph - let rhs = self.insert_scalar_expr(rhs)?; - // Compare the inferred trace segment and domain of the operands - let domain = access.boundary.into(); - { - let graph = self.air.constraint_graph(); - let (lhs_segment, lhs_domain) = graph.node_details(&lhs, domain)?; - let (rhs_segment, rhs_domain) = graph.node_details(&rhs, domain)?; - if lhs_segment < rhs_segment { - // trace segment inference defaults to the lowest segment (the main trace) and is - // adjusted according to the use of random values and trace columns. - let lhs_segment_name = self.trace_columns[lhs_segment].name; - let rhs_segment_name = self.trace_columns[rhs_segment].name; - self.diagnostics.diagnostic(Severity::Error) - .with_message("invalid boundary constraint") - .with_primary_label(lhs_span, format!("this constrains a column in the '{lhs_segment_name}' trace segment")) - .with_secondary_label(rhs_span, format!("but this expression implies the '{rhs_segment_name}' trace segment")) - .with_note("Boundary constraints require both sides of the constraint to apply to the same trace segment.") - .emit(); - return Err(CompileError::Failed); - } - if lhs_domain != rhs_domain { - self.diagnostics.diagnostic(Severity::Error) - .with_message("invalid boundary constraint") - .with_primary_label(lhs_span, format!("this has a constraint domain of {lhs_domain}")) - .with_secondary_label(rhs_span, format!("this has a constraint domain of {rhs_domain}")) - .with_note("Boundary constraints require both sides of the constraint to be in the same domain.") - .emit(); - return Err(CompileError::Failed); - } - } - // Merge the expressions into a single constraint - let root = self.merge_equal_exprs(lhs, rhs, None); - // Store the generated constraint - self.air - .constraints - .insert_constraint(trace_access.segment, root, domain); - - Ok(()) - } - - fn build_integrity_equality( - &mut self, - lhs: &ast::ScalarExpr, - rhs: &ast::ScalarExpr, - condition: Option<&ast::ScalarExpr>, - ) -> Result<(), CompileError> { - let lhs = self.insert_scalar_expr(lhs)?; - let rhs = self.insert_scalar_expr(rhs)?; - let condition = match condition { - Some(cond) => Some(self.insert_scalar_expr(cond)?), - None => None, - }; - let root = self.merge_equal_exprs(lhs, rhs, condition); - // Get the trace segment and domain of the constraint. - // - // The default domain for integrity constraints is `EveryRow` - let (trace_segment, domain) = self - .air - .constraint_graph() - .node_details(&root, ConstraintDomain::EveryRow)?; - // Save the constraint information - self.air - .constraints - .insert_constraint(trace_segment, root, domain); - - Ok(()) - } - - fn merge_equal_exprs( - &mut self, - lhs: NodeIndex, - rhs: NodeIndex, - selector: Option, - ) -> NodeIndex { - if let Some(selector) = selector { - let constraint = self.insert_op(Operation::Sub(lhs, rhs)); - self.insert_op(Operation::Mul(constraint, selector)) - } else { - self.insert_op(Operation::Sub(lhs, rhs)) - } - } - - fn eval_let_expr(&mut self, expr: &ast::Let) -> Result { - let mut next_let = Some(expr); - let snapshot = self.bindings.clone(); - loop { - let let_expr = next_let.take().expect("invalid empty let body"); - let bound = self.eval_expr(&let_expr.value)?; - self.bindings.enter(); - self.bindings.insert(let_expr.name, bound); - match let_expr.body.last().unwrap() { - ast::Statement::Let(inner_let) => { - next_let = Some(inner_let); - } - ast::Statement::Expr(expr) => { - let value = self.eval_expr(expr); - self.bindings = snapshot; - break value; - } - ast::Statement::Enforce(_) - | ast::Statement::EnforceIf(_, _) - | ast::Statement::EnforceAll(_) - | ast::Statement::BusEnforce(_) => { - unreachable!() - } - } - } - } - - fn eval_expr(&mut self, expr: &ast::Expr) -> Result { - match expr { - ast::Expr::Const(constant) => match &constant.item { - ast::ConstantExpr::Scalar(value) => { - let value = self.insert_constant(*value); - Ok(MemoizedBinding::Scalar(value)) - } - ast::ConstantExpr::Vector(values) => { - let values = self.insert_constants(values.as_slice()); - Ok(MemoizedBinding::Vector(values)) - } - ast::ConstantExpr::Matrix(values) => { - let values = values - .iter() - .map(|vs| self.insert_constants(vs.as_slice())) - .collect(); - Ok(MemoizedBinding::Matrix(values)) - } - }, - ast::Expr::Range(values) => { - let values = values - .to_slice_range() - .map(|v| self.insert_constant(v as u64)) - .collect(); - Ok(MemoizedBinding::Vector(values)) - } - ast::Expr::Vector(values) => match values[0].ty().unwrap() { - ast::Type::Felt => { - let mut nodes = vec![]; - for value in values.iter().cloned() { - let value = value.try_into().unwrap(); - nodes.push(self.insert_scalar_expr(&value)?); - } - Ok(MemoizedBinding::Vector(nodes)) - } - ast::Type::Vector(n) => { - let mut nodes = vec![]; - for row in values.iter().cloned() { - match row { - ast::Expr::Const(Span { - item: ast::ConstantExpr::Vector(vs), - .. - }) => { - nodes.push(self.insert_constants(vs.as_slice())); - } - ast::Expr::SymbolAccess(access) => { - let mut cols = vec![]; - for i in 0..n { - let access = ast::ScalarExpr::SymbolAccess( - access.access(AccessType::Index(i)).unwrap(), - ); - let node = self.insert_scalar_expr(&access)?; - cols.push(node); - } - nodes.push(cols); - } - ast::Expr::Vector(elems) => { - let mut cols = vec![]; - for elem in elems.iter().cloned() { - let elem: ast::ScalarExpr = elem.try_into().unwrap(); - let node = self.insert_scalar_expr(&elem)?; - cols.push(node); - } - nodes.push(cols); - } - _ => unreachable!(), - } - } - Ok(MemoizedBinding::Matrix(nodes)) - } - _ => unreachable!(), - }, - ast::Expr::Matrix(values) => { - let mut rows = Vec::with_capacity(values.len()); - for vs in values.iter() { - let mut cols = Vec::with_capacity(vs.len()); - for value in vs { - cols.push(self.insert_scalar_expr(value)?); - } - rows.push(cols); - } - Ok(MemoizedBinding::Matrix(rows)) - } - ast::Expr::Binary(bexpr) => { - let value = self.insert_binary_expr(bexpr)?; - Ok(MemoizedBinding::Scalar(value)) - } - ast::Expr::SymbolAccess(access) => { - match self.bindings.get(access.name.as_ref()) { - None => { - // Must be a reference to a declaration - let value = self.insert_symbol_access(access); - Ok(MemoizedBinding::Scalar(value)) - } - Some(MemoizedBinding::Scalar(node)) => { - assert_eq!(access.access_type, AccessType::Default); - Ok(MemoizedBinding::Scalar(*node)) - } - Some(MemoizedBinding::Vector(nodes)) => { - let value = match &access.access_type { - AccessType::Default => MemoizedBinding::Vector(nodes.clone()), - AccessType::Index(idx) => MemoizedBinding::Scalar(nodes[*idx]), - AccessType::Slice(range) => { - MemoizedBinding::Vector(nodes[range.to_slice_range()].to_vec()) - } - AccessType::Matrix(_, _) => unreachable!(), - }; - Ok(value) - } - Some(MemoizedBinding::Matrix(nodes)) => { - let value = match &access.access_type { - AccessType::Default => MemoizedBinding::Matrix(nodes.clone()), - AccessType::Index(idx) => MemoizedBinding::Vector(nodes[*idx].clone()), - AccessType::Slice(range) => { - MemoizedBinding::Matrix(nodes[range.to_slice_range()].to_vec()) - } - AccessType::Matrix(row, col) => { - MemoizedBinding::Scalar(nodes[*row][*col]) - } - }; - Ok(value) - } - } - } - ast::Expr::Let(let_expr) => self.eval_let_expr(let_expr), - // These node types should not exist at this point - ast::Expr::Call(_) | ast::Expr::ListComprehension(_) => unreachable!(), - ast::Expr::BusOperation(_) | ast::Expr::Null(_) | ast::Expr::Unconstrained(_) => { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("buses are not implemented for this Pipeline") - .emit(); - Err(CompileError::Failed) - } - } - } - - fn insert_scalar_expr(&mut self, expr: &ast::ScalarExpr) -> Result { - match expr { - ast::ScalarExpr::Const(value) => { - Ok(self.insert_op(Operation::Value(Value::Constant(value.item)))) - } - ast::ScalarExpr::SymbolAccess(access) => Ok(self.insert_symbol_access(access)), - ast::ScalarExpr::Binary(expr) => self.insert_binary_expr(expr), - ast::ScalarExpr::Let(let_expr) => match self.eval_let_expr(let_expr)? { - MemoizedBinding::Scalar(node) => Ok(node), - invalid => { - panic!("expected scalar expression to produce scalar value, got: {invalid:?}") - } - }, - ast::ScalarExpr::Call(_) - | ast::ScalarExpr::BoundedSymbolAccess(_) - | ast::ScalarExpr::BusOperation(_) - | ast::ScalarExpr::Null(_) - | ast::ScalarExpr::Unconstrained(_) => unreachable!(), - } - } - - // Use square and multiply algorithm to expand the exp into a series of multiplications - fn expand_exp(&mut self, lhs: NodeIndex, rhs: u64) -> NodeIndex { - match rhs { - 0 => self.insert_constant(1), - 1 => lhs, - n if n % 2 == 0 => { - let square = self.insert_op(Operation::Mul(lhs, lhs)); - self.expand_exp(square, n / 2) - } - n => { - let square = self.insert_op(Operation::Mul(lhs, lhs)); - let rec = self.expand_exp(square, (n - 1) / 2); - self.insert_op(Operation::Mul(lhs, rec)) - } - } - } - - fn insert_binary_expr(&mut self, expr: &ast::BinaryExpr) -> Result { - if expr.op == ast::BinaryOp::Exp { - let lhs = self.insert_scalar_expr(expr.lhs.as_ref())?; - let ast::ScalarExpr::Const(rhs) = expr.rhs.as_ref() else { - unreachable!(); - }; - return Ok(self.expand_exp(lhs, rhs.item)); - } - - let lhs = self.insert_scalar_expr(expr.lhs.as_ref())?; - let rhs = self.insert_scalar_expr(expr.rhs.as_ref())?; - Ok(match expr.op { - ast::BinaryOp::Add => self.insert_op(Operation::Add(lhs, rhs)), - ast::BinaryOp::Sub => self.insert_op(Operation::Sub(lhs, rhs)), - ast::BinaryOp::Mul => self.insert_op(Operation::Mul(lhs, rhs)), - _ => unreachable!(), - }) - } - - fn insert_symbol_access(&mut self, access: &ast::SymbolAccess) -> NodeIndex { - use air_parser::ast::ResolvableIdentifier; - match access.name { - // At this point during compilation, fully-qualified identifiers can only possibly refer - // to a periodic column, as all functions have been inlined, and constants propagated. - ResolvableIdentifier::Resolved(ref qid) => { - if let Some(pc) = self.air.periodic_columns.get(qid) { - self.insert_op(Operation::Value(Value::PeriodicColumn( - PeriodicColumnAccess::new(*qid, pc.period()), - ))) - } else { - // This is a qualified reference that should have been eliminated - // during inlining or constant propagation, but somehow slipped through. - unreachable!( - "expected reference to periodic column, got `{:?}` instead", - qid - ); - } - } - // This must be one of public inputs, random values, or trace columns - ResolvableIdentifier::Global(id) | ResolvableIdentifier::Local(id) => { - // Special identifiers are those which are `$`-prefixed, and must refer to the names of trace segments (e.g. `$main`) - if id.is_special() { - // Must be a trace segment name - if let Some(ta) = self.trace_access(access) { - return self.insert_op(Operation::Value(Value::TraceAccess(ta))); - } - - // It should never be possible to reach this point - semantic analysis - // would have caught that this identifier is undefined. - unreachable!( - "expected reference to random values array or trace segment: {:#?}", - access - ); - } - - // Otherwise, we check the trace bindings and public inputs, in that order - if let Some(trace_access) = self.trace_access(access) { - return self.insert_op(Operation::Value(Value::TraceAccess(trace_access))); - } - - if let Some(public_input) = self.public_input_access(access) { - return self.insert_op(Operation::Value(Value::PublicInput(public_input))); - } - - // If we reach here, this must be a let-bound variable - match self - .bindings - .get(access.name.as_ref()) - .expect("undefined variable") - { - MemoizedBinding::Scalar(node) => { - assert_eq!(access.access_type, AccessType::Default); - *node - } - MemoizedBinding::Vector(nodes) => { - if let AccessType::Index(idx) = &access.access_type { - return nodes[*idx]; - } - unreachable!("impossible vector access: {:?}", access) - } - MemoizedBinding::Matrix(nodes) => { - if let AccessType::Matrix(row, col) = &access.access_type { - return nodes[*row][*col]; - } - unreachable!("impossible matrix access: {:?}", access) - } - } - } - // These should have been eliminated by previous compiler passes - ResolvableIdentifier::Unresolved(_) => { - unreachable!( - "expected fully-qualified or global reference, got `{:?}` instead", - &access.name - ); - } - } - } - - fn public_input_access(&self, access: &ast::SymbolAccess) -> Option { - let public_input = self.air.public_inputs.get(access.name.as_ref())?; - if let AccessType::Index(index) = access.access_type { - Some(PublicInputAccess::new(public_input.name(), index)) - } else { - // This should have been caught earlier during compilation - unreachable!( - "unexpected public input access type encountered during lowering: {:#?}", - access - ) - } - } - - fn trace_access(&self, access: &ast::SymbolAccess) -> Option { - let id = access.name.as_ref(); - for (i, segment) in self.trace_columns.iter().enumerate() { - if segment.name == id { - if let AccessType::Index(column) = access.access_type { - return Some(TraceAccess::new(i, column, access.offset)); - } else { - // This should have been caught earlier during compilation - unreachable!( - "unexpected trace access type encountered during lowering: {:#?}", - &access - ); - } - } - - if let Some(binding) = segment - .bindings - .iter() - .find(|tb| tb.name.as_ref() == Some(id)) - { - return match access.access_type { - AccessType::Default if binding.size == 1 => Some(TraceAccess::new( - binding.segment, - binding.offset, - access.offset, - )), - AccessType::Index(extra_offset) if binding.size > 1 => Some(TraceAccess::new( - binding.segment, - binding.offset + extra_offset, - access.offset, - )), - // This should have been caught earlier during compilation - _ => unreachable!( - "unexpected trace access type encountered during lowering: {:#?}", - access - ), - }; - } - } - - None - } - - /// Adds the specified operation to the graph and returns the index of its node. - #[inline] - fn insert_op(&mut self, op: Operation) -> NodeIndex { - self.air.constraint_graph_mut().insert_node(op) - } - - fn insert_constant(&mut self, value: u64) -> NodeIndex { - self.insert_op(Operation::Value(Value::Constant(value))) - } - - fn insert_constants(&mut self, values: &[u64]) -> Vec { - values - .iter() - .copied() - .map(|v| self.insert_constant(v)) - .collect() - } -} diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index 596b0d0d0..48379fd54 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -5,9 +5,14 @@ use air_parser::{ ast::{self, TraceSegment}, }; use air_pass::Pass; - use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; -use mir::ir::{ConstantValue, Link, Mir, MirValue, Op, Parent, SpannedMirValue}; +use mir::{ + ir::{ + Boundary as MirBoundary, ConstantValue, Link, Mir, MirAccessType, MirValue, Op, Parent, + SpannedMirValue, TraceAccess as MirTraceAccess, + }, + passes::get_inner_const, +}; use crate::{CompileError, graph::NodeIndex, ir::*}; @@ -35,33 +40,56 @@ impl Pass for MirToAir<'_> { let buses = mir.constraint_graph().buses.clone(); - let mut trace_columns = mir.trace_columns.clone(); + assert!( + mir.trace_columns.len() == 1, + "Expected one trace segment, but found: {:?}", + mir.trace_columns + ); + let main_trace_segment = mir.trace_columns.first().unwrap(); + + assert_eq!( + main_trace_segment.id, + TraceSegmentId::Main, + "Expected trace segment to be the main segment, but found: {:?}", + main_trace_segment.id + ); + + // Build trace segments shape: always include main; aux may be empty + let trace_columns_main = main_trace_segment.clone(); let mut bus_bindings_map = BTreeMap::new(); - if !buses.is_empty() { + let trace_columns_aux = if buses.is_empty() { + TraceSegment::new( + SourceSpan::default(), + TraceSegmentId::Aux, + Identifier::new(SourceSpan::default(), Symbol::intern("$aux")), + vec![], + ) + } else { let bus_raw_bindings: Vec<_> = buses .keys() - .map(|k| Span::new(k.span(), (Identifier::new(k.span(), k.name()), AUX_SEGMENT))) + .map(|k| Span::new(k.span(), (Identifier::new(k.span(), k.name()), 1))) .collect(); // Add buses as `aux` trace columns let aux_trace_segment = TraceSegment::new( SourceSpan::default(), - AUX_SEGMENT, - Identifier::new(SourceSpan::default(), Symbol::new(AUX_SEGMENT as u32)), + TraceSegmentId::Aux, + Identifier::new(SourceSpan::default(), Symbol::intern("$aux")), bus_raw_bindings, ); for binding in aux_trace_segment.bindings.iter() { bus_bindings_map.insert(binding.name.unwrap(), binding.offset); } - if trace_columns.len() == 1 { - trace_columns.push(aux_trace_segment); - } else { - panic!("Expected only one trace segment, but found multiple: {trace_columns:?}",); - } - } + aux_trace_segment + }; + + let trace_columns = TraceShape::new(trace_columns_main, trace_columns_aux); - air.trace_segment_widths = trace_columns.iter().map(|ts| ts.size as u16).collect(); + air.trace_segment_widths = vec![ + trace_columns[TraceSegmentId::Main].size as u16, + trace_columns[TraceSegmentId::Aux].size as u16, + ]; air.num_random_values = mir.num_random_values; air.periodic_columns = mir.periodic_columns.clone(); air.public_inputs = mir.public_inputs.clone(); @@ -75,10 +103,10 @@ impl Pass for MirToAir<'_> { let graph = mir.constraint_graph(); - for bus in buses.values() { - builder.build_bus(bus)?; - } - + // We insert all the constraints into the AIR graph. + // Note: We need to insert the boundary constraints before the integrity constraints + // as it's a requirement for the CommonSubexpressionElimination pass to work with the + // winterfell codegen. for bc in graph.boundary_constraints_roots.borrow().deref().iter() { builder.build_boundary_constraint(bc)?; } @@ -86,6 +114,15 @@ impl Pass for MirToAir<'_> { for ic in graph.integrity_constraints_roots.borrow().deref().iter() { builder.build_integrity_constraint(ic)?; } + + // Note: In the MIR, buses operations are kept in integrity constraints to + // allow them to be handled in the graph (e.g. inlined via evaluators). This is why + // we need to first visit the integrity constraints, update the corresponding bus + // when encountering a `BusOp`, and then visit the buses to build them. + for bus in buses.values() { + builder.build_bus(bus)?; + } + Ok(air) } } @@ -93,34 +130,63 @@ impl Pass for MirToAir<'_> { struct AirBuilder<'a> { diagnostics: &'a DiagnosticsHandler, air: &'a mut Air, - trace_columns: Vec, + trace_columns: TraceShape, bus_bindings_map: BTreeMap, } /// In case of nested list comprehension, we may not have entirely unrolled outer loops iterators /// so we need to ensure these cases are properly indexed. -fn indexed_accessor(mir_node: &Link) -> Link { +fn accessor_to_scalar(mir_node: &Link) -> Link { if let Some(accessor) = mir_node.as_accessor() { - if let AccessType::Index(index) = accessor.access_type { - if let Some(vec) = accessor.indexable.as_vector() { - let children = vec.elements.borrow().deref().clone(); - if index >= children.len() { - panic!( - "Index out of bounds during indexed accessor translation from MIR to AIR: {index}", - ); + match accessor.access_type.clone() { + MirAccessType::Index(index) => { + if let Some(vec) = accessor.indexable.as_vector() { + let children = vec.elements.borrow().deref().clone(); + let index = get_inner_const(&index) + .expect("Index should be a constant value after constant propagation") + as usize; + if index >= children.len() { + panic!( + "Index out of bounds during indexed accessor translation from MIR to AIR: {index}", + ); + } + children[index].clone() + } else { + mir_node.clone() } - children[index].clone() - } else { - mir_node.clone() - } - } else { - mir_node.clone() + }, + MirAccessType::Default => { + add_row_offset_if_trace_access(&accessor.indexable, accessor.offset) + }, + _ => mir_node.clone(), } } else { mir_node.clone() } } +/// Helper function to add a row offset to a TraceAccess value, and return the node unchanged +/// otherwise. +fn add_row_offset_if_trace_access(node: &Link, offset: usize) -> Link { + if let Some(value) = node.clone().as_value() { + let mir_value = value.value.value.clone(); + if let MirValue::TraceAccess(trace_access) = mir_value { + mir::ir::Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::TraceAccess(mir::ir::TraceAccess { + segment: trace_access.segment, + column: trace_access.column, + row_offset: trace_access.row_offset + offset, + }), + }) + } else { + node.clone() + } + } else { + node.clone() + } +} + /// Helper function to remove the vector wrapper from a scalar operation /// Will panic if the node is a vector of size > 1 (should not happen after unrolling) fn vec_to_scalar(mir_node: &Link) -> Link { @@ -131,8 +197,8 @@ fn vec_to_scalar(mir_node: &Link) -> Link { panic!("Vector of len >1 after unrolling: {mir_node:?}"); } let child = children.first().unwrap(); - let child = indexed_accessor(child); - let child = vec_to_scalar(&child); + let child = vec_to_scalar(child); + let child = accessor_to_scalar(&child); child.clone() } else { mir_node.clone() @@ -156,15 +222,15 @@ impl AirBuilder<'_> { match rhs { 0 => self.insert_op(Operation::Value(Value::Constant(1))), 1 => lhs, - n if n % 2 == 0 => { + n if n.is_multiple_of(2) => { let square = self.insert_op(Operation::Mul(lhs, lhs)); self.expand_exp(square, n / 2) - } + }, n => { let square = self.insert_op(Operation::Mul(lhs, lhs)); let rec = self.expand_exp(square, (n - 1) / 2); self.insert_op(Operation::Mul(lhs, rec)) - } + }, } } @@ -172,8 +238,15 @@ impl AirBuilder<'_> { /// Will panic when encountering an unexpected operation /// (i.e. that is not a binary operation, a value, enf node or an accessor) fn insert_mir_operation(&mut self, mir_node: &Link) -> Result { - let mir_node = indexed_accessor(mir_node); + // First, we need to remove accessors and vector wrappers to get the actual scalar operation + // to insert. Notes: + // - at this point, we expect trivial `Accessor` (with either constant index or default + // access type) or `Vector` with size 1. + // - in case of nested list comprehensions, we may need to unwrap two accessors, so we + // unwrap them multiple times. + let mir_node = accessor_to_scalar(mir_node); let mir_node = vec_to_scalar(&mir_node); + let mir_node = accessor_to_scalar(&mir_node); let mir_node_ref = mir_node.borrow(); match mir_node_ref.deref() { Op::Add(add) => { @@ -182,21 +255,21 @@ impl AirBuilder<'_> { let lhs_node_index = self.insert_mir_operation(&lhs)?; let rhs_node_index = self.insert_mir_operation(&rhs)?; Ok(self.insert_op(Operation::Add(lhs_node_index, rhs_node_index))) - } + }, Op::Sub(sub) => { let lhs = sub.lhs.clone(); let rhs = sub.rhs.clone(); let lhs_node_index = self.insert_mir_operation(&lhs)?; let rhs_node_index = self.insert_mir_operation(&rhs)?; Ok(self.insert_op(Operation::Sub(lhs_node_index, rhs_node_index))) - } + }, Op::Mul(mul) => { let lhs = mul.lhs.clone(); let rhs = mul.rhs.clone(); let lhs_node_index = self.insert_mir_operation(&lhs)?; let rhs_node_index = self.insert_mir_operation(&rhs)?; Ok(self.insert_op(Operation::Mul(lhs_node_index, rhs_node_index))) - } + }, Op::Exp(exp) => { let lhs = exp.lhs.clone(); let rhs = exp.rhs.clone(); @@ -236,7 +309,7 @@ impl AirBuilder<'_> { }; Ok(self.expand_exp(lhs_node_index, rhs_value)) - } + }, Op::Value(value) => { let mir_value = &value.value.value; @@ -247,53 +320,50 @@ impl AirBuilder<'_> { } else { unreachable!() } - } + }, MirValue::TraceAccess(trace_access) => { crate::ir::Value::TraceAccess(crate::ir::TraceAccess { segment: trace_access.segment, column: trace_access.column, row_offset: trace_access.row_offset, }) - } + }, MirValue::BusAccess(bus_access) => { let name = bus_access.bus.borrow().deref().name(); let column = self.bus_bindings_map.get(&name).unwrap(); crate::ir::Value::TraceAccess(crate::ir::TraceAccess { - segment: AUX_SEGMENT, + segment: TraceSegmentId::Aux, column: *column, row_offset: bus_access.row_offset, }) - } + }, MirValue::PeriodicColumn(periodic_column_access) => { crate::ir::Value::PeriodicColumn(crate::ir::PeriodicColumnAccess { - name: periodic_column_access.name, + name: periodic_column_access.name.clone(), cycle: periodic_column_access.cycle, }) - } + }, MirValue::PublicInput(public_input_access) => { crate::ir::Value::PublicInput(crate::ir::PublicInputAccess { name: public_input_access.name, index: public_input_access.index, }) - } + }, _ => unreachable!("Unexpected MirValue: {:#?}", mir_value), }; Ok(self.insert_op(Operation::Value(value))) - } + }, Op::Enf(enf) => { let child = enf.expr.clone(); self.insert_mir_operation(&child) - } + }, Op::Accessor(accessor) => { let offset = accessor.offset; let child = accessor.indexable.clone(); - let child = indexed_accessor(&child); - - let Some(value) = child.as_value() else { - unreachable!("Expected value in accessor, found: {:?}", child); - }; + let child = accessor_to_scalar(&child); + let value = child.as_value().expect("Expected value in accessor"); let mir_value = &value.value.value; let value = match mir_value { @@ -303,40 +373,40 @@ impl AirBuilder<'_> { } else { unreachable!() } - } + }, MirValue::TraceAccess(trace_access) => { crate::ir::Value::TraceAccess(crate::ir::TraceAccess { segment: trace_access.segment, column: trace_access.column, row_offset: offset, }) - } + }, MirValue::BusAccess(bus_access) => { let name = bus_access.bus.borrow().deref().name(); let column = self.bus_bindings_map.get(&name).unwrap(); crate::ir::Value::TraceAccess(crate::ir::TraceAccess { - segment: AUX_SEGMENT, + segment: TraceSegmentId::Aux, column: *column, row_offset: offset, }) - } + }, MirValue::PeriodicColumn(periodic_column_access) => { crate::ir::Value::PeriodicColumn(crate::ir::PeriodicColumnAccess { - name: periodic_column_access.name, + name: periodic_column_access.name.clone(), cycle: periodic_column_access.cycle, }) - } + }, MirValue::PublicInput(public_input_access) => { crate::ir::Value::PublicInput(crate::ir::PublicInputAccess { name: public_input_access.name, index: public_input_access.index, }) - } + }, _ => unreachable!(), }; Ok(self.insert_op(Operation::Value(value))) - } + }, _ => panic!("Should not have Mir op in graph: {mir_node:?}"), } } @@ -349,7 +419,7 @@ impl AirBuilder<'_> { self.build_boundary_constraint(node)?; } Ok(()) - } + }, Op::Matrix(matrix) => { let rows = matrix.elements.borrow().deref().clone(); for row in rows.iter() { @@ -359,101 +429,39 @@ impl AirBuilder<'_> { } } Ok(()) - } + }, Op::Enf(enf) => { let child_op = enf.expr.clone(); - let child_op = indexed_accessor(&child_op); + let child_op = accessor_to_scalar(&child_op); let child_op = vec_to_scalar(&child_op); self.build_boundary_constraint(&child_op)?; Ok(()) - } + }, Op::Sub(sub) => { // Check that lhs is a Bounded trace access let lhs = sub.lhs.clone(); - let lhs = indexed_accessor(&lhs); + let lhs = accessor_to_scalar(&lhs); let lhs = vec_to_scalar(&lhs); let rhs = sub.rhs.clone(); - let rhs = indexed_accessor(&rhs); + let rhs = accessor_to_scalar(&rhs); let rhs = vec_to_scalar(&rhs); let lhs_span = lhs.span(); let rhs_span = rhs.span(); let boundary = lhs.as_boundary().unwrap().clone(); - let expected_trace_access_expr = boundary.expr.clone(); - let Op::Value(value) = expected_trace_access_expr.borrow().deref().clone() else { - unreachable!(); // Raise diag - }; + let trace_access = self.extract_trace_from_boundary(boundary.clone())?; - let (trace_access, _) = match value.value.clone() { - SpannedMirValue { - value: MirValue::TraceAccess(trace_access), - span: lhs_span, - } => (trace_access, lhs_span), - SpannedMirValue { - value: MirValue::TraceAccessBinding(trace_access_binding), - span: lhs_span, - } => { - if trace_access_binding.size != 1 { - self.diagnostics.diagnostic(Severity::Error) - .with_message("invalid boundary constraint") - .with_primary_label(lhs_span, "this has a trace access binding with a size greater than 1") - .with_note("Boundary constraints require both sides of the constraint to be single columns.") - .emit(); - return Err(CompileError::Failed); - } - let trace_access = mir::ir::TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, - }; - (trace_access, lhs_span) - } - SpannedMirValue { - value: MirValue::BusAccess(bus_access), - span: lhs_span, - } => { - let bus = bus_access.bus; - let name = bus.borrow().deref().name(); - let column = self.bus_bindings_map.get(&name).unwrap(); - let trace_access = - mir::ir::TraceAccess::new(AUX_SEGMENT, *column, bus_access.row_offset); - (trace_access, lhs_span) - } - _ => unreachable!( - "Expected TraceAccess or BusAccess, received {:?}", - value.value - ), // Raise diag - }; + self.mark_constrained_boundary(trace_access, &boundary)?; - if let Some(prev) = self.trace_columns[trace_access.segment].mark_constrained( - lhs_span, - trace_access.column, - boundary.kind, - ) { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("overlapping boundary constraints") - .with_primary_label( - lhs_span, - "this constrains a column and boundary that has already been constrained", - ) - .with_secondary_label(prev, "previous constraint occurs here") - .emit(); - return Err(CompileError::Failed); - } - - let lhs = self - .air - .constraint_graph_mut() - .insert_node(Operation::Value(crate::ir::Value::TraceAccess( - crate::ir::TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset, - }, - ))); + let lhs = self.air.constraint_graph_mut().insert_node(Operation::Value( + crate::ir::Value::TraceAccess(crate::ir::TraceAccess { + segment: trace_access.segment, + column: trace_access.column, + row_offset: trace_access.row_offset, + }), + )); let rhs = self.insert_mir_operation(&rhs)?; // Compare the inferred trace segment and domain of the operands @@ -463,8 +471,9 @@ impl AirBuilder<'_> { let (lhs_segment, lhs_domain) = graph.node_details(&lhs, domain)?; let (rhs_segment, rhs_domain) = graph.node_details(&rhs, domain)?; if lhs_segment < rhs_segment { - // trace segment inference defaults to the lowest segment (the main trace) and is - // adjusted according to the use of random values and trace columns. + // trace segment inference defaults to the lowest segment (the main trace) + // and is adjusted according to the use of random + // values and trace columns. let lhs_segment_name = self.trace_columns[lhs_segment].name; let rhs_segment_name = self.trace_columns[rhs_segment].name; self.diagnostics.diagnostic(Severity::Error) @@ -490,11 +499,28 @@ impl AirBuilder<'_> { let root = self.insert_op(Operation::Sub(lhs, rhs)); // Store the generated constraint - self.air - .constraints - .insert_constraint(trace_access.segment, root, domain); + self.air.constraints.insert_constraint(trace_access.segment, root, domain); Ok(()) - } + }, + Op::Boundary(boundary) => { + let trace_access = self.extract_trace_from_boundary(boundary.clone())?; + + self.mark_constrained_boundary(trace_access, boundary)?; + + let root = self.air.constraint_graph_mut().insert_node(Operation::Value( + crate::ir::Value::TraceAccess(crate::ir::TraceAccess { + segment: trace_access.segment, + column: trace_access.column, + row_offset: trace_access.row_offset, + }), + )); + + let domain = boundary.kind.into(); + + // Store the generated constraint + self.air.constraints.insert_constraint(trace_access.segment, root, domain); + Ok(()) + }, _ => unreachable!(), } } @@ -506,7 +532,7 @@ impl AirBuilder<'_> { for node in vec.iter() { self.build_integrity_constraint(node)?; } - } + }, Op::Matrix(matrix) => { let rows = matrix.elements.borrow().deref().clone(); for row in rows.iter() { @@ -515,34 +541,44 @@ impl AirBuilder<'_> { self.build_integrity_constraint(node)?; } } - } + }, Op::Enf(enf) => { let child_op = enf.expr.clone(); - let child_op = indexed_accessor(&child_op); + let child_op = accessor_to_scalar(&child_op); let child_op = vec_to_scalar(&child_op); let child_op = enf_to_scalar(&child_op); match child_op.clone().borrow().deref() { Op::Sub(_sub) => { self.build_integrity_constraint(&child_op)?; - } - _ => unreachable!("Enforced with unexpected operation: {:?}", child_op), + }, + Op::BusOp(bus_op) => { + let bus = bus_op.bus.to_link().unwrap(); + let latch = bus_op.latch.clone(); + + bus.borrow_mut().latches.push(latch.clone()); + bus.borrow_mut().columns.push(child_op.clone()); + }, + _ => { + let root = self.insert_mir_operation(&child_op)?; + let (trace_segment, domain) = self + .air + .constraint_graph() + .node_details(&root, ConstraintDomain::EveryRow)?; + self.air.constraints.insert_constraint(trace_segment, root, domain); + }, } - } + }, Op::Sub(sub) => { let lhs = sub.lhs.clone(); let rhs = sub.rhs.clone(); let lhs_node_index = self.insert_mir_operation(&lhs)?; let rhs_node_index = self.insert_mir_operation(&rhs)?; let root = self.insert_op(Operation::Sub(lhs_node_index, rhs_node_index)); - let (trace_segment, domain) = self - .air - .constraint_graph() - .node_details(&root, ConstraintDomain::EveryRow)?; - self.air - .constraints - .insert_constraint(trace_segment, root, domain); - } - _ => unreachable!(), + let (trace_segment, domain) = + self.air.constraint_graph().node_details(&root, ConstraintDomain::EveryRow)?; + self.air.constraints.insert_constraint(trace_segment, root, domain); + }, + _ => unreachable!("Unexpected integrity constraint root: {:?}", ic), } Ok(()) } @@ -559,9 +595,7 @@ impl AirBuilder<'_> { let mut column = vec![]; // Note: we have checked this will not panic in the MIR pass - let mir_bus_op = mir_column - .as_bus_op() - .expect("Bus column should be a bus operation"); + let mir_bus_op = mir_column.as_bus_op().expect("Bus column should be a bus operation"); let mir_bus_op_args = mir_bus_op.args.clone(); for arg in mir_bus_op_args.iter() { let arg = self.insert_mir_operation(arg)?; @@ -584,6 +618,88 @@ impl AirBuilder<'_> { fn insert_op(&mut self, op: Operation) -> NodeIndex { self.air.constraint_graph_mut().insert_node(op) } + + /// Extracts the trace access information from a given [Mir] `Boundary`. + /// Returns a [Mir] `TraceAccess` with the corresponding segment id and column if the boundary + /// wraps a valid trace access column, or raises a diagnostic if the trace access has a size + /// greater than 1. + /// + /// Note: the boundary expression must only reference the constrained trace access, not the + /// whole boundary constraint expression. + /// + /// # Panics + /// Panics if the boundary does not wrap a trace access column, which should have been caught + /// during semantic analysis. + fn extract_trace_from_boundary( + &self, + boundary: MirBoundary, + ) -> Result { + let Op::Value(value) = boundary.expr.borrow().deref().clone() else { + unreachable!(); // Raise diag + }; + + let trace_access = match value.value.clone() { + SpannedMirValue { + value: MirValue::TraceAccess(trace_access), + .. + } => trace_access, + SpannedMirValue { + value: MirValue::TraceAccessBinding(trace_access_binding), + span, + } => { + if trace_access_binding.size != 1 { + self.diagnostics.diagnostic(Severity::Error) + .with_message("invalid boundary constraint") + .with_primary_label(span, "this has a trace access binding with a size greater than 1") + .with_note("Boundary constraints require both sides of the constraint to be single columns.") + .emit(); + return Err(CompileError::Failed); + } + MirTraceAccess { + segment: trace_access_binding.segment, + column: trace_access_binding.offset, + row_offset: 0, + } + }, + SpannedMirValue { + value: MirValue::BusAccess(bus_access), .. + } => { + let bus = bus_access.bus; + let name = bus.borrow().deref().name(); + let column = self.bus_bindings_map.get(&name).unwrap(); + MirTraceAccess::new(TraceSegmentId::Aux, *column, bus_access.row_offset) + }, + _ => unreachable!("Expected TraceAccess or BusAccess, received {:?}", value.value), /* Raise diag */ + }; + + Ok(trace_access) + } + + /// Marks a boundary as constrained by the given trace access information. + /// This is used to ensure that we do not insert duplicate boundary constraints in the graph. + fn mark_constrained_boundary( + &mut self, + trace_access: MirTraceAccess, + boundary: &MirBoundary, + ) -> Result<(), CompileError> { + if let Some(prev) = self.trace_columns[trace_access.segment].mark_constrained( + boundary.span(), + trace_access.column, + boundary.kind, + ) { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("overlapping boundary constraints") + .with_primary_label( + boundary.span(), + "this constrains a column and boundary that has already been constrained", + ) + .with_secondary_label(prev, "previous constraint occurs here") + .emit(); + return Err(CompileError::Failed); + } + Ok(()) + } } // HELPERS FUNCTIONS @@ -603,8 +719,8 @@ fn build_bus_boundary( MirValue::PublicInputTable(public_input_table) => Ok( crate::ir::BusBoundary::PublicInputTable(crate::ir::PublicInputTableAccess::new( public_input_table.table_name, - public_input_table.bus_name(), public_input_table.num_cols, + public_input_table.bus_type(), )), ), // This represents an empty bus @@ -622,7 +738,7 @@ fn build_bus_boundary( ) .emit(); Err(CompileError::Failed) - } + }, _ => unreachable!("Unexpected Mir Op in bus boundary: {:#?}", mir_node_ref), } } diff --git a/air/src/tests/access.rs b/air/src/tests/access.rs index 698f9797d..9221d5500 100644 --- a/air/src/tests/access.rs +++ b/air/src/tests/access.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, expect_diagnostic}; +use super::expect_diagnostic; #[test] fn invalid_vector_access_in_boundary_constraint() { @@ -21,16 +21,7 @@ fn invalid_vector_access_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -54,16 +45,7 @@ fn invalid_matrix_row_access_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -87,16 +69,7 @@ fn invalid_matrix_column_access_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -120,16 +93,7 @@ fn invalid_vector_access_in_integrity_constraint() { enf clk' = clk + A + B[3] - C[1][2]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -153,16 +117,7 @@ fn invalid_matrix_row_access_in_integrity_constraint() { enf clk' = clk + A + B[1] - C[3][2]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -186,14 +141,5 @@ fn invalid_matrix_column_access_in_integrity_constraint() { enf clk' = clk + A + B[1] - C[1][3]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } diff --git a/air/src/tests/boundary_constraints.rs b/air/src/tests/boundary_constraints.rs index 47e7707d5..e25a1c6eb 100644 --- a/air/src/tests/boundary_constraints.rs +++ b/air/src/tests/boundary_constraints.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile, expect_diagnostic}; +use super::{compile_from_source, expect_diagnostic}; #[test] fn boundary_constraints() { @@ -18,8 +18,7 @@ fn boundary_constraints() { enf clk' = clk + 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -40,16 +39,7 @@ fn err_bc_duplicate_first() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "overlapping boundary constraints", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "overlapping boundary constraints", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "overlapping boundary constraints"); } #[test] @@ -70,14 +60,5 @@ fn err_bc_duplicate_last() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "overlapping boundary constraints", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "overlapping boundary constraints", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "overlapping boundary constraints"); } diff --git a/air/src/tests/buses.rs b/air/src/tests/buses.rs index 7e3ccb859..d76a1f329 100644 --- a/air/src/tests/buses.rs +++ b/air/src/tests/buses.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile, expect_diagnostic}; +use super::{compile_from_source, expect_diagnostic}; #[test] fn buses_in_boundary_constraints() { @@ -29,12 +29,7 @@ fn buses_in_boundary_constraints() { enf a = 0; }"; - expect_diagnostic( - source, - "buses are not implemented for this Pipeline", - Pipeline::WithoutMIR, - ); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -74,12 +69,7 @@ fn buses_in_integrity_constraints() { q.remove(1, 2) with 2; }"; - expect_diagnostic( - source, - "buses are not implemented for this Pipeline", - Pipeline::WithoutMIR, - ); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } // Tests that should return errors @@ -110,8 +100,7 @@ fn err_buses_boundaries_to_const() { enf a = 0; }"; - expect_diagnostic(source, "error: invalid constraint", Pipeline::WithoutMIR); - expect_diagnostic(source, "error: invalid constraint", Pipeline::WithMIR); + expect_diagnostic(source, "error: invalid constraint"); } #[test] @@ -140,8 +129,7 @@ fn err_trace_columns_constrained_with_null() { enf a = 0; }"; - expect_diagnostic(source, "error: invalid constraint", Pipeline::WithoutMIR); - expect_diagnostic(source, "error: invalid constraint", Pipeline::WithMIR); + expect_diagnostic(source, "error: invalid constraint"); } #[test] @@ -172,10 +160,5 @@ fn err_buses_unconstrained() { enf a = 0; }"; - expect_diagnostic( - source, - "error: buses are not implemented for this Pipeline", - Pipeline::WithoutMIR, - ); - expect_diagnostic(source, "error: invalid bus boundary", Pipeline::WithMIR); + expect_diagnostic(source, "error: invalid bus boundary"); } diff --git a/air/src/tests/constant.rs b/air/src/tests/constant.rs index 70921c235..214ff85a6 100644 --- a/air/src/tests/constant.rs +++ b/air/src/tests/constant.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile, expect_diagnostic}; +use super::{compile_from_source, expect_diagnostic}; #[test] fn boundary_constraint_with_constants() { @@ -21,8 +21,7 @@ fn boundary_constraint_with_constants() { enf clk' = clk - 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -45,8 +44,7 @@ fn integrity_constraint_with_constants() { enf clk' = clk + A + B[1] - C[1][2]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -68,14 +66,5 @@ fn invalid_matrix_constant() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "invalid matrix literal: mismatched dimensions", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "invalid matrix literal: mismatched dimensions", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "invalid matrix literal: mismatched dimensions"); } diff --git a/air/src/tests/evaluators.rs b/air/src/tests/evaluators.rs index 261e6a36e..52e7e4461 100644 --- a/air/src/tests/evaluators.rs +++ b/air/src/tests/evaluators.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile}; +use super::compile_from_source; #[test] fn simple_evaluator() { @@ -24,8 +24,7 @@ fn simple_evaluator() { enf advance_clock([clk]); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -53,8 +52,7 @@ fn evaluator_with_variables() { enf advance_clock([clk]); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -85,8 +83,7 @@ fn ev_call_inside_evaluator_with_main() { enf enforce_all_constraints([clk]); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -113,6 +110,5 @@ fn ev_fn_call_with_column_group() { enf clk_selectors([s, clk]); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } diff --git a/air/src/tests/integrity_constraints/comprehension/constraint_comprehension.rs b/air/src/tests/integrity_constraints/comprehension/constraint_comprehension.rs index 81c47affd..bdef8f693 100644 --- a/air/src/tests/integrity_constraints/comprehension/constraint_comprehension.rs +++ b/air/src/tests/integrity_constraints/comprehension/constraint_comprehension.rs @@ -1,4 +1,4 @@ -use super::super::{Pipeline, compile}; +use super::super::compile_from_source; #[test] fn constraint_comprehension() { @@ -17,8 +17,7 @@ fn constraint_comprehension() { enf c = d for (c, d) in (c, d); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -38,6 +37,5 @@ fn ic_comprehension_with_selectors() { enf c = d for (c, d) in (c, d) when !fmp[0]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } diff --git a/air/src/tests/integrity_constraints/comprehension/list_comprehension.rs b/air/src/tests/integrity_constraints/comprehension/list_comprehension.rs index 501cc979a..9925b0bab 100644 --- a/air/src/tests/integrity_constraints/comprehension/list_comprehension.rs +++ b/air/src/tests/integrity_constraints/comprehension/list_comprehension.rs @@ -1,4 +1,4 @@ -use super::super::{Pipeline, compile, expect_diagnostic}; +use super::super::{compile_from_source, expect_diagnostic}; #[test] fn list_comprehension() { @@ -18,8 +18,7 @@ fn list_comprehension() { enf clk = x[1]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -41,8 +40,7 @@ fn lc_with_const_exp() { enf clk = y[1] + z[1]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -63,16 +61,7 @@ fn lc_with_non_const_exp() { enf clk = enumerate[3]; }"; - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected exponent to be a constant"); } #[test] @@ -93,8 +82,7 @@ fn lc_with_two_lists() { enf clk = diff[0]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -115,8 +103,7 @@ fn lc_with_two_slices() { enf clk = diff[1]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -137,8 +124,7 @@ fn lc_with_multiple_lists() { enf a = x[0] + x[1] + x[2]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -160,16 +146,7 @@ fn err_index_out_of_range_lc_ident() { enf clk = x[2]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -192,16 +169,7 @@ fn err_index_out_of_range_lc_slice() { enf clk = x[3]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -224,16 +192,7 @@ fn err_non_const_exp_ident_iterable() { enf clk = invalid_exp_lc[1]; }"; - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected exponent to be a constant"); } #[test] @@ -256,16 +215,7 @@ fn err_non_const_exp_slice_iterable() { enf clk = invalid_exp_lc[1]; }"; - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected exponent to be a constant"); } #[test] @@ -288,14 +238,5 @@ fn err_duplicate_member() { enf clk = duplicate_member_lc[1]; }"; - expect_diagnostic( - source, - "this name is already bound in this comprehension", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "this name is already bound in this comprehension", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "this name is already bound in this comprehension"); } diff --git a/air/src/tests/integrity_constraints/mod.rs b/air/src/tests/integrity_constraints/mod.rs index 33d87b7c8..0b3702f30 100644 --- a/air/src/tests/integrity_constraints/mod.rs +++ b/air/src/tests/integrity_constraints/mod.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile, expect_diagnostic}; +use super::{compile_from_source, expect_diagnostic}; mod comprehension; @@ -19,8 +19,7 @@ fn integrity_constraints() { enf clk' = clk + 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -40,8 +39,7 @@ fn ic_using_parens() { enf clk' = (clk + 1); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -61,8 +59,7 @@ fn ic_op_mul() { enf clk' * clk = 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -82,8 +79,7 @@ fn ic_op_exp() { enf clk'^2 - clk = 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -104,14 +100,5 @@ fn err_non_const_exp_outside_lc() { enf clk = 2^ctx; }"; - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected exponent to be a constant", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected exponent to be a constant"); } diff --git a/air/src/tests/list_folding.rs b/air/src/tests/list_folding.rs index 9f963029e..fa55d6b42 100644 --- a/air/src/tests/list_folding.rs +++ b/air/src/tests/list_folding.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile}; +use super::compile_from_source; #[test] fn list_folding_on_const() { @@ -20,8 +20,7 @@ fn list_folding_on_const() { enf clk = y - x; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -44,8 +43,7 @@ fn list_folding_on_variable() { enf clk = z - y; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -67,8 +65,7 @@ fn list_folding_on_vector() { enf clk = y - x; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -91,8 +88,7 @@ fn list_folding_on_lc() { enf clk = y - x; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -114,6 +110,5 @@ fn list_folding_in_lc() { enf clk = y[0]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } diff --git a/air/src/tests/mod.rs b/air/src/tests/mod.rs index ec275220e..4f717c80a 100644 --- a/air/src/tests/mod.rs +++ b/air/src/tests/mod.rs @@ -12,38 +12,32 @@ mod source_sections; mod trace; mod variables; -pub use crate::CompileError; - use std::sync::Arc; -use air_pass::Pass; use miden_diagnostics::{CodeMap, DiagnosticsConfig, DiagnosticsHandler, Verbosity}; -#[derive(Clone, Copy, Debug)] -pub enum Pipeline { - WithMIR, - WithoutMIR, -} +pub use crate::CompileError; +use crate::compile; -pub fn compile(source: &str, pipeline: Pipeline) -> Result { +pub fn compile_from_source(source: &str) -> Result { let compiler = Compiler::default(); - match compiler.compile(source, pipeline) { + match compiler.compile(source) { Ok(air) => Ok(air), Err(err) => { compiler.diagnostics.emit(err); compiler.emitter.print_captured_to_stderr(); Err(()) - } + }, } } #[track_caller] -pub fn expect_diagnostic(source: &str, expected: &str, pipeline: Pipeline) { +pub fn expect_diagnostic(source: &str, expected: &str) { let compiler = Compiler::default(); - let err = match compiler.compile(source, pipeline) { + let err = match compiler.compile(source) { Ok(ref ast) => { panic!("expected compilation to fail, got {ast:#?}"); - } + }, Err(err) => err, }; compiler.diagnostics.emit(err); @@ -51,10 +45,7 @@ pub fn expect_diagnostic(source: &str, expected: &str, pipeline: Pipeline) { if !found { compiler.emitter.print_captured_to_stderr(); } - assert!( - found, - "With pipeline {pipeline:?}, expected diagnostic output to contain the string: '{expected}'" - ); + assert!(found, "expected diagnostic output to contain the string: '{expected}'"); } struct Compiler { @@ -76,45 +67,16 @@ impl Compiler { pub fn new(config: DiagnosticsConfig) -> Self { let codemap = Arc::new(CodeMap::new()); let emitter = Arc::new(SplitEmitter::new()); - let diagnostics = Arc::new(DiagnosticsHandler::new( - config, - codemap.clone(), - emitter.clone(), - )); + let diagnostics = + Arc::new(DiagnosticsHandler::new(config, codemap.clone(), emitter.clone())); - Self { - codemap, - emitter, - diagnostics, - } + Self { codemap, emitter, diagnostics } } - pub fn compile(&self, source: &str, pipeline: Pipeline) -> Result { - match pipeline { - Pipeline::WithMIR => air_parser::parse(&self.diagnostics, self.codemap.clone(), source) - .map_err(CompileError::Parse) - .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&self.diagnostics) - .chain(mir::passes::AstToMir::new(&self.diagnostics)) - .chain(mir::passes::Inlining::new(&self.diagnostics)) - .chain(mir::passes::Unrolling::new(&self.diagnostics)) - .chain(crate::passes::MirToAir::new(&self.diagnostics)) - .chain(crate::passes::BusOpExpand::new(&self.diagnostics)); - pipeline.run(ast) - }), - Pipeline::WithoutMIR => { - air_parser::parse(&self.diagnostics, self.codemap.clone(), source) - .map_err(CompileError::Parse) - .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&self.diagnostics) - .chain(air_parser::transforms::Inlining::new(&self.diagnostics)) - .chain(crate::passes::AstToAir::new(&self.diagnostics)); - pipeline.run(ast) - }) - } - } + pub fn compile(&self, source: &str) -> Result { + air_parser::parse(&self.diagnostics, self.codemap.clone(), source) + .map_err(CompileError::Parse) + .and_then(|program| compile(&self.diagnostics, program)) } } @@ -138,9 +100,10 @@ impl SplitEmitter { } pub fn print_captured_to_stderr(&self) { - use miden_diagnostics::Emitter; use std::io::Write; + use miden_diagnostics::Emitter; + let mut copy = self.default.buffer(); let captured = self.capture.captured(); copy.write_all(captured.as_bytes()).unwrap(); diff --git a/air/src/tests/pub_inputs.rs b/air/src/tests/pub_inputs.rs index 00744b40b..3b49e551a 100644 --- a/air/src/tests/pub_inputs.rs +++ b/air/src/tests/pub_inputs.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile}; +use super::compile_from_source; #[test] fn bc_with_public_inputs() { @@ -17,6 +17,5 @@ fn bc_with_public_inputs() { enf clk' = clk - 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } diff --git a/air/src/tests/selectors.rs b/air/src/tests/selectors.rs index 094785d9f..12e7c11b2 100644 --- a/air/src/tests/selectors.rs +++ b/air/src/tests/selectors.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile}; +use super::compile_from_source; #[test] fn single_selector() { @@ -18,8 +18,7 @@ fn single_selector() { enf clk' = clk when s[0]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -40,8 +39,7 @@ fn chained_selectors() { enf clk' = clk when (s[0] & !s[1]) | !s[2]'; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -68,8 +66,7 @@ fn multiconstraint_selectors() { }; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -96,8 +93,7 @@ fn selectors_in_evaluators() { enf evaluator_with_selector([s[0], clk]); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -124,8 +120,7 @@ fn multiple_selectors_in_evaluators() { enf evaluator_with_selector([s[0], s[1], clk]); }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -152,8 +147,7 @@ fn selector_with_evaluator_call() { enf unchanged([clk]) when s[0] & !s[1]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -192,6 +186,108 @@ fn selectors_inside_match() { }; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); +} + +/// This test ensures that nested selectors are well handled during compilation by the +/// MatchOptimizer pass. +#[test] +fn selectors_nested() { + let source = " + def test + trace_columns { + main: [s[3], a, b, c], + } + + public_inputs { + stack_inputs: [1], + } + + boundary_constraints { + enf c.first = 0; + } + + # Simple evaluator functions + ev ev_dummy_0([b]) { + enf b' = b + 1; + } + + ev ev_dummy_1([b]) { + enf b' = b; + } + + # Evaluator functions with match statements + ev ev_match_0([s, b]) { + enf match { + case s: ev_dummy_0([b]), + case !s: ev_dummy_1([b]), + }; + } + + ev ev_match_1([s, b]) { + # Here we invert the cases + enf match { + case s: ev_dummy_1([b]), + case !s: ev_dummy_0([b]), + }; + } + + # Evaluator functions with nested match statements + ev ev_nested_0([s0, s1, b]) { + enf match { + case s0: ev_match_0([s1, b]), + case !s0: ev_match_1([s1, b]), + }; + } + + ev ev_nested_1([s0, s1, b]) { + # Here we invert the cases + enf match { + case s0: ev_match_1([s1, b]), + case !s0: ev_match_0([s1, b]), + }; + } + + # Evaluator with doubly-nested match statements + ev ev_s1([s, a, b, c]) { + enf a * (a - 1) = 0; + enf b = 31; + enf c = 5; + + # This creates Vector nodes with Enf operations + enf match { + case a: ev_nested_0([a, s, b]), + case !a: ev_nested_1([a, s, b]), + }; + } + + # Other evaluators for the outer match statement + ev ev_s0([a, b, c]) { + enf a' = a + 1; + enf b' = b + 1; + enf c' = c + 1; + } + + ev ev_s2([a, b, c]) { + enf a' = a + 2; + enf b' = b + 2; + enf c' = c + 2; + } + + # Main constraint with nested evaluator calls + integrity_constraints { + let s0 = s[0]; + let s1 = !s[0] & s[1]; + let s2 = !s[0] & !s[1]; + + # This pattern creates the problematic Vector structures + # if we don't flatten the constraints correctly + enf match { + case s0: ev_s0([a, b, c]), + case s1: ev_s1([s[2], a, b, c]), # ← This call creates deeply-nested Vector->Enf structures + case s2: ev_s2([a, b, c]), + }; + }"; + + assert!(compile_from_source(source).is_ok()); } diff --git a/air/src/tests/source_sections.rs b/air/src/tests/source_sections.rs index addb25323..50d351695 100644 --- a/air/src/tests/source_sections.rs +++ b/air/src/tests/source_sections.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, expect_diagnostic}; +use super::expect_diagnostic; #[test] fn err_trace_cols_empty() { @@ -13,16 +13,7 @@ fn err_trace_cols_empty() { integrity_constraints { enf clk' = clk + 1"; - expect_diagnostic( - source, - "missing 'main' declaration in this section", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "missing 'main' declaration in this section", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "missing 'main' declaration in this section"); } #[test] @@ -40,12 +31,7 @@ fn err_trace_cols_omitted() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "missing trace_columns section", - Pipeline::WithoutMIR, - ); - expect_diagnostic(source, "missing trace_columns section", Pipeline::WithMIR); + expect_diagnostic(source, "missing trace_columns section"); } #[test] @@ -64,12 +50,7 @@ fn err_pub_inputs_empty() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "expected one of: 'identifier'", - Pipeline::WithoutMIR, - ); - expect_diagnostic(source, "expected one of: 'identifier'", Pipeline::WithMIR); + expect_diagnostic(source, "expected one of: 'identifier'"); } #[test] @@ -87,16 +68,7 @@ fn err_pub_inputs_omitted() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "root module must contain a public_inputs section", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "root module must contain a public_inputs section", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "root module must contain a public_inputs section"); } #[test] @@ -115,16 +87,7 @@ fn err_bc_empty() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "expected one of: '\"enf\"', '\"let\"'", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected one of: '\"enf\"', '\"let\"'", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected one of: '\"enf\"', '\"let\"'"); } #[test] @@ -145,12 +108,6 @@ fn err_bc_omitted() { expect_diagnostic( source, "root module must contain both boundary_constraints and integrity_constraints sections", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "root module must contain both boundary_constraints and integrity_constraints sections", - Pipeline::WithMIR, ); } @@ -170,16 +127,7 @@ fn err_ic_empty() { } integrity_constraints {}"; - expect_diagnostic( - source, - "expected one of: '\"enf\"', '\"let\"'", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected one of: '\"enf\"', '\"let\"'", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected one of: '\"enf\"', '\"let\"'"); } #[test] @@ -200,11 +148,5 @@ fn err_ic_omitted() { expect_diagnostic( source, "root module must contain both boundary_constraints and integrity_constraints sections", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "root module must contain both boundary_constraints and integrity_constraints sections", - Pipeline::WithMIR, ); } diff --git a/air/src/tests/trace.rs b/air/src/tests/trace.rs index ed58e2f08..8bb69e9c1 100644 --- a/air/src/tests/trace.rs +++ b/air/src/tests/trace.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile, expect_diagnostic}; +use super::{compile_from_source, expect_diagnostic}; #[test] fn trace_columns_index_access() { @@ -17,8 +17,7 @@ fn trace_columns_index_access() { enf $main[0]' - $main[1] = 0; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -42,8 +41,7 @@ fn trace_cols_groups() { enf a[0]' = a[1] - 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -64,16 +62,7 @@ fn err_bc_column_undeclared() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] @@ -93,16 +82,7 @@ fn err_ic_column_undeclared() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] @@ -126,16 +106,7 @@ fn err_bc_trace_cols_access_out_of_bounds() { enf a[0]' = a[0] - 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -160,16 +131,7 @@ fn err_ic_trace_cols_access_out_of_bounds() { enf a[4]' = a[4] - 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -189,6 +151,5 @@ fn err_ic_trace_cols_group_used_as_scalar() { enf a[0]' = a + clk; }"; - expect_diagnostic(source, "type mismatch", Pipeline::WithoutMIR); - expect_diagnostic(source, "type mismatch", Pipeline::WithMIR); + expect_diagnostic(source, "type mismatch"); } diff --git a/air/src/tests/variables.rs b/air/src/tests/variables.rs index 765adcc97..a6dd8d05f 100644 --- a/air/src/tests/variables.rs +++ b/air/src/tests/variables.rs @@ -1,4 +1,4 @@ -use super::{Pipeline, compile, expect_diagnostic}; +use super::{compile_from_source, expect_diagnostic}; #[test] fn let_scalar_constant_in_boundary_constraint() { @@ -18,8 +18,7 @@ fn let_scalar_constant_in_boundary_constraint() { enf clk' = clk + 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -40,8 +39,7 @@ fn let_vector_constant_in_boundary_constraint() { enf clk' = clk + 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -66,8 +64,7 @@ fn multi_constraint_nested_let_with_expressions_in_boundary_constraint() { enf clk' = clk + 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -89,8 +86,7 @@ fn let_scalar_constant_in_boundary_constraint_both_domains() { enf clk' = clk + 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -112,16 +108,7 @@ fn invalid_column_offset_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "invalid access of a trace column with offset", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "invalid access of a trace column with offset", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "invalid access of a trace column with offset"); } #[test] @@ -145,8 +132,7 @@ fn nested_let_with_expressions_in_integrity_constraint() { enf c[0][0] = 1; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -172,8 +158,7 @@ fn nested_let_with_vector_access_in_integrity_constraint() { enf clk' = c[0] + e[2][0] + e[0][1]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } #[test] @@ -201,12 +186,6 @@ fn invalid_matrix_literal_with_leading_vector_binding() { expect_diagnostic( source, "expected one of: '\"!\"', '\"(\"', '\"null\"', '\"unconstrained\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected one of: '\"!\"', '\"(\"', '\"null\"', '\"unconstrained\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'", - Pipeline::WithMIR, ); } @@ -232,8 +211,7 @@ fn invalid_matrix_literal_with_trailing_vector_binding() { enf clk' = d[0][0]; }"; - expect_diagnostic(source, "expected one of: '\"[\"'", Pipeline::WithoutMIR); - expect_diagnostic(source, "expected one of: '\"[\"'", Pipeline::WithMIR); + expect_diagnostic(source, "expected one of: '\"[\"'"); } #[test] @@ -256,16 +234,7 @@ fn invalid_variable_access_before_declaration() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] @@ -288,16 +257,7 @@ fn invalid_trailing_let() { let a = 1; }"; - expect_diagnostic( - source, - "expected one of: '\"enf\"', '\"let\"'", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "expected one of: '\"enf\"', '\"let\"'", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "expected one of: '\"enf\"', '\"let\"'"); } #[test] @@ -320,16 +280,7 @@ fn invalid_reference_to_variable_defined_in_other_section() { enf clk' = clk + a; }"; - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "this variable / bus is not defined", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] @@ -352,16 +303,7 @@ fn invalid_vector_variable_access_out_of_bounds() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -383,16 +325,7 @@ fn invalid_matrix_column_variable_access_out_of_bounds() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -414,16 +347,7 @@ fn invalid_matrix_row_variable_access_out_of_bounds() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -447,16 +371,7 @@ fn invalid_index_into_scalar_variable() { enf clk' = clk + a[0]; }"; - expect_diagnostic( - source, - "attempted to index into a scalar value", - Pipeline::WithoutMIR, - ); - expect_diagnostic( - source, - "attempted to index into a scalar value", - Pipeline::WithMIR, - ); + expect_diagnostic(source, "attempted to index into a scalar value"); } #[test] @@ -480,6 +395,5 @@ fn trace_binding_access_in_integrity_constraint() { enf clk' = clk + a[0]; }"; - assert!(compile(source, Pipeline::WithoutMIR).is_ok()); - assert!(compile(source, Pipeline::WithMIR).is_ok()); + assert!(compile_from_source(source).is_ok()); } diff --git a/codegen/ace/Cargo.toml b/codegen/ace/Cargo.toml index 1d829bca3..7c669d029 100644 --- a/codegen/ace/Cargo.toml +++ b/codegen/ace/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-codegen-ace" -version = "0.4.0" +version = "0.5.0" description = "Code generator from AirScript to the ACE chiplet of Miden's recursive verifier." authors.workspace = true readme = "README.md" @@ -12,15 +12,15 @@ edition.workspace = true rust-version.workspace = true [dependencies] -air-ir = { package = "air-ir", path = "../../air", version = "0.4" } +air-ir = { package = "air-ir", path = "../../air", version = "0.5" } anyhow = { workspace = true } miden-core = { package = "miden-core", version = "0.13", default-features = false } +mir = { package = "air-mir", path = "../../mir" } winter-math = { package = "winter-math", version = "0.12", default-features = false } [dev-dependencies] air-parser = { package = "air-parser", path = "../../parser" } air-pass = { package = "air-pass", path = "../../pass" } miden-diagnostics = { workspace = true } -mir = { package = "air-mir", path = "../../mir" } rand = "0.9" winter-utils = { version = "0.12", package = "winter-utils" } diff --git a/codegen/ace/README.md b/codegen/ace/README.md new file mode 100644 index 000000000..331563b99 --- /dev/null +++ b/codegen/ace/README.md @@ -0,0 +1,58 @@ +# ACE Code Generator + +This crate contains a code generator targeting Miden VM's ACE (Arithmetic Circuit Evaluation) chiplet. + +The purpose of this code generator is to convert a provided `Air` representation into arithmetic circuits that can be efficiently evaluated by the ACE chiplet for recursive STARK proof verification within Miden assembly programs. + +## Generating the ACE Circuit + +Generate an ACE circuit from an `Air` (AirScript's intermediate representation) by calling `build_ace_circuit`. The function will return the root node and complete circuit. + +The circuit builder processes constraints in three groups (integrity, boundary-first, and boundary-last constraints), combines them using powers of a random challenge `α`, and creates a circuit that evaluates the formula: + +``` +z₋₂²⋅z₋₁⋅z₀⋅int + zₙ⋅z₋₂⋅bf + zₙ⋅z₀⋅bl - Q(z)⋅zₙ⋅z₀⋅z₋₂ = 0 +``` + +Example usage: + +```rust +use air_codegen_ace::{build_ace_circuit, AceVars}; +use air_ir::{Air, compile}; + +// Parse AirScript source and compile to AIR +let air = compile(&diagnostics, parsed_program)?; + +// Build ACE circuit +let (root_node, circuit) = build_ace_circuit(&air)?; + +// Collect inputs to circuit +// let air_inputs = ... + +// Prepare inputs for evaluation +let ace_vars = AceVars::from_air_inputs(air_inputs, &air); +let memory_inputs = ace_vars.to_memory_vec(&circuit.layout); + +// Evaluate circuit (should be zero for valid constraints) +let result = circuit.eval(root_node, &memory_inputs); +assert_eq!(result, QuadFelt::ZERO); + +// Encode for chiplet consumption +let encoded = circuit.to_ace(); +``` + +## ACE Circuit + +The ACE circuit represents constraints as a directed acyclic graph (DAG) with three types of nodes: +- `Input(usize)`: Variables at which the circuit is evaluated +- `Constant(usize)`: Fixed values stored in the circuit description +- `Operation(usize)`: Results of arithmetic operations (subtraction, multiplication, addition) + +The circuit includes: +- **Layout** (`AirLayout`): Defines the memory layout of inputs expected by the ACE chiplet +- **Encoded format** (`EncodedAceCircuit`): Serialized representation for chiplet consumption +- **Evaluation**: Circuits can be evaluated at given inputs and visualized as DOT graphs for debugging + +## References + +- [ACE Chiplet Documentation](https://0xmiden.github.io/miden-vm/design/chiplets/ace.html) diff --git a/codegen/ace/src/builder.rs b/codegen/ace/src/builder.rs index f0101c664..9a3490de9 100644 --- a/codegen/ace/src/builder.rs +++ b/codegen/ace/src/builder.rs @@ -1,10 +1,14 @@ -use crate::circuit::{ArithmeticOp, Circuit, Node, OperationNode}; -use crate::layout::{Layout, StarkVar}; +use std::collections::BTreeMap; + use air_ir::{ Air, NodeIndex, Operation as AirOperation, PeriodicColumnAccess, QualifiedIdentifier, Value, }; use miden_core::Felt; -use std::collections::BTreeMap; + +use crate::{ + circuit::{ArithmeticOp, Circuit, Node, OperationNode}, + layout::{Layout, StarkVar}, +}; /// [`CircuitBuilder`] is the only way to build a [`Circuit`]. It guarantees the following /// properties: @@ -80,13 +84,13 @@ impl CircuitBuilder { ArithmeticOp::Add => c_l + c_r, }; self.constant(c.as_int()) - } + }, // Store new `Operation` node _ => { let index = self.operations.len(); self.operations.push(operation); Node::Operation(index) - } + }, }; // Cache the operation node for future use. @@ -125,39 +129,41 @@ impl CircuitBuilder { let node = match air_op { AirOperation::Value(v) => match v { Value::Constant(c) => self.constant(*c), - Value::TraceAccess(access) => self - .layout - .trace_access_node(access) - .expect("invalid trace access"), - Value::PeriodicColumn(access) => self - .periodic_column(air, access) - .expect("invalid periodic column access"), + Value::TraceAccess(access) => { + self.layout.trace_access_node(access).expect("invalid trace access") + }, + Value::PeriodicColumn(access) => { + self.periodic_column(air, access).expect("invalid periodic column access") + }, Value::PublicInput(pi) => self.layout.public_inputs[&pi.name] .as_node(pi.index) .expect("invalid public input access"), - Value::RandomValue(idx) => self - .layout - .random_values - .as_node(*idx) - .expect("invalid random value index"), + Value::PublicInputTable(access) => { + let idx = self.layout.reduced_tables[access]; + self.layout + .reduced_tables_region + .as_node(idx) + .expect("invalid public input table access") + }, + Value::RandomValue(idx) => self.random(*idx), }, AirOperation::Add(l_idx, r_idx) => { let node_l = self.node_from_index(air, l_idx); let node_r = self.node_from_index(air, r_idx); self.add(node_l, node_r) - } + }, AirOperation::Sub(l_idx, r_idx) => { let node_l = self.node_from_index(air, l_idx); let node_r = self.node_from_index(air, r_idx); self.sub(node_l, node_r) - } + }, AirOperation::Mul(l_idx, r_idx) => { let node_l = self.node_from_index(air, l_idx); let node_r = self.node_from_index(air, r_idx); self.mul(node_l, node_r) - } + }, }; - self.air_node_cache.insert(*air_op, node); + self.air_node_cache.insert(air_op.clone(), node); node } @@ -252,14 +258,14 @@ impl CircuitBuilder { .unwrap_or_else(|| self.constant(0)) } - /// Returns a [`Node`] corresponding to the evaluation of the `periodic_column` at the appropriate - /// power of `z`. The evaluation is cached to avoid unnecessary computation. + /// Returns a [`Node`] corresponding to the evaluation of the `periodic_column` at the + /// appropriate power of `z`. The evaluation is cached to avoid unnecessary computation. fn periodic_column( &mut self, air: &Air, periodic_column: &PeriodicColumnAccess, ) -> Option { - let ident = periodic_column.name; + let ident = periodic_column.name.clone(); // Check if we have already computed this column's value if let Some(node) = self.periodic_columns_cache.get(&ident) { @@ -269,12 +275,7 @@ impl CircuitBuilder { let periodic_column = &air.periodic_columns.get(&ident)?; // Maximum length of all periodic columns in the air. - let max_col_len = air - .periodic_columns - .values() - .map(|col| col.values.len()) - .max() - .unwrap(); + let max_col_len = air.periodic_columns.values().map(|col| col.values.len()).max().unwrap(); // The power of `z` for the longest column. Let `k` such that `z_max_col = z^k`, // where `k = trace_len / max_cycle_len` @@ -293,17 +294,11 @@ impl CircuitBuilder { // Interpolate the values of the column, converting the resulting coefficients // to constant nodes let poly_nodes: Vec<_> = { - let mut column: Vec<_> = periodic_column - .values - .iter() - .map(|val| Felt::new(*val)) - .collect(); + let mut column: Vec<_> = + periodic_column.values.iter().map(|val| Felt::new(*val)).collect(); let inv_twiddles = winter_math::fft::get_inv_twiddles::(column.len()); winter_math::fft::interpolate_poly(&mut column, &inv_twiddles); - column - .into_iter() - .map(|coeff| self.constant(coeff.as_int())) - .collect() + column.into_iter().map(|coeff| self.constant(coeff.as_int())).collect() }; // Evaluate the polynomial at z_col @@ -313,6 +308,24 @@ impl CircuitBuilder { self.periodic_columns_cache.insert(ident, result); Some(result) } + + /// Returns a [`Node`] corresponding to the random challenge at the given index. + /// We assume that the challenges are `[α, 1, β, β², β³, … ]`. + fn random(&mut self, index: usize) -> Node { + if index == 0 { + return self.layout.random_alpha_node(); + } + let mut beta_power = index - 1; + let beta_base = self.layout.random_beta_node(); + let mut beta = self.constant(1); + + while beta_power > 0 { + beta = self.mul(beta_base, beta); + beta_power -= 1; + } + + beta + } } /// Computes a linear combination with the powers of a random challenge alpha @@ -326,10 +339,7 @@ pub struct LinearCombination { impl LinearCombination { pub fn new(alpha: Node) -> Self { - Self { - alpha, - prev_alpha: None, - } + Self { alpha, prev_alpha: None } } /// Returns the linear combination of the [`Node`]s returned by the `els` [`Iterator`] with the @@ -434,16 +444,10 @@ mod tests { let res = cb.add(res_1, res_2); let alpha = QuadFelt::from(Felt::new(5u64)); - let coeffs: Vec<_> = (0..6) - .map(|i| QuadFelt::from(Felt::new(1 + i as u64))) - .collect(); + let coeffs: Vec<_> = (0..6).map(|i| QuadFelt::from(Felt::new(1 + i as u64))).collect(); - let result_expected = coeffs - .iter() - .rev() - .copied() - .reduce(|acc, coeff| acc * alpha + coeff) - .unwrap(); + let result_expected = + coeffs.iter().rev().copied().reduce(|acc, coeff| acc * alpha + coeff).unwrap(); let circuit = cb.into_ace_circuit(); let inputs: Vec<_> = [alpha].into_iter().chain(coeffs).collect(); diff --git a/codegen/ace/src/circuit.rs b/codegen/ace/src/circuit.rs index fe6ce52ad..97d2540e3 100644 --- a/codegen/ace/src/circuit.rs +++ b/codegen/ace/src/circuit.rs @@ -1,8 +1,9 @@ -use crate::QuadFelt; -use crate::layout::Layout; -use miden_core::Felt; use std::collections::BTreeMap; +use miden_core::Felt; + +use crate::{QuadFelt, layout::Layout}; + /// One of the 3 arithmetic operations supported by the ACE chiplet. #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub enum ArithmeticOp { @@ -18,7 +19,7 @@ pub enum Node { Input(usize), /// Index of a leaf node stored in the circuit description. Constant(usize), - /// Index of a non-leaf node representing the result of an [`ArithmeticOp`] applied + /// Index of a non-leaf node representing the result of an `ArithmeticOp` applied /// to two other [`Node`]s. Operation(usize), } @@ -31,8 +32,9 @@ pub struct OperationNode { } /// A circuit that can be consumed by the ACE chiplet. +/// /// The only way to build a circuit is through the CircuitBuilder, it can then be obtained -/// from [`CircuitBuilder::normalize`] and serialized to felts with [`Circuit::to_felts`]. +/// from `CircuitBuilder::normalize` and serialized to felts with `Circuit::to_elements`. #[derive(Clone, Debug, PartialEq)] pub struct Circuit { pub layout: Layout, @@ -41,7 +43,7 @@ pub struct Circuit { } impl Circuit { - /// Evaluates to a [`Quad`] the index `root`, given a vector of inputs to the circuit. + /// Evaluates to a `Quad` the index `root`, given a vector of inputs to the circuit. pub fn eval(&self, node: Node, inputs: &[QuadFelt]) -> QuadFelt { let mut evals: BTreeMap = BTreeMap::new(); // Insert inputs nodes with given values diff --git a/codegen/ace/src/dot.rs b/codegen/ace/src/dot.rs index 830c6901b..7f51f9725 100644 --- a/codegen/ace/src/dot.rs +++ b/codegen/ace/src/dot.rs @@ -1,10 +1,13 @@ -use crate::circuit::{ArithmeticOp, Circuit, Node, OperationNode}; -use crate::layout::StarkVar; use std::fmt::{Display, Write}; +use crate::{ + circuit::{ArithmeticOp, Circuit, Node, OperationNode}, + layout::StarkVar, +}; + impl Circuit { /// Serialization to Graphviz Dot format for debugging purposes. Display on - /// https://dreampuf.github.io/GraphvizOnline or using `dot -Tsvg tests/regressions/0.dot > 0.svg` + /// or using `dot -Tsvg tests/regressions/0.dot > 0.svg` pub fn to_dot(&self) -> Result { let mut f = String::new(); writeln!(f, "digraph G {{")?; @@ -25,9 +28,8 @@ impl Circuit { } // Random values - for (idx, node) in self.layout.random_values.iter_nodes().enumerate() { - writeln!(f, "{node} [label=\"R[{idx}]\"]")?; - } + writeln!(f, "{} [label=\"α\"]", self.layout.random_alpha_node())?; + writeln!(f, "{} [label=\"β\"]", self.layout.random_beta_node())?; // Main for (idx, node) in self.layout.trace_segments[0][0].iter_nodes().enumerate() { @@ -61,10 +63,7 @@ impl Circuit { for (op_idx, op_node) in self.operations.iter().enumerate() { let OperationNode { op, node_l, node_r } = op_node; let op_node = Node::Operation(op_idx); - writeln!( - f, - "{op_node} [label=\"{op_node}\\n{node_l} {op} {node_r}\"]" - )?; + writeln!(f, "{op_node} [label=\"{op_node}\\n{node_l} {op} {node_r}\"]")?; writeln!(f, "{node_l} -> {op_node}")?; writeln!(f, "{node_r} -> {op_node}")?; } @@ -93,13 +92,13 @@ impl Display for Node { match self { Node::Input(idx) => { write!(f, "input{idx}") - } + }, Node::Constant(idx) => { write!(f, "const{idx}") - } + }, Node::Operation(idx) => { write!(f, "op{idx}") - } + }, } } } diff --git a/codegen/ace/src/encoded.rs b/codegen/ace/src/encoded.rs index 10077fd83..22eee0310 100644 --- a/codegen/ace/src/encoded.rs +++ b/codegen/ace/src/encoded.rs @@ -1,21 +1,26 @@ -use crate::QuadFelt; -use crate::circuit::{ArithmeticOp, Circuit, Node, OperationNode}; -use miden_core::Felt; -use miden_core::crypto::hash::{Rpo256, RpoDigest}; +use miden_core::{ + Felt, + crypto::hash::{Rpo256, RpoDigest}, +}; use winter_math::FieldElement; +use crate::{ + QuadFelt, + circuit::{ArithmeticOp, Circuit, Node, OperationNode}, +}; + /// An encoded [`Circuit`] matching the required format for the ACE chiplet. /// The chiplet performs an evaluation by sequentially reading a region in memory with the following /// layout, where each region must be word-aligned. -/// - *Variables* correspond to all leaf nodes, stored two-by-two as extension field elements. -/// It is subdivided into the two following consecutive regions, which is achieved by padding -/// with zeros. +/// - *Variables* correspond to all leaf nodes, stored two-by-two as extension field elements. It is +/// subdivided into the two following consecutive regions, which is achieved by padding with +/// zeros. /// - *Inputs*: List of all inputs to the circuit. The internal layout is described in /// [`crate::AceVars::to_memory_vec`]. /// - *Constants*: Fixed values that can be referenced by instructions. /// - *Instructions*: List of arithmetic gates to be evaluated. Each instruction is encoded as a /// single field element. It is padded with instructions which square the output, as this still -/// ensures the final evaluation is still evaluates to zero. +/// ensures the final evaluation still evaluates to zero. pub struct EncodedCircuit { num_vars: usize, num_ops: usize, @@ -50,7 +55,7 @@ impl EncodedCircuit { /// Number of constant nodes. pub fn num_constants(&self) -> usize { - self.num_nodes() - self.num_ops + (self.instructions.len() - self.num_ops) / 2 } /// Number of nodes (variables and operations). @@ -63,6 +68,11 @@ impl EncodedCircuit { &self.instructions } + /// Returns the size of the encoded circuit in field elements. + pub fn size_in_felt(&self) -> usize { + self.instructions.len() + } + /// Returns the digest of the circuit, represented by the constants and instructions. pub fn circuit_hash(&self) -> RpoDigest { Rpo256::hash_elements(self.instructions()) @@ -97,34 +107,42 @@ impl Circuit { pub fn to_ace(&self) -> EncodedCircuit { const MAX_NODE_ID: u64 = (1 << 30) - 1; - assert!( - self.num_nodes() as u64 <= MAX_NODE_ID, - "more than 2^30 nodes" - ); + assert!(self.num_nodes() as u64 <= MAX_NODE_ID, "more than 2^30 nodes"); + + // Constants are encoded two-by-two as extension field elements, followed by operations, + // one per field element. + let num_const_nodes = self.constants.len().next_multiple_of(2); + let num_op_nodes = self.operations.len(); - // Constants are encoded two-by-two as extension field elements, followed by operations. - let num_const = self.constants.len().next_multiple_of(2); - let num_ops = self.operations.len().next_multiple_of(4); - let len_const = num_const * 2; - let len_circuit = len_const + num_ops; - let mut instructions = Vec::with_capacity(len_circuit); + let num_const_felts = num_const_nodes * 2; + let num_op_felts = num_op_nodes; + + // we need to have the size of the encoded circuit in field elements to be divisible + // by 8 so as to facilitate un-hashing in the VM + let len_circuit = num_const_felts + num_op_felts; + let len_circuit_padded = len_circuit.next_multiple_of(8); + + let mut instructions = Vec::with_capacity(len_circuit_padded); // Add constants - instructions.extend( - self.constants - .iter() - .flat_map(|c| QuadFelt::from(*c).to_base_elements()), - ); + instructions + .extend(self.constants.iter().flat_map(|c| QuadFelt::from(*c).to_base_elements())); // Since constants are treated as extension field elements, we pad this section with zeros // to ensure it is aligned in memory. - instructions.resize(len_const, Felt::ZERO); + instructions.resize(num_const_felts, Felt::ZERO); - let num_inputs = self.layout.num_inputs; + // we replace the padding above with squaring gates as these will not alter the final + // result, provided it is equal to zero. Note that this relies implicitly on the fact + // that fields do not have zero divisors + let num_input_nodes = self.layout.num_inputs; let num_constants = self.constants.len(); - let num_nodes = num_inputs + num_constants + num_ops; + let num_padding_felt = len_circuit_padded - len_circuit; + let num_padding_nodes = num_padding_felt; + let num_nodes = num_input_nodes + num_const_nodes + num_op_nodes + num_padding_nodes; + let node_id = |node: Node| -> u64 { let input_start = num_nodes - 1; - let constants_start = input_start - num_inputs; + let constants_start = input_start - num_input_nodes; let ops_start = constants_start - num_constants; match node { @@ -146,17 +164,16 @@ impl Circuit { let instruction = (id_0) + (id_1 * ID_1_OFFSET) + (op_tag * OP_OFFSET); Felt::new(instruction) }; + for operation in &self.operations { let encoded = operation_to_instruction(operation); instructions.push(encoded); } - // Since an ACE circuit's last node must evaluate to 0, we pad it with - // operations which square the last node. - // Each operation is encoded as a single field element, so the total number - // of operations must be a multiple of 4. + // Since an ACE circuit's last node must evaluate to 0, we add padding gates which + // square the output of the last node. let mut last_node_index = self.operations.len() - 1; - while instructions.len() % 4 != 0 { + while !instructions.len().is_multiple_of(8) { let last_node = Node::Operation(last_node_index); let dummy_op = OperationNode { op: ArithmeticOp::Mul, @@ -167,13 +184,11 @@ impl Circuit { instructions.push(encoded); last_node_index += 1; } + assert_eq!(instructions.len(), len_circuit_padded); - let num_vars = num_inputs + num_constants; - EncodedCircuit { - num_vars, - num_ops, - instructions, - } + let num_vars = num_input_nodes + num_constants; + let num_ops = num_op_nodes + num_padding_nodes; + EncodedCircuit { num_vars, num_ops, instructions } } /// Returns `true` when the circuit is properly padded, and each region is word-aligned. It @@ -181,19 +196,24 @@ impl Circuit { /// - Inputs and constants lie two-by-two in memory, treated as extension field elements, /// - Operations are encoded as single field elements. pub fn is_padded(&self) -> bool { - (self.layout.num_inputs % 2 == 0) - && (self.constants.len() % 2 == 0) - && (self.operations.len() % 4 == 0) + (self.layout.num_inputs.is_multiple_of(2)) + && (self.constants.len().is_multiple_of(2)) + && (self.operations.len().is_multiple_of(4)) } } #[cfg(test)] mod tests { - use super::*; - use crate::circuit::{ArithmeticOp, OperationNode}; - use crate::layout::{InputRegion, Layout}; use std::iter::zip; + use air_ir::FullTraceShape; + + use super::*; + use crate::{ + circuit::{ArithmeticOp, OperationNode}, + layout::{InputRegion, Layout}, + }; + /// Circuit evaluating `{[(i0 + 1) * i0] - 1}^2`, ensuring /// - All arithmetic operations are used /// - Padding is tested by taking @@ -206,41 +226,27 @@ mod tests { // and stark variables. let layout = Layout { public_inputs: Default::default(), - random_values: Default::default(), + reduced_tables_region: Default::default(), + reduced_tables: Default::default(), + random_alpha: 0, + random_beta: 0, trace_segments: [ - [ + FullTraceShape::new( // Main - InputRegion { - offset: 0, - width: 1, - }, + InputRegion { offset: 0, width: 1 }, // Aux - InputRegion { - offset: 1, - width: 0, - }, + InputRegion { offset: 1, width: 0 }, // Quotient - InputRegion { - offset: 1, - width: 0, - }, - ], - [ - InputRegion { - offset: 1, - width: 1, - }, + InputRegion { offset: 1, width: 0 }, + ), + FullTraceShape::new( + // Main + InputRegion { offset: 1, width: 1 }, // Aux - InputRegion { - offset: 2, - width: 0, - }, + InputRegion { offset: 2, width: 0 }, // Quotient - InputRegion { - offset: 2, - width: 0, - }, - ], + InputRegion { offset: 2, width: 0 }, + ), ], stark_vars: Default::default(), num_inputs: 2, @@ -279,13 +285,6 @@ mod tests { node_l: Node::Operation(1), // id = 2 node_r: one, // id = 5 }, - // Padding with squaring of last operation. - // // id = 0, {[(input + 1) * input] - 1}^2 - // OperationNode { - // op: ArithmeticOp::Mul, // op = 1 - // node_l: Node::Operation(1), // id = 1 - // node_r: Node::Operation(1), // id = 1 - // }, ], }; let encoded = circuit.to_ace(); diff --git a/codegen/ace/src/inputs.rs b/codegen/ace/src/inputs.rs index 355d8626c..ed6e09499 100644 --- a/codegen/ace/src/inputs.rs +++ b/codegen/ace/src/inputs.rs @@ -1,10 +1,14 @@ -use crate::QuadFelt; -use crate::layout::{InputRegion, Layout}; -use air_ir::Air; -use miden_core::Felt; use std::iter::zip; + +use air_ir::{Air, FullTraceShape}; +use miden_core::Felt; use winter_math::{FieldElement, StarkField}; +use crate::{ + QuadFelt, + layout::{InputRegion, Layout}, +}; + /// Set of all inputs required to perform the DEEP-ALI constraint evaluations check. /// Note that these should correspond to all values included in the proof transcript, /// and are ordered as such. @@ -14,14 +18,15 @@ pub struct AirInputs { pub log_trace_len: u32, /// Public inputs in the same order as [`Air::public_inputs`]. pub public: Vec>, - /// Evaluations of the *main* trace. - pub main: [Vec; 2], - /// Verifier challenges used to derive the *aux* trace. - pub rand: Vec, - /// Evaluations of the *aux* trace. - pub aux: [Vec; 2], - /// Evaluations of the *quotient* parts, including in the next row. - pub quotient: [Vec; 2], + /// Reduced public input table values used as boundaries for buses. + pub reduced_tables: Vec, + /// Evaluations of the segments in the order `main`, `aux`, `quotient`, + /// for the current row and next row. + pub segments: [FullTraceShape>; 2], + /// Verifier challenge α used to randomize the multi-set/logUp polynomials in the *aux* trace. + pub random_alpha: QuadFelt, + /// Verifier challenge β used to fingerprint bus messages for the *aux* trace. + pub random_beta: QuadFelt, /// Verifier challenge used to compute the linear combination of constraints. pub alpha: QuadFelt, /// Verifier challenge corresponding to the point at which the constraint evaluation check is @@ -33,8 +38,10 @@ pub struct AirInputs { #[derive(Clone, Debug)] pub struct AceVars { pub(crate) public: Vec>, - pub(crate) segments: [[Vec; 3]; 2], - pub(crate) rand: Vec, + pub(crate) reduced_tables: Vec, + pub(crate) segments: [FullTraceShape>; 2], + pub(crate) random_alpha: QuadFelt, + pub(crate) random_beta: QuadFelt, pub(crate) stark: StarkInputs, } @@ -43,17 +50,13 @@ impl AirInputs { /// the values that would be present in the proof's transcript. pub fn into_ace_vars(self, air: &Air) -> AceVars { let stark = StarkInputs::new(air, self.log_trace_len, self.alpha, self.z); - let [main_curr, main_next] = self.main; - let [aux_curr, aux_next] = self.aux; - let [quotient_curr, quotient_next] = self.quotient; - let segments = [ - [main_curr, aux_curr, quotient_curr], - [main_next, aux_next, quotient_next], - ]; + let segments = self.segments; AceVars { public: self.public, + reduced_tables: self.reduced_tables, segments, - rand: self.rand, + random_alpha: self.random_alpha, + random_beta: self.random_beta, stark, } } @@ -97,11 +100,7 @@ impl StarkInputs { let n = 1 << log_trace_len; let z_pow_n = z.exp_vartime(n); - let max_cycle_len = air - .periodic_columns - .values() - .map(|col| col.values.len() as u64) - .max(); + let max_cycle_len = air.periodic_columns.values().map(|col| col.values.len() as u64).max(); let z_max_cycle_pow = max_cycle_len.map(|cycle_len| n / cycle_len).unwrap_or(0); let z_max_cycle = z.exp_vartime(z_max_cycle_pow); @@ -118,12 +117,12 @@ impl StarkInputs { /// Returns all values as a `Vec` in the same order as [`crate::StarkVar`]. pub(crate) fn to_vec(&self) -> Vec { vec![ - self.gen_penultimate, - self.gen_last, self.alpha, self.z, self.z_pow_n, + self.gen_last, self.z_max_cycle, + self.gen_penultimate, ] } } @@ -145,15 +144,23 @@ impl AceVars { store(&mut mem, pi_region, inputs) } + // Reduced public input table values, ordered by accesses + for (index, reduced_table_value) in + zip(layout.reduced_tables.values(), &self.reduced_tables) + { + let mem_index = layout.reduced_tables_region.index(*index).unwrap(); + mem[mem_index] = *reduced_table_value; + } + // Random values - store(&mut mem, &layout.random_values, &self.rand); + mem[layout.random_alpha] = self.random_alpha; + mem[layout.random_beta] = self.random_beta; // Trace values - for row_offset in [0, 1] { - for (segment_row, region) in zip( - &self.segments[row_offset], - &layout.trace_segments[row_offset], - ) { + for (row_offset, _) in self.segments.iter().enumerate() { + for i in 0..3 { + let region = &layout.trace_segments[row_offset][i]; + let segment_row = &self.segments[row_offset][i]; store(&mut mem, region, segment_row.as_slice()); } } diff --git a/codegen/ace/src/layout.rs b/codegen/ace/src/layout.rs index 89a0180c0..53142b860 100644 --- a/codegen/ace/src/layout.rs +++ b/codegen/ace/src/layout.rs @@ -1,30 +1,56 @@ +use std::{collections::BTreeMap, ops::Range}; + +use air_ir::{ + Air, FullTraceShape, Identifier, PublicInput, PublicInputAccess, PublicInputTableAccess, + TraceAccess, +}; + use crate::circuit::Node; -use air_ir::{Air, Identifier, PublicInputAccess, TraceAccess}; -use std::collections::BTreeMap; -use std::ops::Range; -/// For each set of inputs read from the transcript, we treat them as extension field elements -/// and pad them with zeros to the next multiple of 4. They can then be unhashed to a double-word -/// aligned region in memory. -const HASH_ALIGNMENT: usize = 4; +/// Circuit inputs are represented as extension field elements and stored in a word-aligned region +/// in memory. Each region has specific alignment requirements dictated by the recursive verifier. +/// When the aligned region is larger than the actual number of inputs to the circuit, the +/// region can be padded with arbitrary values as these will not be accessed by the circuit. +/// In practice, we set these to zero. +enum Alignment { + Element = 1, + Word = 2, + DoubleWord = 4, + QuadWord = 8, +} const NUM_QUOTIENT_PARTS: usize = 8; /// Describes the layout of inputs given to an ACE circuit. /// Each set of variables is aligned to the next multiple of 4, ensuring they can be efficiently -/// unhashed from the transcript and that each input region is aligned to [`HASH_ALIGNMENT`]. +/// unhashed from the transcript and that each input region is aligned to `HASH_ALIGNMENT`. +/// An exception to this are the `public_inputs` input regions, which are padded to the next +/// multiple of 8. This is because, during recursive verification, we load (fixed) public inputs +/// in groups of 8 (base) field elements which are stored as 8 extension field elements. /// /// We assume the following about the underlying `Air` from which the layout is constructed /// - The proof always contains a `main` and `aux` segment, even when the latter is unused, /// - The maximal degree of an [`Air`] is `9`, such that the quotient can be decomposed in 8 chunks. -/// TODO(Issue: #391): Derive the degree generically. +/// TODO(Issue: #391): Derive the degree generically. #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct Layout { - /// Region for each set of public inputs, sorted by `Identifier` + /// Region for each set of public inputs, sorted by `Identifier`. + /// The arrays of inputs are laid out contiguously. + /// The last array is padded to ensure this entire region is double-word aligned. pub public_inputs: BTreeMap, - /// Region for auxiliary random inputs. - pub random_values: InputRegion, - /// Regions containing the evaluations of each segment, ordered by `trace[row_offset][segment]`. + /// Region containing the random-reduced public input tables for bus boundary constraints. + /// Each variable is word-aligned (interleaving with unused variables), + /// and the region is double-word aligned. + pub reduced_tables_region: InputRegion, + /// Index of a specific reduced table within the [`Self::reduced_tables_region`]. + pub reduced_tables: BTreeMap, + /// Index of the random challenge α used to randomize the multiset/logUp argument + /// in the *aux* trace. + pub random_alpha: usize, + /// Index of the random challenge β used to randomly reduce/fingerprint bus messages. + pub random_beta: usize, + /// Regions containing the evaluations of each segment, ordered by + /// `trace[row_offset][segment]`. /// /// # Detail: /// Note that we make the following assumptions which do not affect the evaluation of the @@ -34,43 +60,85 @@ pub struct Layout { /// - The [`Air`] from which the layout is derived can only contain a *main* and *aux* trace /// (the latter can be empty). /// - We treat the *quotient* as the third trace, handling it in the same way as the witness - /// traces. This requires the prover to provide the out-of-domain evaluations in - /// the *next* row of each quotient part. These are unused by the circuit. - /// - The rows must be ordered as follows: - /// ```ignore - /// main_curr, aux_curr, quotient_curr, main_next, aux_next, quotient_next. - /// ``` - /// - Each trace must be padded with zero columns such that each row is word-aligned. + /// traces. This requires the prover to provide the out-of-domain evaluations in the *next* + /// row of each quotient part. These are unused by the circuit. + /// - The rows must be ordered as follows: ```ignore main_curr, aux_curr, quotient_curr, + /// main_next, aux_next, quotient_next. ``` + /// - Each segment is double-word aligned to facilitate the Merkle-tree openings during the FRI + /// query phase. In practice, the traces are padded with empty columns. /// /// # TODO(Issue #391): /// The degree of the quotient is fixed to 8 matching the degree of the VM constraints, but /// the actual degree can be derived from the [`Air`]. - pub trace_segments: [[InputRegion; 3]; 2], - /// Index of the first auxiliary input describing variables + pub trace_segments: [FullTraceShape; 2], + /// Region containing the [`StarkVar`] variables. pub stark_vars: InputRegion, - /// Total number of inputs + /// Total number of inputs, padded to the next word-multiple. pub num_inputs: usize, } impl Layout { - /// Returns a new [`Layout`] from a description of an [`Air`]. All regions are padded according - /// to [`HASH_ALIGNMENT`], ensuring that each section starts at a word-aligned memory pointer. + /// Returns a new [`Layout`] from a description of an [`Air`]. + /// Each region is aligned according to the requirements of the MASM verifier. pub fn new(air: &Air) -> Self { - let mut inputs_offset = 0; + let offset = &mut 0; - fn next_region(current_offset: &mut usize, width: usize) -> InputRegion { + // Returns an `InputRegion` of a given width and increments the offset + // to satisfy the alignment. + fn allocate_region( + current_offset: &mut usize, + width: usize, + alignment: Alignment, + ) -> InputRegion { let offset = *current_offset; - *current_offset += width.next_multiple_of(HASH_ALIGNMENT); + *current_offset += width.next_multiple_of(alignment as usize); InputRegion { offset, width } } + fn align(offset: &mut usize, alignment: Alignment) { + *offset = offset.next_multiple_of(alignment as usize); + } + + // The arrays of all public inputs are stored contiguously. let public_inputs: BTreeMap<_, _> = air .public_inputs .iter() - .map(|(ident, pi)| (*ident, next_region(&mut inputs_offset, pi.size()))) + .filter_map(|(ident, pi)| { + if let PublicInput::Vector { .. } = pi { + Some((*ident, allocate_region(offset, pi.size(), Alignment::Element))) + } else { + None + } + }) + .collect(); + + // Ensure the entire region containing the public inputs is double-double-word aligned + // since it is hashed as one contiguous array and processed in batches of 8 inputs. + align(offset, Alignment::QuadWord); + + // List of all reduced public input table accesses in canonical order. + let reduced_table_accesses = air.reduced_public_input_table_accesses(); + + // Region containing all reduced public input table values. + // For MASM efficiency, we store one reduced table per word. + // Each variable therefore occupies two "variable slots". + let reduced_tables_region = + allocate_region(offset, 2 * reduced_table_accesses.len(), Alignment::Word); + + // Mapping of each access to its index within `reduced_tables_region` + // The index is doubled to match the "one variable per word" requirement. + let reduced_tables: BTreeMap<_, _> = reduced_table_accesses + .into_iter() + .enumerate() + .map(|(index, access)| (access, 2 * index)) .collect(); - let random_values = next_region(&mut inputs_offset, air.num_random_values as usize); + // Random challenges α, β. + let random_alpha = *offset; + let random_beta = *offset + 1; + *offset += 2; + // The next region must be word-aligned to facilitate hashing. + align(offset, Alignment::Word); // TODO(Issue: #391): Use the following to derive the degree generically, and maybe add it // to `Air` @@ -98,17 +166,32 @@ impl Layout { // Quotient is stored as a segment num_quotient_parts, ]; - let trace_segments = [0, 1] - .map(|_row_offset| segment_widths.map(|width| next_region(&mut inputs_offset, width))); - let stark_vars = next_region(&mut inputs_offset, StarkVar::num_vars()); + // Each segment must be double-word aligned to facilitate the opening of rows + // during the FRI query phase. + // At the moment, we do so by padding each trace with zero-valued columns. + let trace_segments = [0, 1].map(|_row_offset| { + FullTraceShape::new( + allocate_region(offset, segment_widths[0], Alignment::DoubleWord), + allocate_region(offset, segment_widths[1], Alignment::DoubleWord), + allocate_region(offset, segment_widths[2], Alignment::DoubleWord), + ) + }); + + let stark_vars = allocate_region(offset, StarkVar::num_vars(), Alignment::Word); + + // Ensure the entire input region is word aligned + align(offset, Alignment::Word); Self { public_inputs, + reduced_tables_region, + reduced_tables, + random_alpha, + random_beta, trace_segments, - random_values, stark_vars, - num_inputs: inputs_offset, + num_inputs: *offset, } } @@ -119,30 +202,34 @@ impl Layout { .and_then(|region| region.as_node(public_input.index)) } + /// Input node associated with a reduced public input table variable. + pub fn reduced_table_node(&self, table_access: &PublicInputTableAccess) -> Option { + self.reduced_tables + .get(table_access) + .and_then(|index| self.reduced_tables_region.as_node(*index)) + } + /// Input node associated with a trace variable. pub fn trace_access_node(&self, trace_access: &TraceAccess) -> Option { - let TraceAccess { - segment, - column, - row_offset, - } = *trace_access; - // We should only be able to access the main and aux segments. - if segment > 1 { - return None; - }; + let TraceAccess { segment, column, row_offset } = *trace_access; let segments_in_row = self.trace_segments.get(row_offset)?; - let segment_region = segments_in_row.get(segment)?; + let segment_region = &segments_in_row[segment]; segment_region.as_node(column) } - /// Input node associated with a random challenge variable. - pub fn random_value_node(&self, index: usize) -> Option { - self.random_values.as_node(index) + /// Input node associated with the variable for the random challenge α. + pub fn random_alpha_node(&self) -> Node { + Node::Input(self.random_alpha) + } + + /// Input node associated with the variable for the random challenge β. + pub fn random_beta_node(&self) -> Node { + Node::Input(self.random_beta) } /// Input nodes associated with the quotient polynomial coefficients. pub fn quotient_nodes(&self) -> Vec { - self.trace_segments[0][2].iter_nodes().collect() + self.trace_segments[0].quotient.iter_nodes().collect() } /// Input node associated with an auxiliary STARK challenge/variable. @@ -185,21 +272,22 @@ impl InputRegion { /// List of STARK variables and challenges, derived from the public parameters and proof transcript. #[derive(Copy, Clone, Debug)] pub enum StarkVar { - /// The variable g⁻² corresponding to the penultimate point in the subgroup over which the - /// trace is interpolated. - GenPenultimate = 0, - /// The variable g⁻¹ corresponding to the last point in the subgroup over which the trace is - /// interpolated. - GenLast = 1, /// The variable α used as for random linear-combination of constraints. - Alpha = 2, + Alpha = 0, /// The variable z at which the constraints evaluation check is performed. - Z = 3, + Z = 1, /// The variable zⁿ, where `n = trace_len` - ZPowN = 4, + ZPowN = 2, + /// The variable g⁻¹ corresponding to the last point in the subgroup over which the trace is + /// interpolated. + GenLast = 3, /// The variable `zᵐᵃˣ`, where `max` is equal to `trace_len / max_cycle_len`. Details can be - /// found in [`crate::builder::CircuitBuilder::periodic_column`] - ZMaxCycle = 5, + /// found in `CircuitBuilder::periodic_column`. + // TODO: Make this method public or fix the link to the correct location + ZMaxCycle = 4, + /// The variable g⁻² corresponding to the penultimate point in the subgroup over which the + /// trace is interpolated. + GenPenultimate = 5, } impl StarkVar { @@ -213,12 +301,12 @@ impl TryFrom for StarkVar { fn try_from(value: usize) -> Result { match value { - 0 => Ok(Self::GenPenultimate), - 1 => Ok(Self::GenLast), - 2 => Ok(Self::Alpha), - 3 => Ok(Self::Z), - 4 => Ok(Self::ZPowN), - 5 => Ok(Self::ZMaxCycle), + 0 => Ok(Self::Alpha), + 1 => Ok(Self::Z), + 2 => Ok(Self::ZPowN), + 3 => Ok(Self::GenLast), + 4 => Ok(Self::ZMaxCycle), + 5 => Ok(Self::GenPenultimate), _ => Err(value), } } diff --git a/codegen/ace/src/lib.rs b/codegen/ace/src/lib.rs index 4fe8927c4..7a6d96cc4 100644 --- a/codegen/ace/src/lib.rs +++ b/codegen/ace/src/lib.rs @@ -7,24 +7,24 @@ mod layout; #[cfg(test)] mod tests; -pub use crate::circuit::{Circuit as AceCircuit, Node as AceNode}; -pub use crate::encoded::EncodedCircuit as EncodedAceCircuit; -pub use crate::inputs::{AceVars, AirInputs}; -pub use crate::layout::Layout as AirLayout; +use air_ir::{Air, ConstraintDomain, TraceSegmentId}; +pub use mir::ir::QuadFelt; use crate::builder::{CircuitBuilder, LinearCombination}; -use crate::layout::StarkVar; -use air_ir::{Air, ConstraintDomain}; -use miden_core::{Felt, QuadExtension}; - -type QuadFelt = QuadExtension; +pub use crate::{ + circuit::{Circuit as AceCircuit, Node as AceNode}, + encoded::EncodedCircuit as EncodedAceCircuit, + inputs::{AceVars, AirInputs}, + layout::{Layout as AirLayout, StarkVar}, +}; /// Air constraints are organized in 3 main groups: integrity roots, /// boundary-first roots and boundary-last roots. /// The roots in each group are linearly combined with powers of a random challenge `α`: /// - `int = ∑ᵢ int_roots[i]⋅αⁱ` for `i` from 0 to `|int_roots|` /// - `bf = ∑ᵢ bf_roots[i]⋅αⁱ⁺ᵒ` for `i` from 0 to `|int_roots|`, and `o = |int_roots|` -/// - `bl = ∑ᵢ_i bl_roots[i]⋅αⁱ⁺ᵒ` for `i` from 0 to `|bf_roots|` and `o = |int_roots| + |bf_roots|` +/// - `bl = ∑ᵢ_i bl_roots[i]⋅αⁱ⁺ᵒ` for `i` from 0 to `|bf_roots|` and `o = |int_roots| + +/// |bf_roots|` /// /// This function builds a circuit that computes the following formula which must evaluate to zero: /// `z₋₂²⋅z₋₁⋅z₀⋅int + zₙ⋅z₋₂⋅bf + zₙ⋅z₀⋅bl - Q(z)⋅zₙ⋅z₀⋅z₋₂`. The variables are given by: @@ -36,12 +36,12 @@ type QuadFelt = QuadExtension; /// - `n` is the length of the trace. /// /// This is equivalent to the check -/// ```ignore +/// ```text /// num₀/[(zⁿ - 1)/[(z - g⁻¹)(z - g⁻²)]] + num₁/(z - 1) + num₂/(z - g⁻²) = Q(z) /// ``` /// /// The ACE chiplet expects the inputs of the original AirScript, with the order defined by -/// [`AceLayout`]: +/// `AceLayout`: /// - the public inputs of the AirScript e.g. `public_inputs { stack_inputs[16] }`, /// - auxiliary randomness of the AirScript e.g. `random_values { rand: [2] }`, /// - the main segment of trace inputs of the AirScript e.g. `trace_columns { main: [a b] }`, @@ -52,12 +52,13 @@ type QuadFelt = QuadExtension; /// - a dummy section of 8 quotient evaluation for the next row, unused by the ACE circuit. /// /// Additionally, the ACE chiplet expects the following 5 auxiliary "STARK" inputs, whose order -/// is defined by [`StarkVar`], given by `[g⁻¹, g⁻¹, α, z, zⁿ, zᵐᵃˣ`]. +/// is defined by `StarkVar`, given by `[α, z, zⁿ, g⁻¹, zᵐᵃˣ, g⁻²]`. pub fn build_ace_circuit(air: &Air) -> anyhow::Result<(AceNode, AceCircuit)> { - // A circuit builder is instantiated with the inputs of the circuits plus the 13 needed by the ACE chiplet + // A circuit builder is instantiated with the inputs of the circuits plus the 13 needed by the + // ACE chiplet let mut cb = CircuitBuilder::new(air); - let segments = [0, 1]; + let segments = [TraceSegmentId::Main, TraceSegmentId::Aux]; let integrity_roots: Vec<_> = segments .iter() .flat_map(|&seg| air.integrity_constraints(seg)) @@ -66,11 +67,11 @@ pub fn build_ace_circuit(air: &Air) -> anyhow::Result<(AceNode, AceCircuit)> { // all-row constraints ConstraintDomain::EveryRow | ConstraintDomain::EveryFrame(2) => { Some(cb.node_from_index(air, constraint.node_index())) - } + }, ConstraintDomain::FirstRow | ConstraintDomain::LastRow => None, ConstraintDomain::EveryFrame(_) => { panic!("invalid integrity constraint domain") - } + }, }) .collect(); @@ -113,13 +114,7 @@ pub fn build_ace_circuit(air: &Air) -> anyhow::Result<(AceNode, AceCircuit)> { // z₋₂²⋅z₋₁⋅z₀⋅int { let int = lc.next_linear_combination(&mut cb, integrity_roots); - let res = cb.prod([ - vanish_first, - vanish_penultimate, - vanish_last, - vanish_penultimate, - int, - ]); + let res = cb.prod([vanish_first, vanish_penultimate, vanish_last, vanish_penultimate, int]); lhs = cb.add(lhs, res); }; diff --git a/codegen/ace/src/tests/mod.rs b/codegen/ace/src/tests/mod.rs index 02f5de796..92d4a6078 100644 --- a/codegen/ace/src/tests/mod.rs +++ b/codegen/ace/src/tests/mod.rs @@ -1,33 +1,28 @@ -use crate::circuit::{Circuit, Node}; -use crate::{AceVars, QuadFelt, build_ace_circuit}; -use air_ir::Air; -use miden_diagnostics::term::termcolor::ColorChoice; -use miden_diagnostics::{CodeMap, DefaultEmitter, DiagnosticsHandler}; use std::sync::Arc; + +use air_ir::{Air, compile}; +use miden_diagnostics::{ + CodeMap, DefaultEmitter, DiagnosticsHandler, term::termcolor::ColorChoice, +}; use winter_math::FieldElement; +use crate::{ + AceVars, QuadFelt, build_ace_circuit, + circuit::{Circuit, Node}, +}; + mod quotient; mod random; /// Generates an ACE circuit and its root index from an AirScript program. pub fn generate_circuit(source: &str) -> (Air, Circuit, Node) { - use air_pass::Pass; - let code_map = Arc::new(CodeMap::new()); let emitter = Arc::new(DefaultEmitter::new(ColorChoice::Auto)); let diagnostics = DiagnosticsHandler::new(Default::default(), code_map.clone(), emitter); let air = air_parser::parse(&diagnostics, code_map, source) .map_err(air_ir::CompileError::Parse) - .and_then(|ast| { - let mut pipeline = air_parser::transforms::ConstantPropagation::new(&diagnostics) - .chain(mir::passes::AstToMir::new(&diagnostics)) - .chain(mir::passes::Inlining::new(&diagnostics)) - .chain(mir::passes::Unrolling::new(&diagnostics)) - .chain(air_ir::passes::MirToAir::new(&diagnostics)) - .chain(air_ir::passes::BusOpExpand::new(&diagnostics)); - pipeline.run(ast) - }) + .and_then(|program| compile(&diagnostics, program)) .expect("lowering failed"); let (root, circuit) = build_ace_circuit(&air).expect("codegen failed"); diff --git a/codegen/ace/src/tests/quotient.rs b/codegen/ace/src/tests/quotient.rs index 815f6eed0..3dbc99b73 100644 --- a/codegen/ace/src/tests/quotient.rs +++ b/codegen/ace/src/tests/quotient.rs @@ -1,10 +1,14 @@ -use crate::QuadFelt; -use crate::inputs::{AceVars, StarkInputs}; -use air_ir::{Air, ConstraintDomain, NodeIndex, Operation, Value}; -use miden_core::Felt; use std::collections::BTreeMap; + +use air_ir::{Air, ConstraintDomain, NodeIndex, Operation, TraceSegmentId, Value}; +use miden_core::Felt; use winter_math::FieldElement; +use crate::{ + QuadFelt, + inputs::{AceVars, StarkInputs}, +}; + /// Evaluates the quotient polynomial of the Air. pub fn eval_quotient(air: &Air, ace_vars: &AceVars, log_trace_len: u32) -> QuadFelt { let StarkInputs { @@ -27,27 +31,25 @@ pub fn eval_quotient(air: &Air, ace_vars: &AceVars, log_trace_len: u32) -> QuadF let z_col_pow = trace_len / col.values.len(); let z_col = z.exp_vartime(z_col_pow as u64); - let mut poly: Vec<_> = col - .values - .iter() - .copied() - .map(Felt::new) - .map(QuadFelt::from) - .collect(); + let mut poly: Vec<_> = + col.values.iter().copied().map(Felt::new).map(QuadFelt::from).collect(); let twiddles = winter_math::fft::get_inv_twiddles::(poly.len()); winter_math::fft::interpolate_poly(&mut poly, &twiddles); let eval = poly_eval(&poly, z_col); - (*ident, QuadFelt::from(eval)) + (ident.clone(), QuadFelt::from(eval)) }) .collect(); // Map public inputs from identifier to index matching the AirLayout format - let public: BTreeMap<_, _> = air - .public_inputs - .keys() + let public: BTreeMap<_, _> = + air.public_inputs.keys().enumerate().map(|(i, ident)| (*ident, i)).collect(); + + let reduced_tables: BTreeMap<_, _> = air + .reduced_public_input_table_accesses() + .into_iter() .enumerate() - .map(|(i, ident)| (*ident, i)) + .map(|(i, access)| (access, i)) .collect(); // Prepare a vector containing evaluations of all nodes in the Air graph. @@ -60,18 +62,29 @@ pub fn eval_quotient(air: &Air, ace_vars: &AceVars, log_trace_len: u32) -> QuadF for node_idx in 0..num_nodes { let node: NodeIndex = node_idx.into(); let op = graph.node(&node).op(); - let eval = match *op { + let eval = match op.clone() { Operation::Value(v) => match v { Value::Constant(c) => QuadFelt::from(Felt::new(c)), Value::TraceAccess(access) => { ace_vars.segments[access.row_offset][access.segment][access.column] - } + }, Value::PeriodicColumn(access) => periodic[&access.name], Value::PublicInput(access) => { let idx = public[&access.name]; ace_vars.public[idx][access.index] - } - Value::RandomValue(idx) => ace_vars.rand[idx], + }, + Value::PublicInputTable(access) => { + let idx = reduced_tables[&access]; + ace_vars.reduced_tables[idx] + }, + Value::RandomValue(idx) => { + if idx == 0 { + ace_vars.random_alpha + } else { + let beta_power = idx - 1; + ace_vars.random_beta.exp_vartime(beta_power as u64) + } + }, }, Operation::Add(l, r) => evals[usize::from(l)] + evals[usize::from(r)], Operation::Sub(l, r) => evals[usize::from(l)] - evals[usize::from(r)], @@ -81,32 +94,26 @@ pub fn eval_quotient(air: &Air, ace_vars: &AceVars, log_trace_len: u32) -> QuadF } // Iterator for all powers of alpha - let mut alpha_pow_iter = std::iter::successors(Some(QuadFelt::ONE), move |alpha_prev| { - Some(*alpha_prev * alpha) - }); + let mut alpha_pow_iter = + std::iter::successors(Some(QuadFelt::ONE), move |alpha_prev| Some(*alpha_prev * alpha)); // Evaluate linear-combination of integrity constraints. - let integrity: QuadFelt = [0, 1] + let integrity: QuadFelt = [TraceSegmentId::Main, TraceSegmentId::Aux] .into_iter() .flat_map(|segment| { - air.constraints - .integrity_constraints(segment) - .iter() - .map(|c| { - // TODO(Issue #392): Technically we should separate the transition from - // all-row constraints - // assert_eq!(c.domain(), ConstraintDomain::EveryFrame(2)); - let idx = usize::from(*c.node_index()); - evals[idx] - }) + air.constraints.integrity_constraints(segment).iter().map(|c| { + // TODO(Issue #392): Technically we should separate the transition from + // all-row constraints + // assert_eq!(c.domain(), ConstraintDomain::EveryFrame(2)); + let idx = usize::from(*c.node_index()); + evals[idx] + }) }) .zip(alpha_pow_iter.by_ref()) - .fold(QuadFelt::ZERO, |acc, (eval, alpha_pow)| { - acc + eval * alpha_pow - }); + .fold(QuadFelt::ZERO, |acc, (eval, alpha_pow)| acc + eval * alpha_pow); // Evaluate linear-combination of integrity constraints for the first row - let boundary_first = [0, 1] + let boundary_first = [TraceSegmentId::Main, TraceSegmentId::Aux] .into_iter() .flat_map(|segment| { air.constraints @@ -119,12 +126,10 @@ pub fn eval_quotient(air: &Air, ace_vars: &AceVars, log_trace_len: u32) -> QuadF }) }) .zip(alpha_pow_iter.by_ref()) - .fold(QuadFelt::ZERO, |acc, (eval, alpha_pow)| { - acc + eval * alpha_pow - }); + .fold(QuadFelt::ZERO, |acc, (eval, alpha_pow)| acc + eval * alpha_pow); // Evaluate linear-combination of integrity constraints for the last row - let boundary_last = [0, 1] + let boundary_last = [TraceSegmentId::Main, TraceSegmentId::Aux] .into_iter() .flat_map(|segment| { air.constraints @@ -137,9 +142,7 @@ pub fn eval_quotient(air: &Air, ace_vars: &AceVars, log_trace_len: u32) -> QuadF }) }) .zip(alpha_pow_iter.by_ref()) - .fold(QuadFelt::ZERO, |acc, (eval, alpha_pow)| { - acc + eval * alpha_pow - }); + .fold(QuadFelt::ZERO, |acc, (eval, alpha_pow)| acc + eval * alpha_pow); // z-1 = z − g⁰ let vanishing_first = z - QuadFelt::ONE; diff --git a/codegen/ace/src/tests/random.rs b/codegen/ace/src/tests/random.rs index b3306fc67..c5c18da7b 100644 --- a/codegen/ace/src/tests/random.rs +++ b/codegen/ace/src/tests/random.rs @@ -1,11 +1,14 @@ -use crate::inputs::StarkInputs; -use crate::layout::{InputRegion, Layout}; -use crate::tests::quotient::{eval_quotient, poly_eval}; -use crate::{AceVars, QuadFelt}; -use air_ir::Air; +use air_ir::{Air, FullTraceShape}; use rand::Rng; use winter_utils::Randomizable; +use crate::{ + AceVars, QuadFelt, + inputs::StarkInputs, + layout::{InputRegion, Layout}, + tests::quotient::{eval_quotient, poly_eval}, +}; + impl InputRegion { /// Generates a random list of values that fit in this region. pub fn random(&self) -> Vec { @@ -17,20 +20,22 @@ impl AceVars { /// Samples fully random inputs for the ACE circuit. pub fn random(air: &Air, log_trace_len: u32) -> Self { let layout = Layout::new(air); - let public = layout - .public_inputs - .values() - .map(|pi| pi.random()) - .collect(); - let segments = layout - .trace_segments - .map(|segment_row| segment_row.map(|row_region| row_region.random())); - let rand = layout.random_values.random(); + let public = layout.public_inputs.values().map(|pi| pi.random()).collect(); + let reduced_tables = layout.reduced_tables_region.random(); + let segments = layout.trace_segments.map(|segment_row| { + FullTraceShape::new( + segment_row.segments.main.random(), + segment_row.segments.aux.random(), + segment_row.quotient.random(), + ) + }); let stark = StarkInputs::random(air, log_trace_len); Self { public, + reduced_tables, segments, - rand, + random_alpha: rand_quad(), + random_beta: rand_quad(), stark, } } diff --git a/codegen/ace/tests/regressions/Buses.dot b/codegen/ace/tests/regressions/Buses.dot index 7768ce5c4..687cdd632 100644 --- a/codegen/ace/tests/regressions/Buses.dot +++ b/codegen/ace/tests/regressions/Buses.dot @@ -2,169 +2,166 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="R[0]"] -input5 [label="R[1]"] -input8 [label="M[0]"] -input9 [label="M[1]"] -input24 [label="M'[0]"] -input25 [label="M'[1]"] -input12 [label="A[0]"] -input28 [label="A'[0]"] -input12 [label="Q[0]"] -input40 [label="g⁻²"] -input41 [label="g⁻¹"] +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input26 [label="M'[0]"] +input27 [label="M'[1]"] +input14 [label="A[0]"] +input30 [label="A'[0]"] +input14 [label="Q[0]"] input42 [label="⍺"] input43 [label="z"] input44 [label="zⁿ"] -input45 [label="zᵐᵃˣ"] -op0 [label="op0\ninput24 - input8"] -input24 -> op0 -input8 -> op0 -op1 [label="op1\ninput5 × input8"] -input5 -> op1 +input45 [label="g⁻¹"] +input46 [label="zᵐᵃˣ"] +input47 [label="g⁻²"] +op0 [label="op0\ninput26 - input10"] +input26 -> op0 +input10 -> op0 +op1 [label="op1\ninput8 + input10"] input8 -> op1 -op2 [label="op2\ninput4 + op1"] -input4 -> op2 +input10 -> op1 +op2 [label="op2\ninput11 × op1"] +input11 -> op2 op1 -> op2 -op3 [label="op3\ninput9 × op2"] -input9 -> op3 -op2 -> op3 -op4 [label="op4\nconst1 - input9"] -const1 -> op4 -input9 -> op4 -op5 [label="op5\nop3 + op4"] -op3 -> op5 +op3 [label="op3\nconst1 - input11"] +const1 -> op3 +input11 -> op3 +op4 [label="op4\nop2 + op3"] +op2 -> op4 +op3 -> op4 +op5 [label="op5\ninput14 × op4"] +input14 -> op5 op4 -> op5 -op6 [label="op6\ninput12 × op5"] -input12 -> op6 +op6 [label="op6\nop5 - input30"] op5 -> op6 -op7 [label="op7\nop6 - input28"] -op6 -> op7 -input28 -> op7 -op8 [label="op8\ninput8 - const1"] -input8 -> op8 +input30 -> op6 +op7 [label="op7\ninput10 - const1"] +input10 -> op7 +const1 -> op7 +op8 [label="op8\ninput14 - const1"] +input14 -> op8 const1 -> op8 -op9 [label="op9\ninput12 - const1"] -input12 -> op9 +op9 [label="op9\ninput43 - const1"] +input43 -> op9 const1 -> op9 -op10 [label="op10\ninput43 - const1"] +op10 [label="op10\ninput43 - input47"] input43 -> op10 -const1 -> op10 -op11 [label="op11\ninput43 - input40"] +input47 -> op10 +op11 [label="op11\ninput43 - input45"] input43 -> op11 -input40 -> op11 -op12 [label="op12\ninput43 - input41"] -input43 -> op12 -input41 -> op12 -op13 [label="op13\ninput44 - const1"] -input44 -> op13 -const1 -> op13 -op14 [label="op14\ninput42 × op7"] -input42 -> op14 -op7 -> op14 -op15 [label="op15\nop0 + op14"] -op0 -> op15 -op14 -> op15 -op16 [label="op16\nop10 × op11"] -op10 -> op16 +input45 -> op11 +op12 [label="op12\ninput44 - const1"] +input44 -> op12 +const1 -> op12 +op13 [label="op13\ninput42 × op6"] +input42 -> op13 +op6 -> op13 +op14 [label="op14\nop0 + op13"] +op0 -> op14 +op13 -> op14 +op15 [label="op15\nop9 × op10"] +op9 -> op15 +op10 -> op15 +op16 [label="op16\nop11 × op15"] op11 -> op16 -op17 [label="op17\nop12 × op16"] -op12 -> op17 +op15 -> op16 +op17 [label="op17\nop10 × op16"] +op10 -> op17 op16 -> op17 -op18 [label="op18\nop11 × op17"] -op11 -> op18 +op18 [label="op18\nop14 × op17"] +op14 -> op18 op17 -> op18 -op19 [label="op19\nop15 × op18"] -op15 -> op19 -op18 -> op19 -op20 [label="op20\ninput42 × input42"] -input42 -> op20 -input42 -> op20 -op21 [label="op21\nop8 × op20"] -op8 -> op21 -op20 -> op21 -op22 [label="op22\ninput42 × op20"] -input42 -> op22 -op20 -> op22 -op23 [label="op23\nop9 × op22"] -op9 -> op23 +op19 [label="op19\ninput42 × input42"] +input42 -> op19 +input42 -> op19 +op20 [label="op20\nop7 × op19"] +op7 -> op20 +op19 -> op20 +op21 [label="op21\ninput42 × op19"] +input42 -> op21 +op19 -> op21 +op22 [label="op22\nop8 × op21"] +op8 -> op22 +op21 -> op22 +op23 [label="op23\nop20 + op22"] +op20 -> op23 op22 -> op23 -op24 [label="op24\nop21 + op23"] -op21 -> op24 -op23 -> op24 -op25 [label="op25\nop11 × op13"] -op11 -> op25 -op13 -> op25 -op26 [label="op26\nop24 × op25"] -op24 -> op26 +op24 [label="op24\nop10 × op12"] +op10 -> op24 +op12 -> op24 +op25 [label="op25\nop23 × op24"] +op23 -> op25 +op24 -> op25 +op26 [label="op26\nop18 + op25"] +op18 -> op26 op25 -> op26 -op27 [label="op27\nop19 + op26"] -op19 -> op27 -op26 -> op27 -op28 [label="op28\ninput42 × op22"] -input42 -> op28 -op22 -> op28 -op29 [label="op29\nop9 × op28"] +op27 [label="op27\ninput42 × op21"] +input42 -> op27 +op21 -> op27 +op28 [label="op28\nop8 × op27"] +op8 -> op28 +op27 -> op28 +op29 [label="op29\nop9 × op12"] op9 -> op29 -op28 -> op29 -op30 [label="op30\nop10 × op13"] -op10 -> op30 -op13 -> op30 -op31 [label="op31\nop29 × op30"] -op29 -> op31 +op12 -> op29 +op30 [label="op30\nop28 × op29"] +op28 -> op30 +op29 -> op30 +op31 [label="op31\nop26 + op30"] +op26 -> op31 op30 -> op31 -op32 [label="op32\nop27 + op31"] -op27 -> op32 -op31 -> op32 -op33 [label="op33\ninput23 × input44"] -input23 -> op33 -input44 -> op33 -op34 [label="op34\ninput22 + op33"] -input22 -> op34 +op32 [label="op32\ninput25 × input44"] +input25 -> op32 +input44 -> op32 +op33 [label="op33\ninput24 + op32"] +input24 -> op33 +op32 -> op33 +op34 [label="op34\ninput44 × op33"] +input44 -> op34 op33 -> op34 -op35 [label="op35\ninput44 × op34"] -input44 -> op35 +op35 [label="op35\ninput23 + op34"] +input23 -> op35 op34 -> op35 -op36 [label="op36\ninput21 + op35"] -input21 -> op36 +op36 [label="op36\ninput44 × op35"] +input44 -> op36 op35 -> op36 -op37 [label="op37\ninput44 × op36"] -input44 -> op37 +op37 [label="op37\ninput22 + op36"] +input22 -> op37 op36 -> op37 -op38 [label="op38\ninput20 + op37"] -input20 -> op38 +op38 [label="op38\ninput44 × op37"] +input44 -> op38 op37 -> op38 -op39 [label="op39\ninput44 × op38"] -input44 -> op39 +op39 [label="op39\ninput21 + op38"] +input21 -> op39 op38 -> op39 -op40 [label="op40\ninput19 + op39"] -input19 -> op40 +op40 [label="op40\ninput44 × op39"] +input44 -> op40 op39 -> op40 -op41 [label="op41\ninput44 × op40"] -input44 -> op41 +op41 [label="op41\ninput20 + op40"] +input20 -> op41 op40 -> op41 -op42 [label="op42\ninput18 + op41"] -input18 -> op42 +op42 [label="op42\ninput44 × op41"] +input44 -> op42 op41 -> op42 -op43 [label="op43\ninput44 × op42"] -input44 -> op43 +op43 [label="op43\ninput19 + op42"] +input19 -> op43 op42 -> op43 -op44 [label="op44\ninput17 + op43"] -input17 -> op44 +op44 [label="op44\ninput44 × op43"] +input44 -> op44 op43 -> op44 -op45 [label="op45\ninput44 × op44"] -input44 -> op45 +op45 [label="op45\ninput18 + op44"] +input18 -> op45 op44 -> op45 -op46 [label="op46\ninput16 + op45"] -input16 -> op46 -op45 -> op46 -op47 [label="op47\nop13 × op16"] -op13 -> op47 -op16 -> op47 -op48 [label="op48\nop46 × op47"] -op46 -> op48 +op46 [label="op46\nop12 × op15"] +op12 -> op46 +op15 -> op46 +op47 [label="op47\nop45 × op46"] +op45 -> op47 +op46 -> op47 +op48 [label="op48\nop31 - op47"] +op31 -> op48 op47 -> op48 -op49 [label="op49\nop32 - op48"] -op32 -> op49 -op48 -> op49 } diff --git a/codegen/ace/tests/regressions/ComplexBoundary.dot b/codegen/ace/tests/regressions/ComplexBoundary.dot index fb63f1e59..45234a9e3 100644 --- a/codegen/ace/tests/regressions/ComplexBoundary.dot +++ b/codegen/ace/tests/regressions/ComplexBoundary.dot @@ -3,194 +3,196 @@ const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] input1 [label="PI[stack_inputs][1]"] -input4 [label="PI[stack_outputs][0]"] -input5 [label="PI[stack_outputs][1]"] -input8 [label="M[0]"] -input9 [label="M[1]"] -input10 [label="M[2]"] -input11 [label="M[3]"] -input12 [label="M[4]"] -input13 [label="M[5]"] -input28 [label="M'[0]"] -input29 [label="M'[1]"] -input30 [label="M'[2]"] -input31 [label="M'[3]"] -input32 [label="M'[4]"] -input33 [label="M'[5]"] -input16 [label="A[0]"] -input36 [label="A'[0]"] -input16 [label="Q[0]"] -input48 [label="g⁻²"] -input49 [label="g⁻¹"] +input2 [label="PI[stack_outputs][0]"] +input3 [label="PI[stack_outputs][1]"] +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input12 [label="M[2]"] +input13 [label="M[3]"] +input14 [label="M[4]"] +input15 [label="M[5]"] +input30 [label="M'[0]"] +input31 [label="M'[1]"] +input32 [label="M'[2]"] +input33 [label="M'[3]"] +input34 [label="M'[4]"] +input35 [label="M'[5]"] +input18 [label="A[0]"] +input38 [label="A'[0]"] +input18 [label="Q[0]"] input50 [label="⍺"] input51 [label="z"] input52 [label="zⁿ"] -input53 [label="zᵐᵃˣ"] -op0 [label="op0\ninput8 + input9"] -input8 -> op0 -input9 -> op0 -op1 [label="op1\ninput16 - input36"] -input16 -> op1 -input36 -> op1 -op2 [label="op2\ninput8 - input0"] -input8 -> op2 -input0 -> op2 -op3 [label="op3\ninput9 - input1"] -input9 -> op3 -input1 -> op3 -op4 [label="op4\ninput10 - const1"] -input10 -> op4 +input53 [label="g⁻¹"] +input54 [label="zᵐᵃˣ"] +input55 [label="g⁻²"] +op0 [label="op0\ninput10 + input11"] +input10 -> op0 +input11 -> op0 +op1 [label="op1\ninput10 - input0"] +input10 -> op1 +input0 -> op1 +op2 [label="op2\ninput11 - input1"] +input11 -> op2 +input1 -> op2 +op3 [label="op3\ninput12 - const1"] +input12 -> op3 +const1 -> op3 +op4 [label="op4\ninput13 - const1"] +input13 -> op4 const1 -> op4 -op5 [label="op5\ninput11 - const1"] -input11 -> op5 +op5 [label="op5\ninput15 - const1"] +input15 -> op5 const1 -> op5 -op6 [label="op6\ninput13 - const1"] -input13 -> op6 +op6 [label="op6\ninput18 - const1"] +input18 -> op6 const1 -> op6 -op7 [label="op7\ninput16 - const1"] -input16 -> op7 -const1 -> op7 -op8 [label="op8\ninput8 - input4"] -input8 -> op8 -input4 -> op8 -op9 [label="op9\ninput9 - input5"] -input9 -> op9 -input5 -> op9 -op10 [label="op10\ninput51 - const1"] +op7 [label="op7\ninput10 - input2"] +input10 -> op7 +input2 -> op7 +op8 [label="op8\ninput11 - input3"] +input11 -> op8 +input3 -> op8 +op9 [label="op9\ninput51 - const1"] +input51 -> op9 +const1 -> op9 +op10 [label="op10\ninput51 - input55"] input51 -> op10 -const1 -> op10 -op11 [label="op11\ninput51 - input48"] +input55 -> op10 +op11 [label="op11\ninput51 - input53"] input51 -> op11 -input48 -> op11 -op12 [label="op12\ninput51 - input49"] -input51 -> op12 -input49 -> op12 -op13 [label="op13\ninput52 - const1"] -input52 -> op13 -const1 -> op13 -op14 [label="op14\ninput50 × op1"] -input50 -> op14 -op1 -> op14 -op15 [label="op15\nop0 + op14"] -op0 -> op15 +input53 -> op11 +op12 [label="op12\ninput52 - const1"] +input52 -> op12 +const1 -> op12 +op13 [label="op13\nop9 × op10"] +op9 -> op13 +op10 -> op13 +op14 [label="op14\nop11 × op13"] +op11 -> op14 +op13 -> op14 +op15 [label="op15\nop10 × op14"] +op10 -> op15 op14 -> op15 -op16 [label="op16\nop10 × op11"] -op10 -> op16 -op11 -> op16 -op17 [label="op17\nop12 × op16"] -op12 -> op17 -op16 -> op17 -op18 [label="op18\nop11 × op17"] -op11 -> op18 -op17 -> op18 -op19 [label="op19\nop15 × op18"] -op15 -> op19 +op16 [label="op16\nop0 × op15"] +op0 -> op16 +op15 -> op16 +op17 [label="op17\ninput50 × op1"] +input50 -> op17 +op1 -> op17 +op18 [label="op18\ninput50 × input50"] +input50 -> op18 +input50 -> op18 +op19 [label="op19\nop2 × op18"] +op2 -> op19 op18 -> op19 -op20 [label="op20\ninput50 × input50"] -input50 -> op20 -input50 -> op20 -op21 [label="op21\nop2 × op20"] -op2 -> op21 -op20 -> op21 -op22 [label="op22\ninput50 × op20"] -input50 -> op22 -op20 -> op22 -op23 [label="op23\nop3 × op22"] -op3 -> op23 +op20 [label="op20\nop17 + op19"] +op17 -> op20 +op19 -> op20 +op21 [label="op21\ninput50 × op18"] +input50 -> op21 +op18 -> op21 +op22 [label="op22\nop3 × op21"] +op3 -> op22 +op21 -> op22 +op23 [label="op23\nop20 + op22"] +op20 -> op23 op22 -> op23 -op24 [label="op24\nop21 + op23"] +op24 [label="op24\ninput50 × op21"] +input50 -> op24 op21 -> op24 -op23 -> op24 -op25 [label="op25\ninput50 × op22"] -input50 -> op25 -op22 -> op25 -op26 [label="op26\nop4 × op25"] -op4 -> op26 +op25 [label="op25\nop4 × op24"] +op4 -> op25 +op24 -> op25 +op26 [label="op26\nop23 + op25"] +op23 -> op26 op25 -> op26 -op27 [label="op27\nop24 + op26"] +op27 [label="op27\ninput50 × op24"] +input50 -> op27 op24 -> op27 -op26 -> op27 -op28 [label="op28\ninput50 × op25"] -input50 -> op28 -op25 -> op28 -op29 [label="op29\nop5 × op28"] -op5 -> op29 +op28 [label="op28\ninput14 × op27"] +input14 -> op28 +op27 -> op28 +op29 [label="op29\nop26 + op28"] +op26 -> op29 op28 -> op29 -op30 [label="op30\nop27 + op29"] +op30 [label="op30\ninput50 × op27"] +input50 -> op30 op27 -> op30 -op29 -> op30 -op31 [label="op31\ninput50 × op28"] -input50 -> op31 -op28 -> op31 -op32 [label="op32\ninput12 × op31"] -input12 -> op32 +op31 [label="op31\nop5 × op30"] +op5 -> op31 +op30 -> op31 +op32 [label="op32\nop29 + op31"] +op29 -> op32 op31 -> op32 -op33 [label="op33\nop30 + op32"] +op33 [label="op33\ninput50 × op30"] +input50 -> op33 op30 -> op33 -op32 -> op33 -op34 [label="op34\ninput50 × op31"] -input50 -> op34 -op31 -> op34 -op35 [label="op35\nop6 × op34"] -op6 -> op35 +op34 [label="op34\nop6 × op33"] +op6 -> op34 +op33 -> op34 +op35 [label="op35\nop32 + op34"] +op32 -> op35 op34 -> op35 -op36 [label="op36\nop33 + op35"] -op33 -> op36 -op35 -> op36 -op37 [label="op37\ninput50 × op34"] -input50 -> op37 -op34 -> op37 -op38 [label="op38\nop7 × op37"] -op7 -> op38 +op36 [label="op36\nop10 × op12"] +op10 -> op36 +op12 -> op36 +op37 [label="op37\nop35 × op36"] +op35 -> op37 +op36 -> op37 +op38 [label="op38\nop16 + op37"] +op16 -> op38 op37 -> op38 -op39 [label="op39\nop36 + op38"] -op36 -> op39 -op38 -> op39 -op40 [label="op40\nop11 × op13"] -op11 -> op40 -op13 -> op40 -op41 [label="op41\nop39 × op40"] +op39 [label="op39\ninput50 × op33"] +input50 -> op39 +op33 -> op39 +op40 [label="op40\nop7 × op39"] +op7 -> op40 +op39 -> op40 +op41 [label="op41\ninput50 × op39"] +input50 -> op41 op39 -> op41 -op40 -> op41 -op42 [label="op42\nop19 + op41"] -op19 -> op42 +op42 [label="op42\nop8 × op41"] +op8 -> op42 op41 -> op42 -op43 [label="op43\ninput50 × op37"] -input50 -> op43 -op37 -> op43 -op44 [label="op44\nop8 × op43"] -op8 -> op44 -op43 -> op44 -op45 [label="op45\ninput50 × op43"] -input50 -> op45 -op43 -> op45 -op46 [label="op46\nop9 × op45"] -op9 -> op46 +op43 [label="op43\nop40 + op42"] +op40 -> op43 +op42 -> op43 +op44 [label="op44\ninput50 × op41"] +input50 -> op44 +op41 -> op44 +op45 [label="op45\nop6 × op44"] +op6 -> op45 +op44 -> op45 +op46 [label="op46\nop43 + op45"] +op43 -> op46 op45 -> op46 -op47 [label="op47\nop44 + op46"] -op44 -> op47 -op46 -> op47 -op48 [label="op48\ninput50 × op45"] -input50 -> op48 -op45 -> op48 -op49 [label="op49\nop7 × op48"] -op7 -> op49 +op47 [label="op47\nop9 × op12"] +op9 -> op47 +op12 -> op47 +op48 [label="op48\nop46 × op47"] +op46 -> op48 +op47 -> op48 +op49 [label="op49\nop38 + op48"] +op38 -> op49 op48 -> op49 -op50 [label="op50\nop47 + op49"] -op47 -> op50 -op49 -> op50 -op51 [label="op51\nop10 × op13"] -op10 -> op51 -op13 -> op51 -op52 [label="op52\nop50 × op51"] -op50 -> op52 +op50 [label="op50\ninput29 × input52"] +input29 -> op50 +input52 -> op50 +op51 [label="op51\ninput28 + op50"] +input28 -> op51 +op50 -> op51 +op52 [label="op52\ninput52 × op51"] +input52 -> op52 op51 -> op52 -op53 [label="op53\nop42 + op52"] -op42 -> op53 +op53 [label="op53\ninput27 + op52"] +input27 -> op53 op52 -> op53 -op54 [label="op54\ninput27 × input52"] -input27 -> op54 +op54 [label="op54\ninput52 × op53"] input52 -> op54 +op53 -> op54 op55 [label="op55\ninput26 + op54"] input26 -> op55 op54 -> op55 @@ -218,25 +220,13 @@ op61 -> op62 op63 [label="op63\ninput22 + op62"] input22 -> op63 op62 -> op63 -op64 [label="op64\ninput52 × op63"] -input52 -> op64 -op63 -> op64 -op65 [label="op65\ninput21 + op64"] -input21 -> op65 +op64 [label="op64\nop12 × op13"] +op12 -> op64 +op13 -> op64 +op65 [label="op65\nop63 × op64"] +op63 -> op65 op64 -> op65 -op66 [label="op66\ninput52 × op65"] -input52 -> op66 +op66 [label="op66\nop49 - op65"] +op49 -> op66 op65 -> op66 -op67 [label="op67\ninput20 + op66"] -input20 -> op67 -op66 -> op67 -op68 [label="op68\nop13 × op16"] -op13 -> op68 -op16 -> op68 -op69 [label="op69\nop67 × op68"] -op67 -> op69 -op68 -> op69 -op70 [label="op70\nop53 - op69"] -op53 -> op70 -op69 -> op70 } diff --git a/codegen/ace/tests/regressions/ConstantsAir.dot b/codegen/ace/tests/regressions/ConstantsAir.dot index a51440c35..6850d5a36 100644 --- a/codegen/ace/tests/regressions/ConstantsAir.dot +++ b/codegen/ace/tests/regressions/ConstantsAir.dot @@ -7,66 +7,68 @@ const4 [label="10"] const5 [label="35"] const6 [label="18446744069414584288"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input6 [label="M[2]"] -input16 [label="M'[0]"] -input17 [label="M'[1]"] -input18 [label="M'[2]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 + const0"] -input4 -> op0 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input12 [label="M[2]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] +input24 [label="M'[2]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 + const0"] +input10 -> op0 const0 -> op0 -op1 [label="op1\ninput16 - op0"] -input16 -> op1 +op1 [label="op1\ninput22 - op0"] +input22 -> op1 op0 -> op1 -op2 [label="op2\ninput5 × const2"] -input5 -> op2 +op2 [label="op2\ninput11 × const2"] +input11 -> op2 const2 -> op2 -op3 [label="op3\ninput17 - op2"] -input17 -> op3 +op3 [label="op3\ninput23 - op2"] +input23 -> op3 op2 -> op3 -op4 [label="op4\ninput6 × const4"] -input6 -> op4 +op4 [label="op4\ninput12 × const4"] +input12 -> op4 const4 -> op4 -op5 [label="op5\ninput18 - op4"] -input18 -> op5 +op5 [label="op5\ninput24 - op4"] +input24 -> op5 op4 -> op5 -op6 [label="op6\ninput4 - const0"] -input4 -> op6 +op6 [label="op6\ninput10 - const0"] +input10 -> op6 const0 -> op6 -op7 [label="op7\ninput5 - const5"] -input5 -> op7 +op7 [label="op7\ninput11 - const5"] +input11 -> op7 const5 -> op7 -op8 [label="op8\ninput6 - const6"] -input6 -> op8 +op8 [label="op8\ninput12 - const6"] +input12 -> op8 const6 -> op8 -op9 [label="op9\ninput31 - const3"] -input31 -> op9 +op9 [label="op9\ninput35 - const3"] +input35 -> op9 const3 -> op9 -op10 [label="op10\ninput31 - input28"] -input31 -> op10 -input28 -> op10 -op11 [label="op11\ninput31 - input29"] -input31 -> op11 -input29 -> op11 -op12 [label="op12\ninput32 - const3"] -input32 -> op12 +op10 [label="op10\ninput35 - input39"] +input35 -> op10 +input39 -> op10 +op11 [label="op11\ninput35 - input37"] +input35 -> op11 +input37 -> op11 +op12 [label="op12\ninput36 - const3"] +input36 -> op12 const3 -> op12 -op13 [label="op13\ninput30 × op3"] -input30 -> op13 +op13 [label="op13\ninput34 × op3"] +input34 -> op13 op3 -> op13 op14 [label="op14\nop1 + op13"] op1 -> op14 op13 -> op14 -op15 [label="op15\ninput30 × input30"] -input30 -> op15 -input30 -> op15 +op15 [label="op15\ninput34 × input34"] +input34 -> op15 +input34 -> op15 op16 [label="op16\nop5 × op15"] op5 -> op16 op15 -> op16 @@ -85,14 +87,14 @@ op19 -> op20 op21 [label="op21\nop17 × op20"] op17 -> op21 op20 -> op21 -op22 [label="op22\ninput30 × op15"] -input30 -> op22 +op22 [label="op22\ninput34 × op15"] +input34 -> op22 op15 -> op22 op23 [label="op23\nop6 × op22"] op6 -> op23 op22 -> op23 -op24 [label="op24\ninput30 × op22"] -input30 -> op24 +op24 [label="op24\ninput34 × op22"] +input34 -> op24 op22 -> op24 op25 [label="op25\nop7 × op24"] op7 -> op25 @@ -109,8 +111,8 @@ op27 -> op28 op29 [label="op29\nop21 + op28"] op21 -> op29 op28 -> op29 -op30 [label="op30\ninput30 × op24"] -input30 -> op30 +op30 [label="op30\ninput34 × op24"] +input34 -> op30 op24 -> op30 op31 [label="op31\nop8 × op30"] op8 -> op31 @@ -124,47 +126,47 @@ op32 -> op33 op34 [label="op34\nop29 + op33"] op29 -> op34 op33 -> op34 -op35 [label="op35\ninput15 × input32"] -input15 -> op35 -input32 -> op35 -op36 [label="op36\ninput14 + op35"] -input14 -> op36 +op35 [label="op35\ninput21 × input36"] +input21 -> op35 +input36 -> op35 +op36 [label="op36\ninput20 + op35"] +input20 -> op36 op35 -> op36 -op37 [label="op37\ninput32 × op36"] -input32 -> op37 +op37 [label="op37\ninput36 × op36"] +input36 -> op37 op36 -> op37 -op38 [label="op38\ninput13 + op37"] -input13 -> op38 +op38 [label="op38\ninput19 + op37"] +input19 -> op38 op37 -> op38 -op39 [label="op39\ninput32 × op38"] -input32 -> op39 +op39 [label="op39\ninput36 × op38"] +input36 -> op39 op38 -> op39 -op40 [label="op40\ninput12 + op39"] -input12 -> op40 +op40 [label="op40\ninput18 + op39"] +input18 -> op40 op39 -> op40 -op41 [label="op41\ninput32 × op40"] -input32 -> op41 +op41 [label="op41\ninput36 × op40"] +input36 -> op41 op40 -> op41 -op42 [label="op42\ninput11 + op41"] -input11 -> op42 +op42 [label="op42\ninput17 + op41"] +input17 -> op42 op41 -> op42 -op43 [label="op43\ninput32 × op42"] -input32 -> op43 +op43 [label="op43\ninput36 × op42"] +input36 -> op43 op42 -> op43 -op44 [label="op44\ninput10 + op43"] -input10 -> op44 +op44 [label="op44\ninput16 + op43"] +input16 -> op44 op43 -> op44 -op45 [label="op45\ninput32 × op44"] -input32 -> op45 +op45 [label="op45\ninput36 × op44"] +input36 -> op45 op44 -> op45 -op46 [label="op46\ninput9 + op45"] -input9 -> op46 +op46 [label="op46\ninput15 + op45"] +input15 -> op46 op45 -> op46 -op47 [label="op47\ninput32 × op46"] -input32 -> op47 +op47 [label="op47\ninput36 × op46"] +input36 -> op47 op46 -> op47 -op48 [label="op48\ninput8 + op47"] -input8 -> op48 +op48 [label="op48\ninput14 + op47"] +input14 -> op48 op47 -> op48 op49 [label="op49\nop12 × op18"] op12 -> op49 diff --git a/codegen/ace/tests/regressions/Exp.dot b/codegen/ace/tests/regressions/Exp.dot index e75943323..75cab230b 100644 --- a/codegen/ace/tests/regressions/Exp.dot +++ b/codegen/ace/tests/regressions/Exp.dot @@ -2,57 +2,59 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input16 [label="M'[0]"] -input17 [label="M'[1]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput5 × input5"] -input5 -> op0 -input5 -> op0 -op1 [label="op1\ninput5 × op0"] -input5 -> op1 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput11 × input11"] +input11 -> op0 +input11 -> op0 +op1 [label="op1\ninput11 × op0"] +input11 -> op1 op0 -> op1 op2 [label="op2\nop0 × op0"] op0 -> op2 op0 -> op2 -op3 [label="op3\ninput5 × op2"] -input5 -> op3 +op3 [label="op3\ninput11 × op2"] +input11 -> op3 op2 -> op3 -op4 [label="op4\ninput31 - const1"] -input31 -> op4 +op4 [label="op4\ninput35 - const1"] +input35 -> op4 const1 -> op4 -op5 [label="op5\ninput31 - input28"] -input31 -> op5 -input28 -> op5 -op6 [label="op6\ninput31 - input29"] -input31 -> op6 -input29 -> op6 -op7 [label="op7\ninput32 - const1"] -input32 -> op7 +op5 [label="op5\ninput35 - input39"] +input35 -> op5 +input39 -> op5 +op6 [label="op6\ninput35 - input37"] +input35 -> op6 +input37 -> op6 +op7 [label="op7\ninput36 - const1"] +input36 -> op7 const1 -> op7 -op8 [label="op8\ninput30 × op0"] -input30 -> op8 +op8 [label="op8\ninput34 × op0"] +input34 -> op8 op0 -> op8 -op9 [label="op9\ninput5 + op8"] -input5 -> op9 +op9 [label="op9\ninput11 + op8"] +input11 -> op9 op8 -> op9 -op10 [label="op10\ninput30 × input30"] -input30 -> op10 -input30 -> op10 +op10 [label="op10\ninput34 × input34"] +input34 -> op10 +input34 -> op10 op11 [label="op11\nop1 × op10"] op1 -> op11 op10 -> op11 op12 [label="op12\nop9 + op11"] op9 -> op12 op11 -> op12 -op13 [label="op13\ninput30 × op10"] -input30 -> op13 +op13 [label="op13\ninput34 × op10"] +input34 -> op13 op10 -> op13 op14 [label="op14\nop2 × op13"] op2 -> op14 @@ -60,8 +62,8 @@ op13 -> op14 op15 [label="op15\nop12 + op14"] op12 -> op15 op14 -> op15 -op16 [label="op16\ninput30 × op13"] -input30 -> op16 +op16 [label="op16\ninput34 × op13"] +input34 -> op16 op13 -> op16 op17 [label="op17\nop3 × op16"] op3 -> op17 @@ -81,11 +83,11 @@ op20 -> op21 op22 [label="op22\nop18 × op21"] op18 -> op22 op21 -> op22 -op23 [label="op23\ninput30 × op16"] -input30 -> op23 +op23 [label="op23\ninput34 × op16"] +input34 -> op23 op16 -> op23 -op24 [label="op24\ninput4 × op23"] -input4 -> op24 +op24 [label="op24\ninput10 × op23"] +input10 -> op24 op23 -> op24 op25 [label="op25\nop5 × op7"] op5 -> op25 @@ -99,47 +101,47 @@ op26 -> op27 op28 [label="op28\nop4 × op7"] op4 -> op28 op7 -> op28 -op29 [label="op29\ninput15 × input32"] -input15 -> op29 -input32 -> op29 -op30 [label="op30\ninput14 + op29"] -input14 -> op30 +op29 [label="op29\ninput21 × input36"] +input21 -> op29 +input36 -> op29 +op30 [label="op30\ninput20 + op29"] +input20 -> op30 op29 -> op30 -op31 [label="op31\ninput32 × op30"] -input32 -> op31 +op31 [label="op31\ninput36 × op30"] +input36 -> op31 op30 -> op31 -op32 [label="op32\ninput13 + op31"] -input13 -> op32 +op32 [label="op32\ninput19 + op31"] +input19 -> op32 op31 -> op32 -op33 [label="op33\ninput32 × op32"] -input32 -> op33 +op33 [label="op33\ninput36 × op32"] +input36 -> op33 op32 -> op33 -op34 [label="op34\ninput12 + op33"] -input12 -> op34 +op34 [label="op34\ninput18 + op33"] +input18 -> op34 op33 -> op34 -op35 [label="op35\ninput32 × op34"] -input32 -> op35 +op35 [label="op35\ninput36 × op34"] +input36 -> op35 op34 -> op35 -op36 [label="op36\ninput11 + op35"] -input11 -> op36 +op36 [label="op36\ninput17 + op35"] +input17 -> op36 op35 -> op36 -op37 [label="op37\ninput32 × op36"] -input32 -> op37 +op37 [label="op37\ninput36 × op36"] +input36 -> op37 op36 -> op37 -op38 [label="op38\ninput10 + op37"] -input10 -> op38 +op38 [label="op38\ninput16 + op37"] +input16 -> op38 op37 -> op38 -op39 [label="op39\ninput32 × op38"] -input32 -> op39 +op39 [label="op39\ninput36 × op38"] +input36 -> op39 op38 -> op39 -op40 [label="op40\ninput9 + op39"] -input9 -> op40 +op40 [label="op40\ninput15 + op39"] +input15 -> op40 op39 -> op40 -op41 [label="op41\ninput32 × op40"] -input32 -> op41 +op41 [label="op41\ninput36 × op40"] +input36 -> op41 op40 -> op41 -op42 [label="op42\ninput8 + op41"] -input8 -> op42 +op42 [label="op42\ninput14 + op41"] +input14 -> op42 op41 -> op42 op43 [label="op43\nop7 × op19"] op7 -> op43 diff --git a/codegen/ace/tests/regressions/LongTrace.dot b/codegen/ace/tests/regressions/LongTrace.dot index 0d3880cfa..fa7567000 100644 --- a/codegen/ace/tests/regressions/LongTrace.dot +++ b/codegen/ace/tests/regressions/LongTrace.dot @@ -2,53 +2,55 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input6 [label="M[2]"] -input7 [label="M[3]"] -input8 [label="M[4]"] -input9 [label="M[5]"] -input10 [label="M[6]"] -input11 [label="M[7]"] -input12 [label="M[8]"] -input24 [label="M'[0]"] -input25 [label="M'[1]"] -input26 [label="M'[2]"] -input27 [label="M'[3]"] -input28 [label="M'[4]"] -input29 [label="M'[5]"] -input30 [label="M'[6]"] -input31 [label="M'[7]"] -input32 [label="M'[8]"] -input44 [label="g⁻²"] -input45 [label="g⁻¹"] -input46 [label="⍺"] -input47 [label="z"] -input48 [label="zⁿ"] -input49 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 × input5"] -input4 -> op0 -input5 -> op0 -op1 [label="op1\ninput6 × op0"] -input6 -> op1 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input12 [label="M[2]"] +input13 [label="M[3]"] +input14 [label="M[4]"] +input15 [label="M[5]"] +input16 [label="M[6]"] +input17 [label="M[7]"] +input18 [label="M[8]"] +input30 [label="M'[0]"] +input31 [label="M'[1]"] +input32 [label="M'[2]"] +input33 [label="M'[3]"] +input34 [label="M'[4]"] +input35 [label="M'[5]"] +input36 [label="M'[6]"] +input37 [label="M'[7]"] +input38 [label="M'[8]"] +input50 [label="⍺"] +input51 [label="z"] +input52 [label="zⁿ"] +input53 [label="g⁻¹"] +input54 [label="zᵐᵃˣ"] +input55 [label="g⁻²"] +op0 [label="op0\ninput10 × input11"] +input10 -> op0 +input11 -> op0 +op1 [label="op1\ninput12 × op0"] +input12 -> op1 op0 -> op1 -op2 [label="op2\ninput7 + op1"] -input7 -> op2 +op2 [label="op2\ninput13 + op1"] +input13 -> op2 op1 -> op2 -op3 [label="op3\nop2 - input8"] +op3 [label="op3\nop2 - input14"] op2 -> op3 -input8 -> op3 -op4 [label="op4\ninput47 - const1"] -input47 -> op4 +input14 -> op3 +op4 [label="op4\ninput51 - const1"] +input51 -> op4 const1 -> op4 -op5 [label="op5\ninput47 - input44"] -input47 -> op5 -input44 -> op5 -op6 [label="op6\ninput47 - input45"] -input47 -> op6 -input45 -> op6 -op7 [label="op7\ninput48 - const1"] -input48 -> op7 +op5 [label="op5\ninput51 - input55"] +input51 -> op5 +input55 -> op5 +op6 [label="op6\ninput51 - input53"] +input51 -> op6 +input53 -> op6 +op7 [label="op7\ninput52 - const1"] +input52 -> op7 const1 -> op7 op8 [label="op8\nop4 × op5"] op4 -> op8 @@ -62,9 +64,9 @@ op9 -> op10 op11 [label="op11\nop3 × op10"] op3 -> op11 op10 -> op11 -op12 [label="op12\ninput4 × input46"] -input4 -> op12 -input46 -> op12 +op12 [label="op12\ninput10 × input50"] +input10 -> op12 +input50 -> op12 op13 [label="op13\nop5 × op7"] op5 -> op13 op7 -> op13 @@ -77,47 +79,47 @@ op14 -> op15 op16 [label="op16\nop4 × op7"] op4 -> op16 op7 -> op16 -op17 [label="op17\ninput23 × input48"] -input23 -> op17 -input48 -> op17 -op18 [label="op18\ninput22 + op17"] -input22 -> op18 +op17 [label="op17\ninput29 × input52"] +input29 -> op17 +input52 -> op17 +op18 [label="op18\ninput28 + op17"] +input28 -> op18 op17 -> op18 -op19 [label="op19\ninput48 × op18"] -input48 -> op19 +op19 [label="op19\ninput52 × op18"] +input52 -> op19 op18 -> op19 -op20 [label="op20\ninput21 + op19"] -input21 -> op20 +op20 [label="op20\ninput27 + op19"] +input27 -> op20 op19 -> op20 -op21 [label="op21\ninput48 × op20"] -input48 -> op21 +op21 [label="op21\ninput52 × op20"] +input52 -> op21 op20 -> op21 -op22 [label="op22\ninput20 + op21"] -input20 -> op22 +op22 [label="op22\ninput26 + op21"] +input26 -> op22 op21 -> op22 -op23 [label="op23\ninput48 × op22"] -input48 -> op23 +op23 [label="op23\ninput52 × op22"] +input52 -> op23 op22 -> op23 -op24 [label="op24\ninput19 + op23"] -input19 -> op24 +op24 [label="op24\ninput25 + op23"] +input25 -> op24 op23 -> op24 -op25 [label="op25\ninput48 × op24"] -input48 -> op25 +op25 [label="op25\ninput52 × op24"] +input52 -> op25 op24 -> op25 -op26 [label="op26\ninput18 + op25"] -input18 -> op26 +op26 [label="op26\ninput24 + op25"] +input24 -> op26 op25 -> op26 -op27 [label="op27\ninput48 × op26"] -input48 -> op27 +op27 [label="op27\ninput52 × op26"] +input52 -> op27 op26 -> op27 -op28 [label="op28\ninput17 + op27"] -input17 -> op28 +op28 [label="op28\ninput23 + op27"] +input23 -> op28 op27 -> op28 -op29 [label="op29\ninput48 × op28"] -input48 -> op29 +op29 [label="op29\ninput52 × op28"] +input52 -> op29 op28 -> op29 -op30 [label="op30\ninput16 + op29"] -input16 -> op30 +op30 [label="op30\ninput22 + op29"] +input22 -> op30 op29 -> op30 op31 [label="op31\nop7 × op8"] op7 -> op31 diff --git a/codegen/ace/tests/regressions/MultipleAux.dot b/codegen/ace/tests/regressions/MultipleAux.dot index 1deabc6c8..5e40189ef 100644 --- a/codegen/ace/tests/regressions/MultipleAux.dot +++ b/codegen/ace/tests/regressions/MultipleAux.dot @@ -22,81 +22,83 @@ input12 [label="PI[stack_inputs][12]"] input13 [label="PI[stack_inputs][13]"] input14 [label="PI[stack_inputs][14]"] input15 [label="PI[stack_inputs][15]"] -input16 [label="M[0]"] -input17 [label="M[1]"] -input18 [label="M[2]"] -input28 [label="M'[0]"] -input29 [label="M'[1]"] -input30 [label="M'[2]"] -input40 [label="g⁻²"] -input41 [label="g⁻¹"] +input16 [label="α"] +input17 [label="β"] +input18 [label="M[0]"] +input19 [label="M[1]"] +input20 [label="M[2]"] +input30 [label="M'[0]"] +input31 [label="M'[1]"] +input32 [label="M'[2]"] input42 [label="⍺"] input43 [label="z"] input44 [label="zⁿ"] -input45 [label="zᵐᵃˣ"] -op0 [label="op0\ninput45 × input45"] -input45 -> op0 -input45 -> op0 +input45 [label="g⁻¹"] +input46 [label="zᵐᵃˣ"] +input47 [label="g⁻²"] +op0 [label="op0\ninput46 × input46"] +input46 -> op0 +input46 -> op0 op1 [label="op1\nconst2 × op0"] const2 -> op1 op0 -> op1 op2 [label="op2\nconst2 + op1"] const2 -> op2 op1 -> op2 -op3 [label="op3\ninput16 × op2"] -input16 -> op3 +op3 [label="op3\ninput18 × op2"] +input18 -> op3 op2 -> op3 -op4 [label="op4\ninput45 × const6"] -input45 -> op4 +op4 [label="op4\ninput46 × const6"] +input46 -> op4 const6 -> op4 op5 [label="op5\nconst5 + op4"] const5 -> op5 op4 -> op5 -op6 [label="op6\ninput45 × op5"] -input45 -> op6 +op6 [label="op6\ninput46 × op5"] +input46 -> op6 op5 -> op6 op7 [label="op7\nconst4 + op6"] const4 -> op7 op6 -> op7 -op8 [label="op8\ninput45 × op7"] -input45 -> op8 +op8 [label="op8\ninput46 × op7"] +input46 -> op8 op7 -> op8 op9 [label="op9\nconst3 + op8"] const3 -> op9 op8 -> op9 -op10 [label="op10\ninput17 × op9"] -input17 -> op10 +op10 [label="op10\ninput19 × op9"] +input19 -> op10 op9 -> op10 -op11 [label="op11\ninput45 × const5"] -input45 -> op11 +op11 [label="op11\ninput46 × const5"] +input46 -> op11 const5 -> op11 op12 [label="op12\nconst5 + op11"] const5 -> op12 op11 -> op12 -op13 [label="op13\ninput45 × op12"] -input45 -> op13 +op13 [label="op13\ninput46 × op12"] +input46 -> op13 op12 -> op13 op14 [label="op14\nconst5 + op13"] const5 -> op14 op13 -> op14 -op15 [label="op15\ninput45 × op14"] -input45 -> op15 +op15 [label="op15\ninput46 × op14"] +input46 -> op15 op14 -> op15 op16 [label="op16\nconst5 + op15"] const5 -> op16 op15 -> op16 -op17 [label="op17\ninput18 × op16"] -input18 -> op17 +op17 [label="op17\ninput20 × op16"] +input20 -> op17 op16 -> op17 op18 [label="op18\ninput43 - const1"] input43 -> op18 const1 -> op18 -op19 [label="op19\ninput43 - input40"] +op19 [label="op19\ninput43 - input47"] input43 -> op19 -input40 -> op19 -op20 [label="op20\ninput43 - input41"] +input47 -> op19 +op20 [label="op20\ninput43 - input45"] input43 -> op20 -input41 -> op20 +input45 -> op20 op21 [label="op21\ninput44 - const1"] input44 -> op21 const1 -> op21 @@ -130,8 +132,8 @@ op29 -> op30 op31 [label="op31\ninput42 × op24"] input42 -> op31 op24 -> op31 -op32 [label="op32\ninput16 × op31"] -input16 -> op32 +op32 [label="op32\ninput18 × op31"] +input18 -> op32 op31 -> op32 op33 [label="op33\nop19 × op21"] op19 -> op33 @@ -145,47 +147,47 @@ op34 -> op35 op36 [label="op36\nop18 × op21"] op18 -> op36 op21 -> op36 -op37 [label="op37\ninput27 × input44"] -input27 -> op37 +op37 [label="op37\ninput29 × input44"] +input29 -> op37 input44 -> op37 -op38 [label="op38\ninput26 + op37"] -input26 -> op38 +op38 [label="op38\ninput28 + op37"] +input28 -> op38 op37 -> op38 op39 [label="op39\ninput44 × op38"] input44 -> op39 op38 -> op39 -op40 [label="op40\ninput25 + op39"] -input25 -> op40 +op40 [label="op40\ninput27 + op39"] +input27 -> op40 op39 -> op40 op41 [label="op41\ninput44 × op40"] input44 -> op41 op40 -> op41 -op42 [label="op42\ninput24 + op41"] -input24 -> op42 +op42 [label="op42\ninput26 + op41"] +input26 -> op42 op41 -> op42 op43 [label="op43\ninput44 × op42"] input44 -> op43 op42 -> op43 -op44 [label="op44\ninput23 + op43"] -input23 -> op44 +op44 [label="op44\ninput25 + op43"] +input25 -> op44 op43 -> op44 op45 [label="op45\ninput44 × op44"] input44 -> op45 op44 -> op45 -op46 [label="op46\ninput22 + op45"] -input22 -> op46 +op46 [label="op46\ninput24 + op45"] +input24 -> op46 op45 -> op46 op47 [label="op47\ninput44 × op46"] input44 -> op47 op46 -> op47 -op48 [label="op48\ninput21 + op47"] -input21 -> op48 +op48 [label="op48\ninput23 + op47"] +input23 -> op48 op47 -> op48 op49 [label="op49\ninput44 × op48"] input44 -> op49 op48 -> op49 -op50 [label="op50\ninput20 + op49"] -input20 -> op50 +op50 [label="op50\ninput22 + op49"] +input22 -> op50 op49 -> op50 op51 [label="op51\nop21 × op27"] op21 -> op51 diff --git a/codegen/ace/tests/regressions/MultipleRows.dot b/codegen/ace/tests/regressions/MultipleRows.dot index 250d7de25..5cc30677a 100644 --- a/codegen/ace/tests/regressions/MultipleRows.dot +++ b/codegen/ace/tests/regressions/MultipleRows.dot @@ -3,42 +3,44 @@ const0 [label="2"] const1 [label="0"] const2 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input16 [label="M'[0]"] -input17 [label="M'[1]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 × const0"] -input4 -> op0 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 × const0"] +input10 -> op0 const0 -> op0 -op1 [label="op1\ninput16 - op0"] -input16 -> op1 +op1 [label="op1\ninput22 - op0"] +input22 -> op1 op0 -> op1 -op2 [label="op2\ninput4 + input5"] -input4 -> op2 -input5 -> op2 -op3 [label="op3\ninput17 - op2"] -input17 -> op3 +op2 [label="op2\ninput10 + input11"] +input10 -> op2 +input11 -> op2 +op3 [label="op3\ninput23 - op2"] +input23 -> op3 op2 -> op3 -op4 [label="op4\ninput31 - const2"] -input31 -> op4 +op4 [label="op4\ninput35 - const2"] +input35 -> op4 const2 -> op4 -op5 [label="op5\ninput31 - input28"] -input31 -> op5 -input28 -> op5 -op6 [label="op6\ninput31 - input29"] -input31 -> op6 -input29 -> op6 -op7 [label="op7\ninput32 - const2"] -input32 -> op7 +op5 [label="op5\ninput35 - input39"] +input35 -> op5 +input39 -> op5 +op6 [label="op6\ninput35 - input37"] +input35 -> op6 +input37 -> op6 +op7 [label="op7\ninput36 - const2"] +input36 -> op7 const2 -> op7 -op8 [label="op8\ninput30 × op3"] -input30 -> op8 +op8 [label="op8\ninput34 × op3"] +input34 -> op8 op3 -> op8 op9 [label="op9\nop1 + op8"] op1 -> op9 @@ -55,11 +57,11 @@ op11 -> op12 op13 [label="op13\nop9 × op12"] op9 -> op13 op12 -> op13 -op14 [label="op14\ninput30 × input30"] -input30 -> op14 -input30 -> op14 -op15 [label="op15\ninput4 × op14"] -input4 -> op15 +op14 [label="op14\ninput34 × input34"] +input34 -> op14 +input34 -> op14 +op15 [label="op15\ninput10 × op14"] +input10 -> op15 op14 -> op15 op16 [label="op16\nop5 × op7"] op5 -> op16 @@ -73,47 +75,47 @@ op17 -> op18 op19 [label="op19\nop4 × op7"] op4 -> op19 op7 -> op19 -op20 [label="op20\ninput15 × input32"] -input15 -> op20 -input32 -> op20 -op21 [label="op21\ninput14 + op20"] -input14 -> op21 +op20 [label="op20\ninput21 × input36"] +input21 -> op20 +input36 -> op20 +op21 [label="op21\ninput20 + op20"] +input20 -> op21 op20 -> op21 -op22 [label="op22\ninput32 × op21"] -input32 -> op22 +op22 [label="op22\ninput36 × op21"] +input36 -> op22 op21 -> op22 -op23 [label="op23\ninput13 + op22"] -input13 -> op23 +op23 [label="op23\ninput19 + op22"] +input19 -> op23 op22 -> op23 -op24 [label="op24\ninput32 × op23"] -input32 -> op24 +op24 [label="op24\ninput36 × op23"] +input36 -> op24 op23 -> op24 -op25 [label="op25\ninput12 + op24"] -input12 -> op25 +op25 [label="op25\ninput18 + op24"] +input18 -> op25 op24 -> op25 -op26 [label="op26\ninput32 × op25"] -input32 -> op26 +op26 [label="op26\ninput36 × op25"] +input36 -> op26 op25 -> op26 -op27 [label="op27\ninput11 + op26"] -input11 -> op27 +op27 [label="op27\ninput17 + op26"] +input17 -> op27 op26 -> op27 -op28 [label="op28\ninput32 × op27"] -input32 -> op28 +op28 [label="op28\ninput36 × op27"] +input36 -> op28 op27 -> op28 -op29 [label="op29\ninput10 + op28"] -input10 -> op29 +op29 [label="op29\ninput16 + op28"] +input16 -> op29 op28 -> op29 -op30 [label="op30\ninput32 × op29"] -input32 -> op30 +op30 [label="op30\ninput36 × op29"] +input36 -> op30 op29 -> op30 -op31 [label="op31\ninput9 + op30"] -input9 -> op31 +op31 [label="op31\ninput15 + op30"] +input15 -> op31 op30 -> op31 -op32 [label="op32\ninput32 × op31"] -input32 -> op32 +op32 [label="op32\ninput36 × op31"] +input36 -> op32 op31 -> op32 -op33 [label="op33\ninput8 + op32"] -input8 -> op33 +op33 [label="op33\ninput14 + op32"] +input14 -> op33 op32 -> op33 op34 [label="op34\nop7 × op10"] op7 -> op34 diff --git a/codegen/ace/tests/regressions/PublicInput.dot b/codegen/ace/tests/regressions/PublicInput.dot index 597f01453..3643371c8 100644 --- a/codegen/ace/tests/regressions/PublicInput.dot +++ b/codegen/ace/tests/regressions/PublicInput.dot @@ -3,30 +3,32 @@ const0 [label="0"] const1 [label="1"] input0 [label="PI[m][0]"] input1 [label="PI[m][1]"] -input4 [label="PI[z][0]"] -input5 [label="PI[z][1]"] -input8 [label="M[0]"] -input9 [label="M[1]"] -input20 [label="M'[0]"] -input21 [label="M'[1]"] -input32 [label="g⁻²"] -input33 [label="g⁻¹"] +input2 [label="PI[z][0]"] +input3 [label="PI[z][1]"] +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] input34 [label="⍺"] input35 [label="z"] input36 [label="zⁿ"] -input37 [label="zᵐᵃˣ"] -op0 [label="op0\ninput8 - input0"] -input8 -> op0 +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 - input0"] +input10 -> op0 input0 -> op0 op1 [label="op1\ninput35 - const1"] input35 -> op1 const1 -> op1 -op2 [label="op2\ninput35 - input32"] +op2 [label="op2\ninput35 - input39"] input35 -> op2 -input32 -> op2 -op3 [label="op3\ninput35 - input33"] +input39 -> op2 +op3 [label="op3\ninput35 - input37"] input35 -> op3 -input33 -> op3 +input37 -> op3 op4 [label="op4\ninput36 - const1"] input36 -> op4 const1 -> op4 @@ -39,8 +41,8 @@ op5 -> op6 op7 [label="op7\nop2 × op6"] op2 -> op7 op6 -> op7 -op8 [label="op8\ninput8 × op7"] -input8 -> op8 +op8 [label="op8\ninput10 × op7"] +input10 -> op8 op7 -> op8 op9 [label="op9\ninput34 × op0"] input34 -> op9 @@ -57,47 +59,47 @@ op11 -> op12 op13 [label="op13\nop1 × op4"] op1 -> op13 op4 -> op13 -op14 [label="op14\ninput19 × input36"] -input19 -> op14 +op14 [label="op14\ninput21 × input36"] +input21 -> op14 input36 -> op14 -op15 [label="op15\ninput18 + op14"] -input18 -> op15 +op15 [label="op15\ninput20 + op14"] +input20 -> op15 op14 -> op15 op16 [label="op16\ninput36 × op15"] input36 -> op16 op15 -> op16 -op17 [label="op17\ninput17 + op16"] -input17 -> op17 +op17 [label="op17\ninput19 + op16"] +input19 -> op17 op16 -> op17 op18 [label="op18\ninput36 × op17"] input36 -> op18 op17 -> op18 -op19 [label="op19\ninput16 + op18"] -input16 -> op19 +op19 [label="op19\ninput18 + op18"] +input18 -> op19 op18 -> op19 op20 [label="op20\ninput36 × op19"] input36 -> op20 op19 -> op20 -op21 [label="op21\ninput15 + op20"] -input15 -> op21 +op21 [label="op21\ninput17 + op20"] +input17 -> op21 op20 -> op21 op22 [label="op22\ninput36 × op21"] input36 -> op22 op21 -> op22 -op23 [label="op23\ninput14 + op22"] -input14 -> op23 +op23 [label="op23\ninput16 + op22"] +input16 -> op23 op22 -> op23 op24 [label="op24\ninput36 × op23"] input36 -> op24 op23 -> op24 -op25 [label="op25\ninput13 + op24"] -input13 -> op25 +op25 [label="op25\ninput15 + op24"] +input15 -> op25 op24 -> op25 op26 [label="op26\ninput36 × op25"] input36 -> op26 op25 -> op26 -op27 [label="op27\ninput12 + op26"] -input12 -> op27 +op27 [label="op27\ninput14 + op26"] +input14 -> op27 op26 -> op27 op28 [label="op28\nop4 × op5"] op4 -> op28 diff --git a/codegen/ace/tests/regressions/Simple.dot b/codegen/ace/tests/regressions/Simple.dot index f1e664466..40d7cff4f 100644 --- a/codegen/ace/tests/regressions/Simple.dot +++ b/codegen/ace/tests/regressions/Simple.dot @@ -2,28 +2,30 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input16 [label="M'[0]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 + input4"] -input4 -> op0 -input4 -> op0 -op1 [label="op1\ninput31 - const1"] -input31 -> op1 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input22 [label="M'[0]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 + input10"] +input10 -> op0 +input10 -> op0 +op1 [label="op1\ninput35 - const1"] +input35 -> op1 const1 -> op1 -op2 [label="op2\ninput31 - input28"] -input31 -> op2 -input28 -> op2 -op3 [label="op3\ninput31 - input29"] -input31 -> op3 -input29 -> op3 -op4 [label="op4\ninput32 - const1"] -input32 -> op4 +op2 [label="op2\ninput35 - input39"] +input35 -> op2 +input39 -> op2 +op3 [label="op3\ninput35 - input37"] +input35 -> op3 +input37 -> op3 +op4 [label="op4\ninput36 - const1"] +input36 -> op4 const1 -> op4 op5 [label="op5\nop1 × op2"] op1 -> op5 @@ -37,9 +39,9 @@ op6 -> op7 op8 [label="op8\nop0 × op7"] op0 -> op8 op7 -> op8 -op9 [label="op9\ninput4 × input30"] -input4 -> op9 -input30 -> op9 +op9 [label="op9\ninput10 × input34"] +input10 -> op9 +input34 -> op9 op10 [label="op10\nop2 × op4"] op2 -> op10 op4 -> op10 @@ -52,47 +54,47 @@ op11 -> op12 op13 [label="op13\nop1 × op4"] op1 -> op13 op4 -> op13 -op14 [label="op14\ninput15 × input32"] -input15 -> op14 -input32 -> op14 -op15 [label="op15\ninput14 + op14"] -input14 -> op15 +op14 [label="op14\ninput21 × input36"] +input21 -> op14 +input36 -> op14 +op15 [label="op15\ninput20 + op14"] +input20 -> op15 op14 -> op15 -op16 [label="op16\ninput32 × op15"] -input32 -> op16 +op16 [label="op16\ninput36 × op15"] +input36 -> op16 op15 -> op16 -op17 [label="op17\ninput13 + op16"] -input13 -> op17 +op17 [label="op17\ninput19 + op16"] +input19 -> op17 op16 -> op17 -op18 [label="op18\ninput32 × op17"] -input32 -> op18 +op18 [label="op18\ninput36 × op17"] +input36 -> op18 op17 -> op18 -op19 [label="op19\ninput12 + op18"] -input12 -> op19 +op19 [label="op19\ninput18 + op18"] +input18 -> op19 op18 -> op19 -op20 [label="op20\ninput32 × op19"] -input32 -> op20 +op20 [label="op20\ninput36 × op19"] +input36 -> op20 op19 -> op20 -op21 [label="op21\ninput11 + op20"] -input11 -> op21 +op21 [label="op21\ninput17 + op20"] +input17 -> op21 op20 -> op21 -op22 [label="op22\ninput32 × op21"] -input32 -> op22 +op22 [label="op22\ninput36 × op21"] +input36 -> op22 op21 -> op22 -op23 [label="op23\ninput10 + op22"] -input10 -> op23 +op23 [label="op23\ninput16 + op22"] +input16 -> op23 op22 -> op23 -op24 [label="op24\ninput32 × op23"] -input32 -> op24 +op24 [label="op24\ninput36 × op23"] +input36 -> op24 op23 -> op24 -op25 [label="op25\ninput9 + op24"] -input9 -> op25 +op25 [label="op25\ninput15 + op24"] +input15 -> op25 op24 -> op25 -op26 [label="op26\ninput32 × op25"] -input32 -> op26 +op26 [label="op26\ninput36 × op25"] +input36 -> op26 op25 -> op26 -op27 [label="op27\ninput8 + op26"] -input8 -> op27 +op27 [label="op27\ninput14 + op26"] +input14 -> op27 op26 -> op27 op28 [label="op28\nop4 × op5"] op4 -> op28 diff --git a/codegen/ace/tests/regressions/SimpleArithmetic.dot b/codegen/ace/tests/regressions/SimpleArithmetic.dot index b21bc9b89..d6df856d5 100644 --- a/codegen/ace/tests/regressions/SimpleArithmetic.dot +++ b/codegen/ace/tests/regressions/SimpleArithmetic.dot @@ -2,54 +2,56 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input16 [label="M'[0]"] -input17 [label="M'[1]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 + input4"] -input4 -> op0 -input4 -> op0 -op1 [label="op1\ninput4 × input4"] -input4 -> op1 -input4 -> op1 -op2 [label="op2\ninput4 + input5"] -input4 -> op2 -input5 -> op2 -op3 [label="op3\ninput5 - input4"] -input5 -> op3 -input4 -> op3 -op4 [label="op4\ninput4 × input5"] -input4 -> op4 -input5 -> op4 -op5 [label="op5\ninput31 - const1"] -input31 -> op5 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 + input10"] +input10 -> op0 +input10 -> op0 +op1 [label="op1\ninput10 × input10"] +input10 -> op1 +input10 -> op1 +op2 [label="op2\ninput10 + input11"] +input10 -> op2 +input11 -> op2 +op3 [label="op3\ninput11 - input10"] +input11 -> op3 +input10 -> op3 +op4 [label="op4\ninput10 × input11"] +input10 -> op4 +input11 -> op4 +op5 [label="op5\ninput35 - const1"] +input35 -> op5 const1 -> op5 -op6 [label="op6\ninput31 - input28"] -input31 -> op6 -input28 -> op6 -op7 [label="op7\ninput31 - input29"] -input31 -> op7 -input29 -> op7 -op8 [label="op8\ninput32 - const1"] -input32 -> op8 +op6 [label="op6\ninput35 - input39"] +input35 -> op6 +input39 -> op6 +op7 [label="op7\ninput35 - input37"] +input35 -> op7 +input37 -> op7 +op8 [label="op8\ninput36 - const1"] +input36 -> op8 const1 -> op8 -op9 [label="op9\ninput30 × input30"] -input30 -> op9 -input30 -> op9 +op9 [label="op9\ninput34 × input34"] +input34 -> op9 +input34 -> op9 op10 [label="op10\nop1 × op9"] op1 -> op10 op9 -> op10 op11 [label="op11\nop0 + op10"] op0 -> op11 op10 -> op11 -op12 [label="op12\ninput30 × op9"] -input30 -> op12 +op12 [label="op12\ninput34 × op9"] +input34 -> op12 op9 -> op12 op13 [label="op13\nop2 × op12"] op2 -> op13 @@ -57,8 +59,8 @@ op12 -> op13 op14 [label="op14\nop11 + op13"] op11 -> op14 op13 -> op14 -op15 [label="op15\ninput30 × op12"] -input30 -> op15 +op15 [label="op15\ninput34 × op12"] +input34 -> op15 op12 -> op15 op16 [label="op16\nop3 × op15"] op3 -> op16 @@ -66,8 +68,8 @@ op15 -> op16 op17 [label="op17\nop14 + op16"] op14 -> op17 op16 -> op17 -op18 [label="op18\ninput30 × op15"] -input30 -> op18 +op18 [label="op18\ninput34 × op15"] +input34 -> op18 op15 -> op18 op19 [label="op19\nop4 × op18"] op4 -> op19 @@ -87,11 +89,11 @@ op22 -> op23 op24 [label="op24\nop20 × op23"] op20 -> op24 op23 -> op24 -op25 [label="op25\ninput30 × op18"] -input30 -> op25 +op25 [label="op25\ninput34 × op18"] +input34 -> op25 op18 -> op25 -op26 [label="op26\ninput4 × op25"] -input4 -> op26 +op26 [label="op26\ninput10 × op25"] +input10 -> op26 op25 -> op26 op27 [label="op27\nop6 × op8"] op6 -> op27 @@ -105,47 +107,47 @@ op28 -> op29 op30 [label="op30\nop5 × op8"] op5 -> op30 op8 -> op30 -op31 [label="op31\ninput15 × input32"] -input15 -> op31 -input32 -> op31 -op32 [label="op32\ninput14 + op31"] -input14 -> op32 +op31 [label="op31\ninput21 × input36"] +input21 -> op31 +input36 -> op31 +op32 [label="op32\ninput20 + op31"] +input20 -> op32 op31 -> op32 -op33 [label="op33\ninput32 × op32"] -input32 -> op33 +op33 [label="op33\ninput36 × op32"] +input36 -> op33 op32 -> op33 -op34 [label="op34\ninput13 + op33"] -input13 -> op34 +op34 [label="op34\ninput19 + op33"] +input19 -> op34 op33 -> op34 -op35 [label="op35\ninput32 × op34"] -input32 -> op35 +op35 [label="op35\ninput36 × op34"] +input36 -> op35 op34 -> op35 -op36 [label="op36\ninput12 + op35"] -input12 -> op36 +op36 [label="op36\ninput18 + op35"] +input18 -> op36 op35 -> op36 -op37 [label="op37\ninput32 × op36"] -input32 -> op37 +op37 [label="op37\ninput36 × op36"] +input36 -> op37 op36 -> op37 -op38 [label="op38\ninput11 + op37"] -input11 -> op38 +op38 [label="op38\ninput17 + op37"] +input17 -> op38 op37 -> op38 -op39 [label="op39\ninput32 × op38"] -input32 -> op39 +op39 [label="op39\ninput36 × op38"] +input36 -> op39 op38 -> op39 -op40 [label="op40\ninput10 + op39"] -input10 -> op40 +op40 [label="op40\ninput16 + op39"] +input16 -> op40 op39 -> op40 -op41 [label="op41\ninput32 × op40"] -input32 -> op41 +op41 [label="op41\ninput36 × op40"] +input36 -> op41 op40 -> op41 -op42 [label="op42\ninput9 + op41"] -input9 -> op42 +op42 [label="op42\ninput15 + op41"] +input15 -> op42 op41 -> op42 -op43 [label="op43\ninput32 × op42"] -input32 -> op43 +op43 [label="op43\ninput36 × op42"] +input36 -> op43 op42 -> op43 -op44 [label="op44\ninput8 + op43"] -input8 -> op44 +op44 [label="op44\ninput14 + op43"] +input14 -> op44 op43 -> op44 op45 [label="op45\nop8 × op21"] op8 -> op45 diff --git a/codegen/ace/tests/regressions/SimpleBoundary.dot b/codegen/ace/tests/regressions/SimpleBoundary.dot index 0b20538bd..c1ca81dc5 100644 --- a/codegen/ace/tests/regressions/SimpleBoundary.dot +++ b/codegen/ace/tests/regressions/SimpleBoundary.dot @@ -2,50 +2,52 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[target][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input6 [label="M[2]"] -input16 [label="M'[0]"] -input17 [label="M'[1]"] -input18 [label="M'[2]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 + input5"] -input4 -> op0 -input5 -> op0 -op1 [label="op1\ninput16 - op0"] -input16 -> op1 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input12 [label="M[2]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] +input24 [label="M'[2]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 + input11"] +input10 -> op0 +input11 -> op0 +op1 [label="op1\ninput22 - op0"] +input22 -> op1 op0 -> op1 -op2 [label="op2\ninput17 - input4"] -input17 -> op2 -input4 -> op2 -op3 [label="op3\ninput4 - const1"] -input4 -> op3 +op2 [label="op2\ninput23 - input10"] +input23 -> op2 +input10 -> op2 +op3 [label="op3\ninput10 - const1"] +input10 -> op3 const1 -> op3 -op4 [label="op4\ninput5 - const1"] -input5 -> op4 +op4 [label="op4\ninput11 - const1"] +input11 -> op4 const1 -> op4 -op5 [label="op5\ninput6 - input0"] -input6 -> op5 +op5 [label="op5\ninput12 - input0"] +input12 -> op5 input0 -> op5 -op6 [label="op6\ninput31 - const1"] -input31 -> op6 +op6 [label="op6\ninput35 - const1"] +input35 -> op6 const1 -> op6 -op7 [label="op7\ninput31 - input28"] -input31 -> op7 -input28 -> op7 -op8 [label="op8\ninput31 - input29"] -input31 -> op8 -input29 -> op8 -op9 [label="op9\ninput32 - const1"] -input32 -> op9 +op7 [label="op7\ninput35 - input39"] +input35 -> op7 +input39 -> op7 +op8 [label="op8\ninput35 - input37"] +input35 -> op8 +input37 -> op8 +op9 [label="op9\ninput36 - const1"] +input36 -> op9 const1 -> op9 -op10 [label="op10\ninput30 × op2"] -input30 -> op10 +op10 [label="op10\ninput34 × op2"] +input34 -> op10 op2 -> op10 op11 [label="op11\nop1 + op10"] op1 -> op11 @@ -62,14 +64,14 @@ op13 -> op14 op15 [label="op15\nop11 × op14"] op11 -> op15 op14 -> op15 -op16 [label="op16\ninput30 × input30"] -input30 -> op16 -input30 -> op16 +op16 [label="op16\ninput34 × input34"] +input34 -> op16 +input34 -> op16 op17 [label="op17\nop3 × op16"] op3 -> op17 op16 -> op17 -op18 [label="op18\ninput30 × op16"] -input30 -> op18 +op18 [label="op18\ninput34 × op16"] +input34 -> op18 op16 -> op18 op19 [label="op19\nop4 × op18"] op4 -> op19 @@ -77,11 +79,11 @@ op18 -> op19 op20 [label="op20\nop17 + op19"] op17 -> op20 op19 -> op20 -op21 [label="op21\ninput30 × op18"] -input30 -> op21 +op21 [label="op21\ninput34 × op18"] +input34 -> op21 op18 -> op21 -op22 [label="op22\ninput6 × op21"] -input6 -> op22 +op22 [label="op22\ninput12 × op21"] +input12 -> op22 op21 -> op22 op23 [label="op23\nop20 + op22"] op20 -> op23 @@ -95,8 +97,8 @@ op24 -> op25 op26 [label="op26\nop15 + op25"] op15 -> op26 op25 -> op26 -op27 [label="op27\ninput30 × op21"] -input30 -> op27 +op27 [label="op27\ninput34 × op21"] +input34 -> op27 op21 -> op27 op28 [label="op28\nop5 × op27"] op5 -> op28 @@ -110,47 +112,47 @@ op29 -> op30 op31 [label="op31\nop26 + op30"] op26 -> op31 op30 -> op31 -op32 [label="op32\ninput15 × input32"] -input15 -> op32 -input32 -> op32 -op33 [label="op33\ninput14 + op32"] -input14 -> op33 +op32 [label="op32\ninput21 × input36"] +input21 -> op32 +input36 -> op32 +op33 [label="op33\ninput20 + op32"] +input20 -> op33 op32 -> op33 -op34 [label="op34\ninput32 × op33"] -input32 -> op34 +op34 [label="op34\ninput36 × op33"] +input36 -> op34 op33 -> op34 -op35 [label="op35\ninput13 + op34"] -input13 -> op35 +op35 [label="op35\ninput19 + op34"] +input19 -> op35 op34 -> op35 -op36 [label="op36\ninput32 × op35"] -input32 -> op36 +op36 [label="op36\ninput36 × op35"] +input36 -> op36 op35 -> op36 -op37 [label="op37\ninput12 + op36"] -input12 -> op37 +op37 [label="op37\ninput18 + op36"] +input18 -> op37 op36 -> op37 -op38 [label="op38\ninput32 × op37"] -input32 -> op38 +op38 [label="op38\ninput36 × op37"] +input36 -> op38 op37 -> op38 -op39 [label="op39\ninput11 + op38"] -input11 -> op39 +op39 [label="op39\ninput17 + op38"] +input17 -> op39 op38 -> op39 -op40 [label="op40\ninput32 × op39"] -input32 -> op40 +op40 [label="op40\ninput36 × op39"] +input36 -> op40 op39 -> op40 -op41 [label="op41\ninput10 + op40"] -input10 -> op41 +op41 [label="op41\ninput16 + op40"] +input16 -> op41 op40 -> op41 -op42 [label="op42\ninput32 × op41"] -input32 -> op42 +op42 [label="op42\ninput36 × op41"] +input36 -> op42 op41 -> op42 -op43 [label="op43\ninput9 + op42"] -input9 -> op43 +op43 [label="op43\ninput15 + op42"] +input15 -> op43 op42 -> op43 -op44 [label="op44\ninput32 × op43"] -input32 -> op44 +op44 [label="op44\ninput36 × op43"] +input36 -> op44 op43 -> op44 -op45 [label="op45\ninput8 + op44"] -input8 -> op45 +op45 [label="op45\ninput14 + op44"] +input14 -> op45 op44 -> op45 op46 [label="op46\nop9 × op12"] op9 -> op46 diff --git a/codegen/ace/tests/regressions/SimpleIntegrityAux.dot b/codegen/ace/tests/regressions/SimpleIntegrityAux.dot index 18c704477..af7a23064 100644 --- a/codegen/ace/tests/regressions/SimpleIntegrityAux.dot +++ b/codegen/ace/tests/regressions/SimpleIntegrityAux.dot @@ -1,27 +1,29 @@ digraph G { -const0 [label="0"] -const1 [label="1"] +const0 [label="1"] +const1 [label="0"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input16 [label="M'[0]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput31 - const1"] -input31 -> op0 -const1 -> op0 -op1 [label="op1\ninput31 - input28"] -input31 -> op1 -input28 -> op1 -op2 [label="op2\ninput31 - input29"] -input31 -> op2 -input29 -> op2 -op3 [label="op3\ninput32 - const1"] -input32 -> op3 -const1 -> op3 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input22 [label="M'[0]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput35 - const0"] +input35 -> op0 +const0 -> op0 +op1 [label="op1\ninput35 - input39"] +input35 -> op1 +input39 -> op1 +op2 [label="op2\ninput35 - input37"] +input35 -> op2 +input37 -> op2 +op3 [label="op3\ninput36 - const0"] +input36 -> op3 +const0 -> op3 op4 [label="op4\nop0 × op1"] op0 -> op4 op1 -> op4 @@ -31,12 +33,12 @@ op4 -> op5 op6 [label="op6\nop1 × op5"] op1 -> op6 op5 -> op6 -op7 [label="op7\ninput4 × op6"] -input4 -> op7 +op7 [label="op7\ninput10 × op6"] +input10 -> op7 op6 -> op7 -op8 [label="op8\ninput4 × input30"] -input4 -> op8 -input30 -> op8 +op8 [label="op8\ninput10 × input34"] +input10 -> op8 +input34 -> op8 op9 [label="op9\nop1 × op3"] op1 -> op9 op3 -> op9 @@ -49,47 +51,47 @@ op10 -> op11 op12 [label="op12\nop0 × op3"] op0 -> op12 op3 -> op12 -op13 [label="op13\ninput15 × input32"] -input15 -> op13 -input32 -> op13 -op14 [label="op14\ninput14 + op13"] -input14 -> op14 +op13 [label="op13\ninput21 × input36"] +input21 -> op13 +input36 -> op13 +op14 [label="op14\ninput20 + op13"] +input20 -> op14 op13 -> op14 -op15 [label="op15\ninput32 × op14"] -input32 -> op15 +op15 [label="op15\ninput36 × op14"] +input36 -> op15 op14 -> op15 -op16 [label="op16\ninput13 + op15"] -input13 -> op16 +op16 [label="op16\ninput19 + op15"] +input19 -> op16 op15 -> op16 -op17 [label="op17\ninput32 × op16"] -input32 -> op17 +op17 [label="op17\ninput36 × op16"] +input36 -> op17 op16 -> op17 -op18 [label="op18\ninput12 + op17"] -input12 -> op18 +op18 [label="op18\ninput18 + op17"] +input18 -> op18 op17 -> op18 -op19 [label="op19\ninput32 × op18"] -input32 -> op19 +op19 [label="op19\ninput36 × op18"] +input36 -> op19 op18 -> op19 -op20 [label="op20\ninput11 + op19"] -input11 -> op20 +op20 [label="op20\ninput17 + op19"] +input17 -> op20 op19 -> op20 -op21 [label="op21\ninput32 × op20"] -input32 -> op21 +op21 [label="op21\ninput36 × op20"] +input36 -> op21 op20 -> op21 -op22 [label="op22\ninput10 + op21"] -input10 -> op22 +op22 [label="op22\ninput16 + op21"] +input16 -> op22 op21 -> op22 -op23 [label="op23\ninput32 × op22"] -input32 -> op23 +op23 [label="op23\ninput36 × op22"] +input36 -> op23 op22 -> op23 -op24 [label="op24\ninput9 + op23"] -input9 -> op24 +op24 [label="op24\ninput15 + op23"] +input15 -> op24 op23 -> op24 -op25 [label="op25\ninput32 × op24"] -input32 -> op25 +op25 [label="op25\ninput36 × op24"] +input36 -> op25 op24 -> op25 -op26 [label="op26\ninput8 + op25"] -input8 -> op26 +op26 [label="op26\ninput14 + op25"] +input14 -> op26 op25 -> op26 op27 [label="op27\nop3 × op4"] op3 -> op27 diff --git a/codegen/ace/tests/regressions/Vector.dot b/codegen/ace/tests/regressions/Vector.dot index 8518250ed..11abc157f 100644 --- a/codegen/ace/tests/regressions/Vector.dot +++ b/codegen/ace/tests/regressions/Vector.dot @@ -2,35 +2,37 @@ digraph G { const0 [label="0"] const1 [label="1"] input0 [label="PI[stack_inputs][0]"] -input4 [label="M[0]"] -input5 [label="M[1]"] -input6 [label="M[2]"] -input16 [label="M'[0]"] -input17 [label="M'[1]"] -input18 [label="M'[2]"] -input28 [label="g⁻²"] -input29 [label="g⁻¹"] -input30 [label="⍺"] -input31 [label="z"] -input32 [label="zⁿ"] -input33 [label="zᵐᵃˣ"] -op0 [label="op0\ninput4 - input5"] -input4 -> op0 -input5 -> op0 -op1 [label="op1\ninput6 + op0"] -input6 -> op1 +input8 [label="α"] +input9 [label="β"] +input10 [label="M[0]"] +input11 [label="M[1]"] +input12 [label="M[2]"] +input22 [label="M'[0]"] +input23 [label="M'[1]"] +input24 [label="M'[2]"] +input34 [label="⍺"] +input35 [label="z"] +input36 [label="zⁿ"] +input37 [label="g⁻¹"] +input38 [label="zᵐᵃˣ"] +input39 [label="g⁻²"] +op0 [label="op0\ninput10 - input11"] +input10 -> op0 +input11 -> op0 +op1 [label="op1\ninput12 + op0"] +input12 -> op1 op0 -> op1 -op2 [label="op2\ninput31 - const1"] -input31 -> op2 +op2 [label="op2\ninput35 - const1"] +input35 -> op2 const1 -> op2 -op3 [label="op3\ninput31 - input28"] -input31 -> op3 -input28 -> op3 -op4 [label="op4\ninput31 - input29"] -input31 -> op4 -input29 -> op4 -op5 [label="op5\ninput32 - const1"] -input32 -> op5 +op3 [label="op3\ninput35 - input39"] +input35 -> op3 +input39 -> op3 +op4 [label="op4\ninput35 - input37"] +input35 -> op4 +input37 -> op4 +op5 [label="op5\ninput36 - const1"] +input36 -> op5 const1 -> op5 op6 [label="op6\nop2 × op3"] op2 -> op6 @@ -44,9 +46,9 @@ op7 -> op8 op9 [label="op9\nop1 × op8"] op1 -> op9 op8 -> op9 -op10 [label="op10\ninput4 × input30"] -input4 -> op10 -input30 -> op10 +op10 [label="op10\ninput10 × input34"] +input10 -> op10 +input34 -> op10 op11 [label="op11\nop3 × op5"] op3 -> op11 op5 -> op11 @@ -59,47 +61,47 @@ op12 -> op13 op14 [label="op14\nop2 × op5"] op2 -> op14 op5 -> op14 -op15 [label="op15\ninput15 × input32"] -input15 -> op15 -input32 -> op15 -op16 [label="op16\ninput14 + op15"] -input14 -> op16 +op15 [label="op15\ninput21 × input36"] +input21 -> op15 +input36 -> op15 +op16 [label="op16\ninput20 + op15"] +input20 -> op16 op15 -> op16 -op17 [label="op17\ninput32 × op16"] -input32 -> op17 +op17 [label="op17\ninput36 × op16"] +input36 -> op17 op16 -> op17 -op18 [label="op18\ninput13 + op17"] -input13 -> op18 +op18 [label="op18\ninput19 + op17"] +input19 -> op18 op17 -> op18 -op19 [label="op19\ninput32 × op18"] -input32 -> op19 +op19 [label="op19\ninput36 × op18"] +input36 -> op19 op18 -> op19 -op20 [label="op20\ninput12 + op19"] -input12 -> op20 +op20 [label="op20\ninput18 + op19"] +input18 -> op20 op19 -> op20 -op21 [label="op21\ninput32 × op20"] -input32 -> op21 +op21 [label="op21\ninput36 × op20"] +input36 -> op21 op20 -> op21 -op22 [label="op22\ninput11 + op21"] -input11 -> op22 +op22 [label="op22\ninput17 + op21"] +input17 -> op22 op21 -> op22 -op23 [label="op23\ninput32 × op22"] -input32 -> op23 +op23 [label="op23\ninput36 × op22"] +input36 -> op23 op22 -> op23 -op24 [label="op24\ninput10 + op23"] -input10 -> op24 +op24 [label="op24\ninput16 + op23"] +input16 -> op24 op23 -> op24 -op25 [label="op25\ninput32 × op24"] -input32 -> op25 +op25 [label="op25\ninput36 × op24"] +input36 -> op25 op24 -> op25 -op26 [label="op26\ninput9 + op25"] -input9 -> op26 +op26 [label="op26\ninput15 + op25"] +input15 -> op26 op25 -> op26 -op27 [label="op27\ninput32 × op26"] -input32 -> op27 +op27 [label="op27\ninput36 × op26"] +input36 -> op27 op26 -> op27 -op28 [label="op28\ninput8 + op27"] -input8 -> op28 +op28 [label="op28\ninput14 + op27"] +input14 -> op28 op27 -> op28 op29 [label="op29\nop5 × op6"] op5 -> op29 diff --git a/codegen/winterfell/Cargo.toml b/codegen/winterfell/Cargo.toml index 30f8b6b13..32b753408 100644 --- a/codegen/winterfell/Cargo.toml +++ b/codegen/winterfell/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-codegen-winter" -version = "0.4.0" +version = "0.5.0" description = "Winterfell code generator for the AirScript language" authors.workspace = true readme = "README.md" @@ -12,6 +12,6 @@ edition.workspace = true rust-version.workspace = true [dependencies] -air-ir = { package = "air-ir", path = "../../air", version = "0.4" } +air-ir = { package = "air-ir", path = "../../air", version = "0.5" } anyhow = { workspace = true } codegen = "0.2" diff --git a/codegen/winterfell/README.md b/codegen/winterfell/README.md index e84d6eff6..dd88d0aa9 100644 --- a/codegen/winterfell/README.md +++ b/codegen/winterfell/README.md @@ -16,11 +16,11 @@ Example usage: // parse the source string to a Result containing the AST or an Error let ast = parse(source.as_str()).expect("Parsing failed"); -// process the AST to get a Result containing the AirIR or an Error -let ir = AirIR::new(&ast).expect("AIR is invalid"); +// Compile AST into AIR +let air = compile(&diagnostics, ast).expect("compilation failed"); // generate Rust code targeting the Winterfell prover -let rust_code = CodeGenerator::new(&ir); +let rust_code = CodeGenerator::new(&air); ``` ## Generated Winterfell Rust Code diff --git a/codegen/winterfell/src/air/boundary_constraints.rs b/codegen/winterfell/src/air/boundary_constraints.rs index cc43bb372..41acc5dc5 100644 --- a/codegen/winterfell/src/air/boundary_constraints.rs +++ b/codegen/winterfell/src/air/boundary_constraints.rs @@ -1,10 +1,11 @@ use core::panic; -use air_ir::{Air, AlgebraicGraph, ConstraintDomain, NodeIndex, Operation, TraceAccess, Value}; - -use crate::air::call_bus_boundary_varlen_pubinput; +use air_ir::{ + Air, AlgebraicGraph, ConstraintDomain, NodeIndex, Operation, TraceAccess, TraceSegmentId, +}; use super::{Codegen, ElemType, Impl}; +use crate::air::call_bus_boundary_varlen_pubinput; // HELPERS TO GENERATE THE WINTERFELL BOUNDARY CONSTRAINT METHODS // ================================================================================================ @@ -14,10 +15,8 @@ use super::{Codegen, ElemType, Impl}; /// TODO: add result types to these functions. pub(super) fn add_fn_get_assertions(impl_ref: &mut Impl, ir: &Air) { // define the function - let get_assertions = impl_ref - .new_fn("get_assertions") - .arg_ref_self() - .ret("Vec>"); + let get_assertions = + impl_ref.new_fn("get_assertions").arg_ref_self().ret("Vec>"); // add the boundary constraints add_main_trace_assertions(get_assertions, ir); @@ -47,19 +46,19 @@ pub(super) fn add_fn_get_aux_assertions(impl_ref: &mut Impl, ir: &Air) { /// Declares a result vector and adds assertions for boundary constraints to it for the main /// trace segment fn add_main_trace_assertions(func_body: &mut codegen::Function, ir: &Air) { - let elem_type = ElemType::Base; - let main_trace_segment = 0; - // declare the result vector to be returned. func_body.line("let mut result = Vec::new();"); // add the main boundary constraints - for constraint in ir.boundary_constraints(main_trace_segment) { + for constraint in ir.boundary_constraints(TraceSegmentId::Main) { let (trace_access, expr_root) = split_boundary_constraint(ir.constraint_graph(), constraint.node_index()); - debug_assert_eq!(trace_access.segment, main_trace_segment); + debug_assert_eq!(trace_access.segment, TraceSegmentId::Main); - let expr_root_string = expr_root.to_string(ir, elem_type, main_trace_segment); + let expr_root_string = match expr_root { + Some(node_index) => node_index.to_string(ir, ElemType::Base, TraceSegmentId::Main), + None => "Felt::ZERO".to_string(), // If no root, the expression is zero + }; let assertion = format!( "result.push(Assertion::single({}, {}, {}));", @@ -75,20 +74,44 @@ fn add_main_trace_assertions(func_body: &mut codegen::Function, ir: &Air) { /// Declares a result vector and adds assertions for boundary constraints to it for the aux /// trace segment (used for buses boundary constraints for variable length public inputs) fn add_aux_trace_assertions(func_body: &mut codegen::Function, ir: &Air) { - let elem_type = ElemType::Ext; - let aux_trace_segment = 1; - // declare the result vector to be returned. func_body.line("let mut result = Vec::new();"); + // Add expressions for evaluating the reduced public input table. Its expression is defined as + // `reduced_{TABLE_NAME}_{BUS_TYPE}`. + // This ensures that if two busses of the same type are constrained at a boundary to the same + // public input table, the codegen generates the same lines. These should easily be optimized + // by the compiler. + // TODO: These values are constant across all rows and therefore can be computed only once + // before starting the constraint evaluation. + for access in ir.reduced_public_input_table_accesses() { + let boundary_value = air_ir::Value::PublicInputTable(access).to_string( + ir, + ElemType::Ext, + TraceSegmentId::Aux, + ); + let expr_root_string = call_bus_boundary_varlen_pubinput(access); + + let boundary_value_init = format!("let {boundary_value} = {expr_root_string};"); + + func_body.line(boundary_value_init); + } + // add the boundary constraints that have already be expanded in the algebraic graph // (currently, empty buses constraints) - for constraint in ir.boundary_constraints(aux_trace_segment) { + for constraint in ir.boundary_constraints(TraceSegmentId::Aux) { let (trace_access, expr_root) = split_boundary_constraint(ir.constraint_graph(), constraint.node_index()); - debug_assert_eq!(trace_access.segment, aux_trace_segment); + debug_assert_eq!(trace_access.segment, TraceSegmentId::Aux); - let expr_root_string = expr_root.to_string(ir, elem_type, aux_trace_segment); + // In the graph, empty buses are either constrained by 0 (for logup buses) or 1 (for + // multiset buses). However, because of Common Subexpression Elimination, the `0` + // constant will not be inserted into the graph and the `split_boundary_constraint` function + // will return a `None` value, so we should handle this case separately. + let expr_root_string = match expr_root { + Some(node_index) => node_index.to_string(ir, ElemType::Ext, TraceSegmentId::Aux), + None => "E::ZERO".to_string(), + }; let assertion = format!( "result.push(Assertion::single({}, {}, {}));", @@ -99,39 +122,6 @@ fn add_aux_trace_assertions(func_body: &mut codegen::Function, ir: &Air) { func_body.line(assertion); } - - let domains = [ConstraintDomain::FirstRow, ConstraintDomain::LastRow]; - - for domain in &domains { - for (index, bus) in ir.buses.values().enumerate() { - let bus_boundary = match domain { - ConstraintDomain::FirstRow => &bus.first, - ConstraintDomain::LastRow => &bus.last, - _ => unreachable!("Invalid domain for bus boundary constraint"), - }; - - match bus_boundary { - air_ir::BusBoundary::PublicInputTable(air_ir::PublicInputTableAccess { - bus_name, - table_name, - .. - }) => { - let expr_root_string = - call_bus_boundary_varlen_pubinput(ir, *bus_name, *table_name); - - let assertion = format!( - "result.push(Assertion::single({}, {}, {}));", - index, - domain_to_str(*domain), - expr_root_string - ); - - func_body.line(assertion); - } - air_ir::BusBoundary::Null | air_ir::BusBoundary::Unconstrained => {} - } - } - } } /// Returns a string slice representing the provided constraint domain. @@ -152,23 +142,31 @@ fn domain_to_str(domain: ConstraintDomain) -> String { /// boundary constraint expression must hold, as well as the node index that represents the root /// of the constraint expression that must equal zero during evaluation. /// -/// TODO: replace panics with Result and Error +/// Note: If, after the CSE pass, the boundary constraint is a single trace access, +/// we return None for the constraint expression. This expression should then be assumed to be zero +/// during evaluation by the caller. pub fn split_boundary_constraint( graph: &AlgebraicGraph, index: &NodeIndex, -) -> (TraceAccess, NodeIndex) { +) -> (TraceAccess, Option) { let node = graph.node(index); - match node.op() { + match *node.op() { Operation::Sub(lhs, rhs) => { - if let Operation::Value(Value::TraceAccess(trace_access)) = graph.node(lhs).op() { + if let Operation::Value(air_ir::Value::TraceAccess(trace_access)) = + graph.node(&lhs).op() + { debug_assert_eq!(trace_access.row_offset, 0); - (*trace_access, *rhs) + (*trace_access, Some(rhs)) } else { panic!( "InvalidUsage: index {index:?} is not the constraint root of a boundary constraint" ); } - } + }, + Operation::Value(air_ir::Value::TraceAccess(trace_access)) => { + debug_assert_eq!(trace_access.row_offset, 0); + (trace_access, None) + }, _ => panic!("InvalidUsage: index {index:?} is not the root index of a constraint"), } } diff --git a/codegen/winterfell/src/air/graph.rs b/codegen/winterfell/src/air/graph.rs index 676b84223..6ea55b9cd 100644 --- a/codegen/winterfell/src/air/graph.rs +++ b/codegen/winterfell/src/air/graph.rs @@ -25,28 +25,24 @@ impl Codegen for IntegrityConstraintDegree { .map(|cycle_len| cycle_len.to_string()) .collect::>() .join(", "); - format!( - "TransitionConstraintDegree::with_cycles({}, vec![{}])", - self.base(), - cycles - ) + format!("TransitionConstraintDegree::with_cycles({}, vec![{}])", self.base(), cycles) } } } impl Codegen for TraceAccess { fn to_string(&self, _ir: &Air, _elem_type: ElemType, trace_segment: TraceSegmentId) -> String { - let frame = if self.segment == 0 { "main" } else { "aux" }; + let frame = self.segment.to_string(); let row_offset = match self.row_offset { 0 => { format!("current[{}]", self.column) - } + }, 1 => { format!("next[{}]", self.column) - } + }, _ => panic!("Winterfell doesn't support row offsets greater than 1."), }; - if self.segment == 0 && self.segment != trace_segment { + if self.segment == TraceSegmentId::Main && self.segment != trace_segment { format!("E::from({frame}_{row_offset})") } else { format!("{frame}_{row_offset}") @@ -65,9 +61,9 @@ impl Codegen for Operation { fn to_string(&self, ir: &Air, elem_type: ElemType, trace_segment: TraceSegmentId) -> String { match self { Operation::Value(value) => value.to_string(ir, elem_type, trace_segment), - Operation::Add(_, _) => binary_op_to_string(ir, self, elem_type, trace_segment), - Operation::Sub(_, _) => binary_op_to_string(ir, self, elem_type, trace_segment), - Operation::Mul(_, _) => binary_op_to_string(ir, self, elem_type, trace_segment), + Operation::Add(..) => binary_op_to_string(ir, self, elem_type, trace_segment), + Operation::Sub(..) => binary_op_to_string(ir, self, elem_type, trace_segment), + Operation::Mul(..) => binary_op_to_string(ir, self, elem_type, trace_segment), } } } @@ -90,21 +86,25 @@ impl Codegen for Value { }, Value::TraceAccess(trace_access) => { trace_access.to_string(ir, elem_type, trace_segment) - } + }, Value::PeriodicColumn(pc) => { - let index = ir - .periodic_columns - .iter() - .position(|(qid, _)| qid == &pc.name) - .unwrap(); + let index = + ir.periodic_columns.iter().position(|(qid, _)| qid == &pc.name).unwrap(); format!("periodic_values[{index}]") - } + }, Value::PublicInput(air_ir::PublicInputAccess { name, index }) => { format!("self.{name}[{index}]") - } + }, + Value::PublicInputTable(air_ir::PublicInputTableAccess { + table_name, + bus_type, + num_cols: _, + }) => { + format!("reduced_{table_name}_{bus_type}") + }, Value::RandomValue(idx) => { format!("aux_rand_elements.rand_elements()[{idx}]") - } + }, } } } @@ -121,7 +121,7 @@ fn binary_op_to_string( let lhs = l_idx.to_string(ir, elem_type, trace_segment); let rhs = r_idx.to_string(ir, elem_type, trace_segment); format!("{lhs} + {rhs}") - } + }, Operation::Sub(l_idx, r_idx) => { let lhs = l_idx.to_string(ir, elem_type, trace_segment); let rhs = if ir.constraint_graph().node(r_idx).op().precedence() <= op.precedence() { @@ -130,7 +130,7 @@ fn binary_op_to_string( r_idx.to_string(ir, elem_type, trace_segment) }; format!("{lhs} - {rhs}") - } + }, Operation::Mul(l_idx, r_idx) => { let lhs = if ir.constraint_graph().node(l_idx).op().precedence() < op.precedence() { format!("({})", l_idx.to_string(ir, elem_type, trace_segment)) @@ -143,7 +143,7 @@ fn binary_op_to_string( r_idx.to_string(ir, elem_type, trace_segment) }; format!("{lhs} * {rhs}") - } + }, _ => panic!("unsupported operation"), } } diff --git a/codegen/winterfell/src/air/mod.rs b/codegen/winterfell/src/air/mod.rs index 1cc6590ff..e8f18e78b 100644 --- a/codegen/winterfell/src/air/mod.rs +++ b/codegen/winterfell/src/air/mod.rs @@ -12,10 +12,9 @@ mod boundary_constraints; use boundary_constraints::{add_fn_get_assertions, add_fn_get_aux_assertions}; mod transition_constraints; +use air_ir::{Air, BusBoundary, BusType, ConstraintDomain, PublicInputTableAccess, TraceSegmentId}; use transition_constraints::{add_fn_evaluate_aux_transition, add_fn_evaluate_transition}; -use air_ir::{Air, BusBoundary, BusType, ConstraintDomain, Identifier, TraceSegmentId}; - use super::{Impl, Scope}; // HELPER TYPES @@ -48,17 +47,11 @@ pub(super) fn add_air(scope: &mut Scope, ir: &Air) { /// Updates the provided scope with a custom Air struct. fn add_air_struct(scope: &mut Scope, ir: &Air, name: &str) { // define the custom Air struct. - let air_struct = scope - .new_struct(name) - .vis("pub") - .field("context", "AirContext"); + let air_struct = scope.new_struct(name).vis("pub").field("context", "AirContext"); // add public inputs for public_input in ir.public_inputs() { - air_struct.field( - public_input.name().as_str(), - public_input_type_to_string(public_input), - ); + air_struct.field(public_input.name().as_str(), public_input_type_to_string(public_input)); } // add the custom Air implementation block @@ -73,17 +66,17 @@ fn add_air_struct(scope: &mut Scope, ir: &Air, name: &str) { // add a method to get the variable length public inputs bus boundary constraints. let (mut add_bus_multiset_boundary_varlen, mut add_bus_logup_boundary_varlen) = (false, false); for bus in ir.buses.values() { - // Check which bus type is refering to variable length public inputs + // Check which bus type is referring to variable length public inputs let bus_constraints = [&bus.first, &bus.last]; for fl in bus_constraints { if let BusBoundary::PublicInputTable(_) = fl { match bus.bus_type { BusType::Multiset => { add_bus_multiset_boundary_varlen = true; - } + }, BusType::Logup => { add_bus_logup_boundary_varlen = true; - } + }, } } } @@ -120,15 +113,15 @@ fn impl_bus_multiset_boundary_varlen(base_impl: &mut Impl) { .new_fn("bus_multiset_boundary_varlen") .generic("'a") .generic("const N: usize") - .generic("I: IntoIterator + Clone") + .generic("I: IntoIterator") .generic("E: FieldElement") .arg("aux_rand_elements", "&AuxRandElements") - .arg("public_inputs", "&I") + .arg("public_inputs", "I") .ret("E") .vis("pub") .line("let mut bus_p_last: E = E::ONE;") .line("let rand = aux_rand_elements.rand_elements();") - .line("for row in public_inputs.clone().into_iter() {") + .line("for row in public_inputs {") .line(" let mut p_last = rand[0];") .line(" for (c, p_i) in row.iter().enumerate() {") .line(" p_last += E::from(*p_i) * rand[c + 1];") @@ -164,15 +157,15 @@ fn impl_bus_logup_boundary_varlen(base_impl: &mut Impl) { .new_fn("bus_logup_boundary_varlen") .generic("'a") .generic("const N: usize") - .generic("I: IntoIterator + Clone") + .generic("I: IntoIterator") .generic("E: FieldElement") .arg("aux_rand_elements", "&AuxRandElements") - .arg("public_inputs", "&I") + .arg("public_inputs", "I") .ret("E") .vis("pub") .line("let mut bus_q_last = E::ZERO;") .line("let rand = aux_rand_elements.rand_elements();") - .line("for row in public_inputs.clone().into_iter() {") + .line("for row in public_inputs {") .line(" let mut q_last = rand[0];") .line(" for (c, p_i) in row.iter().enumerate() {") .line(" let p_i = *p_i;") @@ -194,10 +187,7 @@ fn add_air_trait(scope: &mut Scope, ir: &Air, name: &str) { .associate_type("PublicInputs", "PublicInputs"); // add default function "context". - let fn_context = air_impl - .new_fn("context") - .arg_ref_self() - .ret("&AirContext"); + let fn_context = air_impl.new_fn("context").arg_ref_self().ret("&AirContext"); fn_context.line("&self.context"); // add the method implementations required by the AIR trait. @@ -226,22 +216,19 @@ fn add_fn_new(impl_ref: &mut Impl, ir: &Air) { .ret("Self"); // define the integrity constraint degrees of the main trace `main_degrees`. - add_constraint_degrees(new, ir, 0, "main_degrees"); + add_constraint_degrees(new, ir, TraceSegmentId::Main, "main_degrees"); // define the integrity constraint degrees of the aux trace `aux_degrees`. - add_constraint_degrees(new, ir, 1, "aux_degrees"); + add_constraint_degrees(new, ir, TraceSegmentId::Aux, "aux_degrees"); // define the number of main trace boundary constraints `num_main_assertions`. new.line(format!( "let num_main_assertions = {};", - ir.num_boundary_constraints(0) + ir.num_boundary_constraints(TraceSegmentId::Main) )); // define the number of aux trace boundary constraints `num_aux_assertions`. - new.line(format!( - "let num_aux_assertions = {};", - num_bus_boundary_constraints(ir) - )); + new.line(format!("let num_aux_assertions = {};", num_bus_boundary_constraints(ir))); // define the context. let context = " @@ -283,19 +270,20 @@ fn add_constraint_degrees( func_body.line(format!("let {decl_name} = vec![{}];", degrees.join(", "))); } -fn call_bus_boundary_varlen_pubinput( - ir: &Air, - bus_name: Identifier, - table_name: Identifier, -) -> String { - let bus = ir.buses.get(&bus_name).expect("bus not found"); - match bus.bus_type { - BusType::Multiset => format!( - "Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.{table_name}.iter())", - ), +fn call_bus_boundary_varlen_pubinput(access: PublicInputTableAccess) -> String { + match access.bus_type { + BusType::Multiset => { + format!( + "Self::bus_multiset_boundary_varlen(aux_rand_elements, &self.{})", + access.table_name + ) + }, BusType::Logup => { - format!("Self::bus_logup_boundary_varlen(aux_rand_elements, &self.{table_name}.iter())",) - } + format!( + "Self::bus_logup_boundary_varlen(aux_rand_elements, &self.{})", + access.table_name + ) + }, } } @@ -314,8 +302,8 @@ fn num_bus_boundary_constraints(ir: &Air) -> usize { match bus_boundary { air_ir::BusBoundary::PublicInputTable(_) | air_ir::BusBoundary::Null => { num_bus_boundary_constraints += 1; - } - air_ir::BusBoundary::Unconstrained => {} + }, + air_ir::BusBoundary::Unconstrained => {}, } } } diff --git a/codegen/winterfell/src/air/periodic_columns.rs b/codegen/winterfell/src/air/periodic_columns.rs index 4a9c24ef4..234b518a5 100644 --- a/codegen/winterfell/src/air/periodic_columns.rs +++ b/codegen/winterfell/src/air/periodic_columns.rs @@ -30,13 +30,13 @@ impl Codegen for &BTreeMap { match row { 0 => { rows.push("Felt::ZERO".to_string()); - } + }, 1 => { rows.push("Felt::ONE".to_string()); - } + }, row => { rows.push(format!("Felt::new({row})")); - } + }, } } columns.push(format!("vec![{}]", rows.join(", "))); diff --git a/codegen/winterfell/src/air/public_inputs.rs b/codegen/winterfell/src/air/public_inputs.rs index 79e0a1474..f4bf0d827 100644 --- a/codegen/winterfell/src/air/public_inputs.rs +++ b/codegen/winterfell/src/air/public_inputs.rs @@ -16,19 +16,15 @@ pub(super) fn add_public_inputs_struct(scope: &mut Scope, ir: &Air) { let pub_inputs_struct = scope.new_struct(name).vis("pub"); for public_input in ir.public_inputs() { - pub_inputs_struct.field( - public_input.name().as_str(), - public_input_type_to_string(public_input), - ); + pub_inputs_struct + .field(public_input.name().as_str(), public_input_type_to_string(public_input)); } // add the public inputs implementation block let base_impl = scope.new_impl(name); - let pub_inputs_values: Vec = ir - .public_inputs() - .map(|input| input.name().to_string()) - .collect(); + let pub_inputs_values: Vec = + ir.public_inputs().map(|input| input.name().to_string()).collect(); // add a constructor for public inputs let new_fn = base_impl @@ -37,37 +33,27 @@ pub(super) fn add_public_inputs_struct(scope: &mut Scope, ir: &Air) { .ret("Self") .line(format!("Self {{ {} }}", pub_inputs_values.join(", "))); for public_input in ir.public_inputs() { - new_fn.arg( - public_input.name().as_str(), - public_input_type_to_string(public_input), - ); + new_fn.arg(public_input.name().as_str(), public_input_type_to_string(public_input)); } add_serializable_impl(scope, pub_inputs_values.clone()); // add a to_elements implementation - let to_elements_impl = scope - .new_impl("PublicInputs") - .impl_trait("ToElements"); - let to_elements_fn = to_elements_impl - .new_fn("to_elements") - .arg_ref_self() - .ret("Vec"); + let to_elements_impl = scope.new_impl("PublicInputs").impl_trait("ToElements"); + let to_elements_fn = to_elements_impl.new_fn("to_elements").arg_ref_self().ret("Vec"); to_elements_fn.line("let mut elements = Vec::new();"); for public_input in ir.public_inputs() { match public_input { air_ir::PublicInput::Vector { .. } => { - to_elements_fn.line(format!( - "elements.extend_from_slice(&self.{});", - public_input.name() - )); - } + to_elements_fn + .line(format!("elements.extend_from_slice(&self.{});", public_input.name())); + }, air_ir::PublicInput::Table { .. } => { to_elements_fn.line(format!( "self.{}.iter().for_each(|row| elements.extend_from_slice(row));", public_input.name() )); - } + }, } } to_elements_fn.line("elements"); diff --git a/codegen/winterfell/src/air/transition_constraints.rs b/codegen/winterfell/src/air/transition_constraints.rs index c12f562ff..be2e951b6 100644 --- a/codegen/winterfell/src/air/transition_constraints.rs +++ b/codegen/winterfell/src/air/transition_constraints.rs @@ -22,11 +22,11 @@ pub(super) fn add_fn_evaluate_transition(impl_ref: &mut Impl, ir: &Air) { evaluate_transition.line("let main_next = frame.next();"); // output the constraints. - add_constraints(evaluate_transition, ir, 0); + add_constraints(evaluate_transition, ir, TraceSegmentId::Main); } -/// Adds an implementation of the "evaluate_aux_transition" method to the referenced Air implementation -/// based on the data in the provided IR. +/// Adds an implementation of the "evaluate_aux_transition" method to the referenced Air +/// implementation based on the data in the provided IR. pub(super) fn add_fn_evaluate_aux_transition(impl_ref: &mut Impl, ir: &Air) { // define the function. let evaluate_aux_transition = impl_ref @@ -48,7 +48,7 @@ pub(super) fn add_fn_evaluate_aux_transition(impl_ref: &mut Impl, ir: &Air) { evaluate_aux_transition.line("let aux_next = aux_frame.next();"); // output the constraints. - add_constraints(evaluate_aux_transition, ir, 1); + add_constraints(evaluate_aux_transition, ir, TraceSegmentId::Aux); } /// Iterates through the integrity constraints in the IR, and appends a line of generated code to @@ -58,9 +58,7 @@ fn add_constraints(func_body: &mut codegen::Function, ir: &Air, trace_segment: T func_body.line(format!( "result[{}] = {};", idx, - constraint - .node_index() - .to_string(ir, ElemType::Ext, trace_segment) + constraint.node_index().to_string(ir, ElemType::Ext, trace_segment) )); } } diff --git a/docs/README.md b/docs/README.md index 4e1a5031a..8ce82b490 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ # AirScript docs -This crate contains source files and assets for [AirScript documentation](https://0xpolygonmiden.github.io/air-script/). +This crate contains source files and assets for [AirScript documentation](https://0xmiden.github.io/air-script/). All doc files are written in Markdown and are converted into an online book using the [mdBook](https://github.com/rust-lang/mdBook) utility. diff --git a/docs/book.toml b/docs/book.toml index b454844d5..64a0f7779 100644 --- a/docs/book.toml +++ b/docs/book.toml @@ -5,12 +5,15 @@ multilingual = false src = "src" title = "AirScript by Miden" -[output.html] -git-repository-url = "https://github.com/0xMiden/air-script/" -mathjax-support = true +[build] +build-dir = "target/book" [preprocessor.katex] after = ["links"] +[output.html] +git-repository-url = "https://github.com/0xMiden/air-script/" +mathjax-support = true + [output.linkcheck] warning-policy = "ignore" diff --git a/docs/examples/boundary_constraints_buses.air b/docs/examples/boundary_constraints_buses.air new file mode 100644 index 000000000..f92e6aa93 --- /dev/null +++ b/docs/examples/boundary_constraints_buses.air @@ -0,0 +1,29 @@ +def BoundaryConstraintsExample_WithBuses + +trace_columns { + main: [a, b], +} + +public_inputs { + stack_inputs: [16], + stack_outputs: [16], +} + +buses { + multiset p, + logup q, +} + +boundary_constraints { + enf a.first = stack_inputs[0]; + enf a.last = stack_outputs[0]; + + enf p.first = null; + enf p.last = null; + enf q.first = null; + enf q.last = null; +} + +integrity_constraints { + enf b' = b + 1; +} diff --git a/docs/examples/boundary_constraints_simple.air b/docs/examples/boundary_constraints_simple.air new file mode 100644 index 000000000..23e5403ea --- /dev/null +++ b/docs/examples/boundary_constraints_simple.air @@ -0,0 +1,19 @@ +def BoundaryConstraintsExample_Simple + +trace_columns { + main: [a], +} + +public_inputs { + stack_inputs: [16], + stack_outputs: [16], +} + +boundary_constraints { + enf a.first = 0; + enf a.last = 10; +} + +integrity_constraints { + enf a' = a + 1; +} diff --git a/docs/examples/boundary_constraints_variables.air b/docs/examples/boundary_constraints_variables.air new file mode 100644 index 000000000..d66f08ae5 --- /dev/null +++ b/docs/examples/boundary_constraints_variables.air @@ -0,0 +1,19 @@ +def BoundaryConstraintsExample_Variables + +trace_columns { + main: [p1], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + let x = 3 + let y = 4 + enf p1.first = x * y +} + +integrity_constraints { + enf p1' = p1 + 1; +} diff --git a/docs/examples/fibonacci.air b/docs/examples/fibonacci.air new file mode 100644 index 000000000..279327569 --- /dev/null +++ b/docs/examples/fibonacci.air @@ -0,0 +1,23 @@ +def Fibonacci + +trace_columns { + main: [a, b], +} + +public_inputs { + inputs: [32], +} + +periodic_columns { + k0: [1, 1, 1, 1, 1, 1, 1, 0], +} + +boundary_constraints { + enf a.first = 0; + enf b.first = 1; +} + +integrity_constraints { + enf k0 * (a' - b) = 0; + enf k0 * (b' - a - b) = 0; +} diff --git a/docs/examples/integrity_constraints_buses.air b/docs/examples/integrity_constraints_buses.air new file mode 100644 index 000000000..55d50f798 --- /dev/null +++ b/docs/examples/integrity_constraints_buses.air @@ -0,0 +1,22 @@ +def IntegrityConstraintsExample_Buses + +trace_columns { + main: [a, s], +} + +public_inputs { + stack_inputs: [16], +} + +buses { + multiset p, +} + +boundary_constraints { + enf p.first = null; + enf p.last = null; +} + +integrity_constraints { + p.insert(a) when s; +} diff --git a/docs/examples/integrity_constraints_periodic.air b/docs/examples/integrity_constraints_periodic.air new file mode 100644 index 000000000..2feca8f2c --- /dev/null +++ b/docs/examples/integrity_constraints_periodic.air @@ -0,0 +1,22 @@ +def IntegrityConstraintsExample_Periodic + +trace_columns { + main: [a, b], +} + +public_inputs { + stack_inputs: [16], +} + +periodic_columns { + k: [1, 1, 1, 0], +} + +boundary_constraints { + enf a.first = stack_inputs[0]; + enf b.first = 0; +} + +integrity_constraints { + enf a' = k * a; +} diff --git a/docs/examples/integrity_constraints_simple.air b/docs/examples/integrity_constraints_simple.air new file mode 100644 index 000000000..6ad3ff13d --- /dev/null +++ b/docs/examples/integrity_constraints_simple.air @@ -0,0 +1,20 @@ +def IntegrityConstraintsExample_Simple + +trace_columns { + main: [a, b], +} + +public_inputs { + stack_inputs: [16], + stack_outputs: [16], +} + +boundary_constraints { + enf a.first = stack_inputs[0]; + enf b.first = stack_inputs[1]; +} + +integrity_constraints { + enf a' = a + 1; + enf b' - b - 1 = 0; +} diff --git a/docs/examples/integrity_constraints_variables.air b/docs/examples/integrity_constraints_variables.air new file mode 100644 index 000000000..5c8b8e0d8 --- /dev/null +++ b/docs/examples/integrity_constraints_variables.air @@ -0,0 +1,24 @@ +def IntegrityConstraintsExample_Variables + +trace_columns { + main: [a, b], +} + +public_inputs { + stack_inputs: [16], +} + +periodic_columns { + k: [1, 1, 1, 0] +} + +boundary_constraints { + enf a.first = stack_inputs[0]; + enf b.first = 0; +} + +integrity_constraints { + let x = a + 2 + let y = b + 5 + enf b' = k * x * y +} diff --git a/docs/examples/list_comprehension_example.air b/docs/examples/list_comprehension_example.air new file mode 100644 index 000000000..7d2f5fdfb --- /dev/null +++ b/docs/examples/list_comprehension_example.air @@ -0,0 +1,22 @@ +def ListComprehensionExample + +const Y = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + +const X = [1, 2]; + +trace_columns { + main: [result], +} + +public_inputs { + inputs: [32], +} + +boundary_constraints { + enf result.first = inputs[0]; + enf result.last = inputs[0]; +} + +integrity_constraints { + enf result' = result + 1; +} diff --git a/docs/examples/simple_addition.air b/docs/examples/simple_addition.air new file mode 100644 index 000000000..dc21a94b9 --- /dev/null +++ b/docs/examples/simple_addition.air @@ -0,0 +1,31 @@ +def SimpleAddition + +trace_columns { + main: [a, b, c], +} + +public_inputs { + inputs: [2], + outputs: [1], +} + +periodic_columns { + k0: [1, 1, 1, 1, 1, 1, 1, 0], +} + +boundary_constraints { + enf a.first = inputs[0]; + enf b.first = inputs[1]; + enf c.first = 0; + + enf a.last = outputs[0]; + enf b.last = outputs[0]; + enf c.last = 0; +} + +integrity_constraints { + enf k0 * (a' - a) = 0; + enf k0 * (b' - b) = 0; + + enf c' = c + a + b; +} diff --git a/docs/examples/variables_example.air b/docs/examples/variables_example.air new file mode 100644 index 000000000..dbc6099b8 --- /dev/null +++ b/docs/examples/variables_example.air @@ -0,0 +1,26 @@ +def VariablesExample + +const A = 1; +const B = 2; + +trace_columns { + main: [a, b, c, d], +} + +public_inputs { + stack_inputs: [16], +} + +boundary_constraints { + let x = stack_inputs[0] + stack_inputs[1]; + let y = [stack_inputs[2], stack_inputs[3]]; + enf a.first = x + y[0] + y[1]; +} + +integrity_constraints { + let z = [ + [a + b, c + d], + [A * a, B * b] + ]; + enf a' = z[0][0] + z[0][1] + z[1][0] + z[1][1]; +} diff --git a/docs/src/backends.md b/docs/src/backends.md index b6e686e7b..ce8b023e7 100644 --- a/docs/src/backends.md +++ b/docs/src/backends.md @@ -2,11 +2,13 @@ AirScript currently comes bundled with two backends: - [Winterfell backend](https://github.com/0xMiden/air-script/tree/main/codegen/winterfell) which outputs `Air` trait implementation for the [Winterfell prover](https://github.com/facebook/winterfell) (Rust). +- [ACE backend](https://github.com/0xMiden/air-script/tree/main/codegen/ace) which outputs arithmetic circuits for Miden VM's ACE (Arithmetic Circuit Evaluation) chiplet for recursive STARK proof verification. -These backends can be used programmatically as crates. They can also be used via AirScript CLI by specifying `--target` flag. +These backends can be used programmatically as crates. -For example, the following will output Winterfell `Air` trait implementation for AIR constraints described in `example.air` file: -``` +The Winterfell backend can also be used via AirScript CLI by specifying `--target` flag. For example, the following will output Winterfell `Air` trait implementation for AIR constraints described in `example.air` file: +```bash +# Make sure to run from the project root directory ./target/release/airc transpile examples/example.air --target winterfell ``` In both cases we assumed that the CLI has been compiled as described [here](./introduction.md#cli). diff --git a/docs/src/description/buses.md b/docs/src/description/buses.md index 837433c8f..0cbf6718e 100644 --- a/docs/src/description/buses.md +++ b/docs/src/description/buses.md @@ -5,18 +5,18 @@ A bus is a construct which aims to simplify description of non-local constraints ## Bus types - Multiset (`multiset`): Multiset-based buses can represent constraints which specify values that must have been inserted or removed from a column, in no particular order. -[Miden VM - Multiset Checks](https://0xpolygonmiden.github.io/miden-vm/design/lookups/multiset.html) +[Miden VM - Multiset Checks](https://0xmiden.github.io/miden-vm/design/lookups/multiset.html) [Incremental Multiset Hash Functions and Their Application to Memory Integrity Checking - Clarke et al. MIT CSAIL (2018)](https://people.csail.mit.edu/devadas/pubs/mhashes.pdf) - LogUp (`logup`): LogUp-based buses are more complex than multiset buses, and can encode the multiplicity of an element: an element can be inserted or removed multiple times. -[Miden VM - LogUp: multivariate lookups with logarithmic derivatives](https://0xpolygonmiden.github.io/miden-vm/design/lookups/logup.html) +[Miden VM - LogUp: multivariate lookups with logarithmic derivatives](https://0xmiden.github.io/miden-vm/design/lookups/logup.html) [Multivariate lookups based on logarithmic derivatives - Ulrich Haböck, Orbis Labs, Polygon Labs](https://eprint.iacr.org/2022/1530) ## Defining buses See the [declaring buses](./declarations.md#buses) for more details. -``` +```air buses { multiset p, logup q, @@ -27,7 +27,7 @@ buses { In the boundary constraints section, we can constrain the initial and final state of the bus. Currently, only constraining a bus to be empty (with the `null` keyword) is supported. -``` +```air boundary_constraints { enf p.first = null; enf p.last = null; @@ -38,9 +38,9 @@ The above example states that the bus `p` should be empty at the beginning and e ## Bus integrity constraints -In the integrity constraints section, we can insert and remove elements (as tuples of felts) into and from a bus. In the following examples, `p` and `q` are respectively multiset and LogUp based buses. +In the integrity constraints section or in evaluators, we can insert and remove elements (as tuples of felts) into and from a bus. In the following examples, `p` and `q` are respectively multiset and LogUp based buses. -``` +```air integrity_constraints { p.insert(a) when s1; p.remove(a, b) when 1 - s2; @@ -55,7 +55,7 @@ $$ p ′ \cdot ( ( \alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot b ) \cdot ( 1 − s2 ) + s2 ) = p \cdot ( ( \alpha_0 + \alpha_1 \cdot a ) \cdot s1 + 1 − s1 )) $$ -``` +```air integrity_constraints { q.remove(e, f, g) when s; q.insert(a, b, c) with d; diff --git a/docs/src/description/constraints.md b/docs/src/description/constraints.md index d845002db..0f531568e 100644 --- a/docs/src/description/constraints.md +++ b/docs/src/description/constraints.md @@ -18,26 +18,7 @@ A boundary constraint definition must: The following is a simple example of a valid `boundary_constraints` source section: -``` -def BoundaryConstraintsExample - -trace_columns { - main: [a], -} - -public_inputs { - -} - -boundary_constraints { - # these are main constraints. - enf a.first = 0; - enf a.last = 10; -} - -integrity_constraints { - -} +```{{#include ../../examples/boundary_constraints_simple.air}} ``` ### Public inputs @@ -50,38 +31,7 @@ To use public inputs, the public input must be declared in the `public_inputs` s The following is an example of a valid bus `boundary_constraints` source section that uses public inputs: -``` -def BoundaryConstraintsExample - -trace_columns { - main: [a, b], -} - -public_inputs { - stack_inputs: [16], - stack_outputs: [16], -} - -buses { - multiset p, - logup q, -} - -boundary_constraints { - # these are main constraints that use public input values. - enf a.first = stack_inputs[0]; - enf a.last = stack_outputs[0]; - - # these are bus constraints that specify that buses must be empty at the beginning and the end of the execution trace - enf p.first = null; - enf p.last = null; - enf q.first = null; - enf q.last = null; -} - -integrity_constraints { - -} +```{{#include ../../examples/boundary_constraints_buses.air}} ``` ### Intermediate variables @@ -92,27 +42,7 @@ Boundary constraints can use intermediate variables to express more complex cons The following is an example of a valid `boundary_constraints` source section that uses intermediate variables: -``` -def BoundaryConstraintsExample - -trace_columns { - main: [a, b], -} - -public_inputs { - -} - -boundary_constraints { - # this is a constraint that uses intermediate variables. - let x = 3 - let y = 4 - enf p1.first = x * y -} - -integrity_constraints { - -} +```{{#include ../../examples/boundary_constraints_variables.air}} ``` ## Integrity constraints (`integrity_constraints`) @@ -135,26 +65,7 @@ Integrity constraints have access to values in the "current" row of the trace to The following is a simple example of a valid `integrity_constraints` source section using values from the current and next rows of the main trace: -``` -def IntegrityConstraintsExample - -trace_columns { - main: [a, b], -} - -public_inputs { - -} - -boundary_constraints { - -} - -integrity_constraints { - # these are main constraints. they both express the same constraint. - enf a' = a + 1; - enf b' - b - 1 = 0; -} +```{{#include ../../examples/integrity_constraints_simple.air}} ``` ### Periodic columns @@ -167,29 +78,7 @@ To use periodic column values, the periodic column must be declared in the `peri The following is an example of a valid `integrity_constraints` source section that uses periodic columns: -``` -def IntegrityConstraintsExample - -trace_columns { - main: [a, b], -} - -public_inputs { - -} - -periodic_columns { - k: [1, 1, 1, 0], -} - -boundary_constraints { - -} - -integrity_constraints { - # this is a main constraint that uses a periodic column. - enf a' = k * a; -} +```{{#include ../../examples/integrity_constraints_periodic.air}} ``` ### Buses @@ -200,29 +89,7 @@ Integrity constraints can constrain insertions and removal of elements into / fr The following is an example of a valid `integrity_constraints` source section that uses buses: -``` -def IntegrityConstraintsExample - -trace_columns { - main: [a, s], -} - -public_inputs { - -} - -buses { - multiset p, -} - -boundary_constraints { - -} - -integrity_constraints { - # this is a bus constraint, inserting a into the bus p while s = 1 - p.insert(a) when s; -} +```{{#include ../../examples/integrity_constraints_buses.air}} ``` ### Intermediate variables @@ -233,29 +100,5 @@ Integrity constraints can use intermediate variables to express more complex con The following is an example of a valid `integrity_constraints` source section that uses intermediate variables: -``` -def IntegrityConstraintsExample - -trace_columns { - main: [a, b], -} - -public_inputs { - -} - -periodic_columns { - k: [1, 1, 1, 0] -} - -boundary_constraints { - -} - -integrity_constraints { - # this is a main constraint that uses intermediate variables. - let x = a + 2 - let y = b + 5 - enf b' = k * x * y -} +```{{#include ../../examples/integrity_constraints_variables.air}} ``` diff --git a/docs/src/description/convenience.md b/docs/src/description/convenience.md index 736c9a156..c7668f6a4 100644 --- a/docs/src/description/convenience.md +++ b/docs/src/description/convenience.md @@ -6,44 +6,38 @@ To make writing constraints easier, AirScript provides a number of syntactic con List comprehension provides a simple way to create new vectors. It is similar to the list comprehension syntax in Python. The following examples show how to use list comprehension in AirScript. -``` +```air let x = [a * 2 for a in b]; ``` This will create a new vector with the same length as `b` and the value of each element will be twice that of the corresponding element in `b`. -``` +```air let x = [a + b for (a, b) in (c, d)]; ``` This will create a new vector with the same length as `c` and `d` and the value of each element will be the sum of the corresponding elements in `c` and `d`. This will throw an error if `c` and `d` vectors are of unequal lengths. -``` +```air let x = [2^i * a for (i, a) in (0..5, b)]; ``` Ranges can also be used as iterables, which makes it easy to refer to an element and its index at the same time. This will create a new vector with length 5 and each element will be the corresponding element in `b` multiplied by 2 raised to the power of the element's index. This will throw an error if `b` is not of length 5. -``` +```air const MAX = 5; let x = [2^i * a for (i, a) in (0..MAX, b)]; ``` Ranges are defined with each bound being either an integer literal, or a [named constant](./declarations.md#constants-const) of type scalar. -``` +```air let x = [m + n + o for (m, n, o) in (a, 0..5, c[0..5])]; ``` Slices can also be used as iterables. This will create a new vector with length 5 and each element will be the sum of the corresponding elements in `a`, the range 0 to 5, and the first 5 elements of `c`. This will throw an error if `a` is not of length 5 or if `c` is of length less than 5. -``` -const Y = [ - [1, 2], - [3, 4], - [5, 6] -]; -const X = [1, 2]; -let c = [sum([x * y for (x, y) in (X, row_y)]) for row_y in Y]; +```air +{{#include ../../examples/list_comprehension_example.air}} ``` List comprehensions can be nested. The above example creates a new vector `c` where each element is the sum of the products of corresponding elements in `X` and each row of `Y`. The outer list comprehension iterates over each row of `Y`, while the inner list comprehension iterates over each element in `X` and the current row of `Y`. -``` +```ignore # The following will result in a parsing error: let c = [[x * y for (x, y) in (X, row_y)] for row_y in Y]; ``` @@ -53,7 +47,7 @@ List comprehensions can only have a scalar expression as their body. This means List folding provides syntactic convenience for folding vectors into expressions. It is similar to the list folding syntax in Python. List folding can be applied to vectors, list comprehension or identifiers referring to vectors and list comprehension. The following examples show how to use list folding in AirScript. -``` +```air trace_columns { main: [a[5], b, c], } @@ -67,7 +61,7 @@ integrity_constraints { In the above, `x` and `y` both represent the sum of all trace column values in the trace column group `a`. `z` represents the sum of all trace column values in the trace column group `a` multiplied by `2`. -``` +```air trace_columns { main: [a[5], b, c], } @@ -84,7 +78,7 @@ In the above, `x` and `y` both represent the product of all trace column values ## Constraint comprehension Constraint comprehension provides a way to enforce the same constraint on multiple values. Conceptually, it is very similar to the list comprehension described above. For example: -``` +```air trace_columns { main: [a[5], b, c], } @@ -94,7 +88,7 @@ integrity_constraints { } ``` The above will enforce $a_i^2 = a_i$ constraint for all columns in the trace column group `a`. Semantically, this is equivalent to: -``` +```air trace_columns { main: [a[5], b, c], } @@ -109,7 +103,7 @@ integrity_constraints { ``` Similar to list comprehension, constraints in constraint comprehension can involve values from multiple lists. For example: -``` +```air trace_columns { main: [a[5], b[5]], } @@ -123,7 +117,7 @@ The above will enforce that $a_i' = i \cdot b_i$ for $i \in [0, 5)$. If the leng ## Conditional constraints Frequently, we may want to enforce constraints based on some selectors. For example, let's say our trace has 4 columns: `a`, `b`, `c`, and `s`, and we want to enforce that $c' = a + b$ when $s = 1$ and $c' = a \cdot c$ when $s = 0$. We can write these constraints directly like so: -``` +```air trace_columns { main: [a, b, c, s], } @@ -136,7 +130,7 @@ integrity_constraints { Notice that we also need to enforce $s^2 = s$ to ensure that column $s$ can contain only binary values. While the above approach works, it gets more and more difficult to manage as selectors and constraints get more complicated. To simplify describing constraints for this use case, AirScript introduces `enf match` statement. The above constraints can be described using `enf match` statement as follows: -``` +```air trace_columns { main: [a, b, c, s], } @@ -156,7 +150,7 @@ In the above, the syntax of each "option" is `case : proc_macro2::TokenStream { let enum_wrapper = get_enum_wrapper(input); let name = &input.ident; - let (yes_name, no_name) = ( - format_ident!("{}BuilderYes", name), - format_ident!("{}BuilderNo", name), - ); + let (yes_name, no_name) = + (format_ident!("{}BuilderYes", name), format_ident!("{}BuilderNo", name)); let (fields, hidden_fields) = extract_fields(input); let (builder_struct_fields, builder_states, transition_table) = make_builder_struct_fields(&fields); @@ -81,7 +79,7 @@ fn extract_fields(data: &syn::DeriveInput) -> (Vec<(&syn::Ident, &syn::Type)>, V "_" => { hidden_fields.push(ident); None - } + }, _ => Some((ident, &field.ty)), } }) @@ -102,7 +100,7 @@ fn next_ty(ty: &syn::PathSegment) -> Option { }) => match args.first().unwrap() { syn::GenericArgument::Type(syn::Type::Path(syn::TypePath { path, .. })) => { Some(path.segments.first().unwrap().clone()) - } + }, _ => None, }, syn::PathArguments::None => Some(syn::PathSegment { @@ -291,11 +289,8 @@ fn make_builder_struct_fields<'a>(fields: &[(&'a syn::Ident, &'a syn::Type)]) -> let initial_state = initial_state.clone(); let states = make_states(&initial_state); - let reverse_states: HashMap, usize> = states - .iter() - .enumerate() - .map(|(i, state)| (state.clone(), i)) - .collect(); + let reverse_states: HashMap, usize> = + states.iter().enumerate().map(|(i, state)| (state.clone(), i)).collect(); let mut transition_table: Vec> = vec![]; for state in states.iter() { let mut row = vec![]; @@ -427,7 +422,7 @@ fn make_builder_struct<'a>( )], ) -> proc_macro2::TokenStream { let builder_struct_name = format_ident!("{}Builder", name); - let struct_fields = fields.iter().map(|(_, _, field, _, _, _)| field); + let struct_fields = fields.iter().map(|(_, _, field, ..)| field); let builder_struct = quote! { #[derive(Debug)] pub struct #builder_struct_name { @@ -521,10 +516,7 @@ fn make_builder_impls<'a>( let (ret, body_ret) = if next_state_name == state_name { (quote! { Self }, quote! { self }) } else { - ( - quote! { #next_state_name }, - quote! { unsafe { std::mem::transmute(self) } }, - ) + (quote! { #next_state_name }, quote! { unsafe { std::mem::transmute(self) } }) }; quote! { pub fn #ident(#arg) -> #ret { @@ -535,16 +527,8 @@ fn make_builder_impls<'a>( }) .collect::>(); if i == states.len() - 1 { - let builder = fields - .iter() - .map(|(_, _, _, _, _, builder)| builder) - .collect::>(); - methods.push(make_build_method( - name, - &builder, - enum_wrapper, - hidden_fields, - )); + let builder = fields.iter().map(|(_, _, _, _, _, builder)| builder).collect::>(); + methods.push(make_build_method(name, &builder, enum_wrapper, hidden_fields)); }; quote! { impl #state_name { @@ -552,7 +536,7 @@ fn make_builder_impls<'a>( } } }); - let field_names = fields.iter().map(|(ident, _, _, _, _, _)| ident); + let field_names = fields.iter().map(|(ident, ..)| ident); quote! { impl Default for #empty_state { fn default() -> Self { @@ -572,11 +556,10 @@ fn make_build_method( enum_wrapper: &EnumWrapper, hidden_fields: &[&syn::Ident], ) -> proc_macro2::TokenStream { - let fields = builders.iter().map(|builder| quote! { #builder }).chain( - hidden_fields - .iter() - .map(|field| quote! { #field: Default::default() }), - ); + let fields = builders + .iter() + .map(|builder| quote! { #builder }) + .chain(hidden_fields.iter().map(|field| quote! { #field: Default::default() })); match enum_wrapper { EnumWrapper::Op => quote! { @@ -602,17 +585,15 @@ fn make_build_method( #[cfg(test)] mod tests { - use super::*; - use crate::helpers::fmt; use pretty_assertions::assert_eq; use syn::parse2; + use super::*; + use crate::helpers::fmt; + #[test] fn test_derive_builder() { - let (y, n) = ( - format_ident!("FooBuilderYes"), - format_ident!("FooBuilderNo"), - ); + let (y, n) = (format_ident!("FooBuilderYes"), format_ident!("FooBuilderNo")); let input = quote! { #[derive(Builder)] #[enum_wrapper(Op)] diff --git a/mir/derive-ir/src/lib.rs b/mir/derive-ir/src/lib.rs index f9f571555..090241f8a 100644 --- a/mir/derive-ir/src/lib.rs +++ b/mir/derive-ir/src/lib.rs @@ -5,16 +5,15 @@ use builder::impl_builder; use proc_macro::TokenStream; use syn::{DeriveInput, parse_macro_input}; -/// /// Derive the [Builder] trait for a struct. +/// /// Generates a type-level state machine for transitioning between states. /// -/// States correspond to which fields have been set or not. -/// It takes into account the type of the fields. -/// The following types are treated as optional fields: -/// - [BackLink] -/// - [Vec] -/// - [Link>] +/// States correspond to which fields have been set or not. It takes into account the type of the +/// fields. The following types are treated as optional fields: +/// - `BackLink` +/// - `Vec` +/// - `Link>` /// /// For example, given the following struct: /// ```ignore @@ -34,8 +33,7 @@ use syn::{DeriveInput, parse_macro_input}; /// - `a: BackLink` /// - `b: Vec>` /// - `e: Vec>` -/// - `f: Link>>` -/// and the required fields: +/// - `f: Link>>` and the required fields: /// - `c: i32` /// - `d: Link` /// @@ -59,10 +57,10 @@ use syn::{DeriveInput, parse_macro_input}; /// ]; /// ``` /// -/// We generate a transition table for each state, -/// which maps from the current state to the next state. -/// rows are indexed by the current state, -/// columns by the method thst transitions to the next state. +/// We generate a transition table for each state, which maps from the current state to the next +/// state. rows are indexed by the current state, columns by the method that transitions to the next +/// state. +/// /// ```ignore /// let transitions = [ /// /* 0: */ [0, 0, 1, 2, 0, 0], @@ -72,14 +70,14 @@ use syn::{DeriveInput, parse_macro_input}; /// ]; /// ``` /// -/// We then generate an implementation for each state, -/// with a method for each field., which transitions to the next state. -/// The [enum_wrapper] attribute is used to automatically wrap the struct in a Link -/// to reduce boilerplate. -/// The only supported [enum_wrapper]s are `Op` and `Root`. +/// We then generate an implementation for each state, with a method for each field., which +/// transitions to the next state. The `enum_wrapper` attribute is used to automatically wrap the +/// struct in a `Link` to reduce boilerplate. The only supported `enum_wrapper`s are +/// `Op` and `Root`. /// /// The following API is generated: -/// ```ignore +/// +/// ```text /// let a: Link = todo!(); /// let b0: Link = todo!(); /// let b1: Link = todo!(); @@ -110,7 +108,8 @@ use syn::{DeriveInput, parse_macro_input}; /// e: vec![e0, e1], /// f: Link::new(vec![f0, f1]), /// } -/// ); +/// ) +/// ); /// ``` #[proc_macro_derive(Builder, attributes(enum_wrapper))] pub fn derive_builder(input: TokenStream) -> TokenStream { diff --git a/mir/src/codegen.rs b/mir/src/codegen.rs deleted file mode 100644 index ee7a3d876..000000000 --- a/mir/src/codegen.rs +++ /dev/null @@ -1,8 +0,0 @@ -/// This trait should be implemented on types which handle generating code from AirScript MIR -pub trait CodeGenerator { - /// The type of the artifact produced by this codegen backend - type Output; - - /// Generates code using this generator, consuming it in the process - fn generate(&self, ir: &crate::ir::Mir) -> anyhow::Result; -} diff --git a/mir/src/ir/bus.rs b/mir/src/ir/bus.rs index c965176ab..57822c7f3 100644 --- a/mir/src/ir/bus.rs +++ b/mir/src/ir/bus.rs @@ -1,7 +1,6 @@ use std::ops::Deref; use air_parser::ast::{self, Identifier}; - use miden_diagnostics::{SourceSpan, Spanned}; use crate::{ @@ -12,7 +11,7 @@ use crate::{ /// A Mir struct to represent a Bus definition /// we have 2 cases: /// -/// - [BusType::Multiset]: multiset check +/// - BusType::Multiset: multiset check /// /// these constraints: /// ```air @@ -25,17 +24,17 @@ use crate::{ /// ``` /// with this bus definition: /// ```ignore -/// Bus { +/// struct Bus { /// bus_type: BusType::Multiset, /// columns: [a, b, c, d], /// latches: [s, 1 - s], /// } /// ``` /// with: -/// a, b, c, d, s being [Link] in the graph -/// s, 1 - s being [Link] representing booleans in the graph +/// a, b, c, d, s being [`Link`] in the graph +/// s, 1 - s being [`Link`] representing booleans in the graph /// -/// - [BusType::Logup]: LogUp bus +/// - BusType::Logup: LogUp bus /// /// these constraints: /// ```air @@ -55,8 +54,8 @@ use crate::{ /// } /// ``` /// with: -/// a, b, c, e, f, g being [Link] in the graph -/// d, s being [Link], s is boolean, d is a number. +/// a, b, c, e, f, g being [`Link`] in the graph +/// d, s being [`Link`], s is boolean, d is a number. #[derive(Default, Clone, Eq, Debug, Spanned)] pub struct Bus { /// Identifier of the bus diff --git a/mir/src/ir/graph.rs b/mir/src/ir/graph.rs index b3149d761..584ef862b 100644 --- a/mir/src/ir/graph.rs +++ b/mir/src/ir/graph.rs @@ -1,4 +1,3 @@ -use crate::{CompileError, ir}; use std::{ cell::{Ref, RefMut}, collections::BTreeMap, @@ -6,6 +5,8 @@ use std::{ use air_parser::ast::QualifiedIdentifier; +use crate::{CompileError, ir}; + /// The constraints graph for the Mir. /// /// We store constraints (boundary and integrity), as well as function and evaluator definitions. @@ -41,7 +42,7 @@ impl Graph { } else { Err(CompileError::Failed) } - } + }, } } @@ -62,9 +63,7 @@ impl Graph { ident: &QualifiedIdentifier, ) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting - self.functions - .get_mut(ident) - .map(|n| n.as_function_mut().unwrap()) + self.functions.get_mut(ident).map(|n| n.as_function_mut().unwrap()) } /// Queries all function nodes @@ -72,8 +71,8 @@ impl Graph { self.functions.values().cloned().collect() } - /// Inserts an evaluator into the graph, returning an error if the root is not an [ir::Evaluator], - /// or if the evaluator already exists (declaration conflict). + /// Inserts an evaluator into the graph, returning an error if the root is not an + /// [ir::Evaluator], or if the evaluator already exists (declaration conflict). pub fn insert_evaluator( &mut self, ident: QualifiedIdentifier, @@ -90,7 +89,7 @@ impl Graph { } else { Err(CompileError::Failed) } - } + }, } } @@ -102,9 +101,7 @@ impl Graph { /// Queries a given evaluator as a mutable [ir::Evaluator] pub fn get_evaluator(&self, ident: &QualifiedIdentifier) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting - self.evaluators - .get(ident) - .map(|n| n.as_evaluator().unwrap()) + self.evaluators.get(ident).map(|n| n.as_evaluator().unwrap()) } /// Queries a given evaluator as a mutable [ir::Evaluator] @@ -113,9 +110,7 @@ impl Graph { ident: &QualifiedIdentifier, ) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting - self.evaluators - .get_mut(ident) - .map(|n| n.as_evaluator_mut().unwrap()) + self.evaluators.get_mut(ident).map(|n| n.as_evaluator_mut().unwrap()) } /// Queries all evaluator nodes @@ -126,33 +121,25 @@ impl Graph { /// Inserts a boundary constraint into the graph, if it does not already exist. pub fn insert_boundary_constraints_root(&mut self, root: ir::Link) { if !self.boundary_constraints_roots.borrow().contains(&root) { - self.boundary_constraints_roots - .borrow_mut() - .push(root.clone()); + self.boundary_constraints_roots.borrow_mut().push(root.clone()); } } /// Removes a boundary constraint from the graph. pub fn remove_boundary_constraints_root(&mut self, root: ir::Link) { - self.boundary_constraints_roots - .borrow_mut() - .retain(|n| *n != root); + self.boundary_constraints_roots.borrow_mut().retain(|n| *n != root); } /// Inserts an integrity constraint into the graph, if it does not already exist. pub fn insert_integrity_constraints_root(&mut self, root: ir::Link) { if !self.integrity_constraints_roots.borrow().contains(&root) { - self.integrity_constraints_roots - .borrow_mut() - .push(root.clone()); + self.integrity_constraints_roots.borrow_mut().push(root.clone()); } } /// Removes an integrity constraint from the graph. pub fn remove_integrity_constraints_root(&mut self, root: ir::Link) { - self.boundary_constraints_roots - .borrow_mut() - .retain(|n| *n != root); + self.integrity_constraints_roots.borrow_mut().retain(|n| *n != root); } /// Inserts a bus into the graph, returning an error @@ -162,9 +149,7 @@ impl Graph { ident: QualifiedIdentifier, bus: ir::Link, ) -> Result<(), CompileError> { - self.buses - .insert(ident, bus) - .map_or(Ok(()), |_| Err(CompileError::Failed)) + self.buses.insert(ident, bus).map_or(Ok(()), |_| Err(CompileError::Failed)) } /// Queries a given bus, returning a [ir::Link] if it exists. diff --git a/mir/src/ir/link.rs b/mir/src/ir/link.rs index 04eb896b2..db1cab648 100644 --- a/mir/src/ir/link.rs +++ b/mir/src/ir/link.rs @@ -1,8 +1,11 @@ +use std::{ + cell::RefCell, + fmt::Debug, + hash::Hash, + rc::{Rc, Weak}, +}; + use miden_diagnostics::{SourceSpan, Spanned}; -use std::cell::RefCell; -use std::fmt::Debug; -use std::hash::Hash; -use std::rc::{Rc, Weak}; /// A wrapper around a `Rc>` to allow custom trait implementations. pub struct Link @@ -14,9 +17,7 @@ where impl Link { pub fn new(data: T) -> Self { - Self { - link: Rc::new(RefCell::new(data)), - } + Self { link: Rc::new(RefCell::new(data)) } } /// Returns a `std::cell::Ref` to the inner value. pub fn borrow(&self) -> std::cell::Ref<'_, T> { @@ -58,9 +59,7 @@ where impl Clone for Link { fn clone(&self) -> Self { - Self { - link: self.link.clone(), - } + Self { link: self.link.clone() } } } @@ -149,9 +148,7 @@ impl Debug for BackLink { impl Clone for BackLink { fn clone(&self) -> Self { - Self { - link: self.link.clone(), - } + Self { link: self.link.clone() } } } @@ -167,18 +164,14 @@ impl Eq for BackLink {} /// Converts a `Link` into a `BackLink` by downgrading the strong reference. impl From> for BackLink { fn from(parent: Link) -> Self { - Self { - link: Some(Rc::downgrade(&parent.link)), - } + Self { link: Some(Rc::downgrade(&parent.link)) } } } /// Converts a `Rc>` into a `BackLink`. impl From>> for BackLink { fn from(parent: Rc>) -> Self { - Self { - link: Some(Rc::downgrade(&parent)), - } + Self { link: Some(Rc::downgrade(&parent)) } } } @@ -198,7 +191,7 @@ where } } -/// A wrapper around a [Link] to block recursive implementations of [PartialEq] and [Hash]. +/// A wrapper around a [`Link`] to block recursive implementations of [PartialEq] and [Hash]. /// A [Singleton] is used when the following properties are desired: /// - The reference count of the field needs to be kept at >1 once instantiated. /// - The field should be ignored in comparisons and hashing. diff --git a/mir/src/ir/mir.rs b/mir/src/ir/mir.rs index 211204140..a1bb8a45c 100644 --- a/mir/src/ir/mir.rs +++ b/mir/src/ir/mir.rs @@ -1,11 +1,10 @@ +use std::collections::BTreeMap; + use air_parser::ast::TraceSegment; pub use air_parser::{ Symbol, ast::{Identifier, PeriodicColumn, PublicInput, QualifiedIdentifier}, }; - -use std::collections::BTreeMap; - use miden_diagnostics::{SourceSpan, Spanned}; use super::Graph; @@ -41,10 +40,7 @@ pub struct Mir { } impl Default for Mir { fn default() -> Self { - Self::new(Identifier::new( - SourceSpan::UNKNOWN, - Symbol::intern("unnamed"), - )) + Self::new(Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("unnamed"))) } } impl Mir { @@ -71,13 +67,13 @@ impl Mir { self.name.as_str() } - /// Return a reference to the raw [AlgebraicGraph] corresponding to the constraints + /// Return a reference to the raw AlgebraicGraph corresponding to the constraints #[inline] pub fn constraint_graph(&self) -> &Graph { &self.graph } - /// Return a mutable reference to the raw [AlgebraicGraph] corresponding to the constraints + /// Return a mutable reference to the raw AlgebraicGraph corresponding to the constraints #[inline] pub fn constraint_graph_mut(&mut self) -> &mut Graph { &mut self.graph diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index 98144b5d5..a23623eb2 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -5,6 +5,7 @@ mod mir; mod node; mod nodes; mod owner; +mod quad_eval; mod utils; pub extern crate derive_ir; @@ -16,6 +17,9 @@ pub use mir::Mir; pub use node::Node; pub use nodes::*; pub use owner::Owner; +pub use quad_eval::{ + QuadFelt, RandomInputs, const_quad_felt, query_indexed_eval, query_mapped_eval, +}; pub use utils::*; /// A trait for nodes that can have children /// This is used with the Child trait to allow for easy traversal and manipulation of the graph @@ -101,7 +105,8 @@ where } /// A trait implemented by all nodes. -/// Will be derivable later. The implementation and type-safe builder is currently manual while we tweak the design +/// Will be derivable later. The implementation and type-safe builder is currently manual while we +/// tweak the design pub trait Builder { type Empty; type Full; diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index 294f728fd..9bf47a696 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -1,7 +1,8 @@ -use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root}; +use std::ops::Deref; + use miden_diagnostics::{SourceSpan, Spanned}; -use std::ops::Deref; +use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root}; /// All the nodes that can be in the MIR Graph /// Combines all [Root] and [Op] variants @@ -11,8 +12,7 @@ use std::ops::Deref; /// so it can be updated to the correct variant when the inner struct is updated /// Note: The [None] variant is used to represent a [Node] that: /// - is not yet initialized -/// - no longer exists (due to its ref-count dropping to 0). -/// We refer to those as "stale" nodes. +/// - no longer exists (due to its ref-count dropping to 0). We refer to those as "stale" nodes. #[derive(Clone, Eq, Debug, Spanned)] pub enum Node { Function(BackLink), @@ -229,7 +229,8 @@ impl Link { Root::None(span) => Node::None(*span), }; } else { - unreachable!(); + // If the [Node] is stale, we set it to None + to_update = Node::None(self.span()); } *self.borrow_mut() = to_update; diff --git a/mir/src/ir/nodes/mod.rs b/mir/src/ir/nodes/mod.rs index c6f761bed..5ddf29733 100644 --- a/mir/src/ir/nodes/mod.rs +++ b/mir/src/ir/nodes/mod.rs @@ -3,18 +3,20 @@ mod ops; mod root; mod roots; +use std::cell::{Ref, RefMut}; + pub use op::Op; pub use ops::*; pub use root::Root; pub use roots::*; -use std::cell::{Ref, RefMut}; -/// Apply a getter function to a Ref and return a Ref if the getter succeeds +/// Apply a getter function to a `Ref` and return a `Ref` if the getter succeeds. pub fn get_inner(obj: Ref, getter: impl Fn(&T) -> Option<&U>) -> Option> { Ref::filter_map(obj, getter).ok() } -/// Apply a mutable getter function to a RefMut and return a RefMut if the getter succeeds +/// Apply a mutable getter function to a `RefMut` and return a `RefMut` if the getter +/// succeeds. pub fn get_inner_mut( obj: RefMut, getter: impl Fn(&mut T) -> Option<&mut U>, diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index 0ea312fc9..4dc7f4385 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -1,17 +1,19 @@ +use std::{ + cell::{Ref, RefMut}, + ops::{Deref, DerefMut}, +}; + +use miden_diagnostics::{SourceSpan, Spanned}; + use crate::ir::{ Accessor, Add, BackLink, Boundary, BusOp, Call, Child, ConstantValue, Enf, Exp, Fold, For, If, Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, Singleton, SpannedMirValue, Sub, Value, Vector, get_inner, get_inner_mut, }; -use miden_diagnostics::{SourceSpan, Spanned}; - -use std::{ - cell::{Ref, RefMut}, - ops::{Deref, DerefMut}, -}; -/// The combined [Op]s and leaves of the MIR Graph -/// These represent the operations that can be present in [Root] bodies +/// The combined [Op]s and leaves of the MIR Graph. +/// +/// These represent the operations that can be present in Root bodies /// The [Op] enum owns it's inner struct to allow conversion between variants #[derive(Clone, PartialEq, Eq, Debug, Hash, Spanned)] pub enum Op { @@ -106,7 +108,7 @@ impl Child for Op { Op::BusOp(b) => b.add_parent(parent), Op::Parameter(p) => p.add_parent(parent), Op::Value(v) => v.add_parent(parent), - Op::None(_) => {} + Op::None(_) => {}, } } fn remove_parent(&mut self, parent: Link) { @@ -127,7 +129,7 @@ impl Child for Op { Op::BusOp(b) => b.remove_parent(parent), Op::Parameter(p) => p.remove_parent(parent), Op::Value(v) => v.remove_parent(parent), - Op::None(_) => {} + Op::None(_) => {}, } } } @@ -165,11 +167,11 @@ impl Link { other_node.update(&self_node); other.update_inner_node(&self_node); - if let Some(other_owner) = other.as_owner() { - if let Some(self_owner) = self.as_owner() { - other_owner.update(&self_owner); - other.update_inner_owner(&self_owner); - } + if let Some(other_owner) = other.as_owner() + && let Some(self_owner) = self.as_owner() + { + other_owner.update(&self_owner); + other.update_inner_owner(&self_owner); } self.update(other); @@ -182,53 +184,53 @@ impl Link { match self.clone().borrow_mut().deref_mut() { Op::Enf(enf) => { enf._node = Singleton::from(node.clone()); - } + }, Op::Boundary(boundary) => { boundary._node = Singleton::from(node.clone()); - } + }, Op::Add(add) => { add._node = Singleton::from(node.clone()); - } + }, Op::Sub(sub) => { sub._node = Singleton::from(node.clone()); - } + }, Op::Mul(mul) => { mul._node = Singleton::from(node.clone()); - } + }, Op::Exp(exp) => { exp._node = Singleton::from(node.clone()); - } + }, Op::If(if_op) => { if_op._node = Singleton::from(node.clone()); - } + }, Op::For(for_op) => { for_op._node = Singleton::from(node.clone()); - } + }, Op::Call(call) => { call._node = Singleton::from(node.clone()); - } + }, Op::Fold(fold) => { fold._node = Singleton::from(node.clone()); - } + }, Op::Vector(vector) => { vector._node = Singleton::from(node.clone()); - } + }, Op::Matrix(matrix) => { matrix._node = Singleton::from(node.clone()); - } + }, Op::Accessor(accessor) => { accessor._node = Singleton::from(node.clone()); - } + }, Op::BusOp(bus_op) => { bus_op._node = Singleton::from(node.clone()); - } + }, Op::Parameter(parameter) => { parameter._node = Singleton::from(node.clone()); - } + }, Op::Value(value) => { value._node = Singleton::from(node.clone()); - } - Op::None(_) => {} + }, + Op::None(_) => {}, } } @@ -236,49 +238,49 @@ impl Link { match self.clone().borrow_mut().deref_mut() { Op::Enf(enf) => { enf._owner = Singleton::from(owner.clone()); - } + }, Op::Boundary(boundary) => { boundary._owner = Singleton::from(owner.clone()); - } + }, Op::Add(add) => { add._owner = Singleton::from(owner.clone()); - } + }, Op::Sub(sub) => { sub._owner = Singleton::from(owner.clone()); - } + }, Op::Mul(mul) => { mul._owner = Singleton::from(owner.clone()); - } + }, Op::Exp(exp) => { exp._owner = Singleton::from(owner.clone()); - } + }, Op::If(if_op) => { if_op._owner = Singleton::from(owner.clone()); - } + }, Op::For(for_op) => { for_op._owner = Singleton::from(owner.clone()); - } + }, Op::Call(call) => { call._owner = Singleton::from(owner.clone()); - } + }, Op::Fold(fold) => { fold._owner = Singleton::from(owner.clone()); - } + }, Op::Vector(vector) => { vector._owner = Singleton::from(owner.clone()); - } + }, Op::Matrix(matrix) => { matrix._owner = Singleton::from(owner.clone()); - } + }, Op::Accessor(accessor) => { accessor._owner = Singleton::from(owner.clone()); - } + }, Op::BusOp(bus_op) => { bus_op._owner = Singleton::from(owner.clone()); - } - Op::Parameter(_parameter) => {} - Op::Value(_value) => {} - Op::None(_) => {} + }, + Op::Parameter(_parameter) => {}, + Op::Value(_value) => {}, + Op::None(_) => {}, } } @@ -287,150 +289,102 @@ impl Link { pub fn as_node(&self) -> Link { let back: BackLink = self.clone().into(); match self.clone().borrow_mut().deref_mut() { - Op::Enf(Enf { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + Op::Enf(Enf { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Enf(enf) => { let node: Link = Node::Enf(back).into(); enf._node = Singleton::from(node.clone()); node - } - Op::Boundary(Boundary { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Boundary(Boundary { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Boundary(boundary) => { let node: Link = Node::Boundary(back).into(); boundary._node = Singleton::from(node.clone()); node - } - Op::Add(Add { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Add(Add { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Add(add) => { let node: Link = Node::Add(back).into(); add._node = Singleton::from(node.clone()); node - } - Op::Sub(Sub { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Sub(Sub { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Sub(sub) => { let node: Link = Node::Sub(back).into(); sub._node = Singleton::from(node.clone()); node - } - Op::Mul(Mul { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Mul(Mul { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Mul(mul) => { let node: Link = Node::Mul(back).into(); mul._node = Singleton::from(node.clone()); node - } - Op::Exp(Exp { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Exp(Exp { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Exp(exp) => { let node: Link = Node::Exp(back).into(); exp._node = Singleton::from(node.clone()); node - } - Op::If(If { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::If(If { _node: Singleton(Some(link)), .. }) => link.clone(), Op::If(if_op) => { let node: Link = Node::If(back).into(); if_op._node = Singleton::from(node.clone()); node - } - Op::For(For { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::For(For { _node: Singleton(Some(link)), .. }) => link.clone(), Op::For(for_op) => { let node: Link = Node::For(back).into(); for_op._node = Singleton::from(node.clone()); node - } - Op::Call(Call { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Call(Call { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Call(call) => { let node: Link = Node::Call(back).into(); call._node = Singleton::from(node.clone()); node - } - Op::Fold(Fold { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Fold(Fold { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Fold(fold) => { let node: Link = Node::Fold(back).into(); fold._node = Singleton::from(node.clone()); node - } - Op::Vector(Vector { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Vector(Vector { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Vector(vector) => { let node: Link = Node::Vector(back).into(); vector._node = Singleton::from(node.clone()); node - } - Op::Matrix(Matrix { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Matrix(Matrix { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Matrix(matrix) => { let node: Link = Node::Matrix(back).into(); matrix._node = Singleton::from(node.clone()); node - } - Op::Accessor(Accessor { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Accessor(Accessor { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Accessor(accessor) => { let node: Link = Node::Accessor(back).into(); accessor._node = Singleton::from(node.clone()); node - } - Op::BusOp(BusOp { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::BusOp(BusOp { _node: Singleton(Some(link)), .. }) => link.clone(), Op::BusOp(bus_op) => { let node: Link = Node::BusOp(back).into(); bus_op._node = Singleton::from(node.clone()); node - } - Op::Parameter(Parameter { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Parameter(Parameter { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Parameter(parameter) => { let node: Link = Node::Parameter(back).into(); parameter._node = Singleton::from(node.clone()); node - } - Op::Value(Value { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Op::Value(Value { _node: Singleton(Some(link)), .. }) => link.clone(), Op::Value(value) => { let node: Link = Node::Value(back).into(); value._node = Singleton::from(node.clone()); node - } + }, Op::None(span) => Node::None(*span).into(), } } @@ -439,132 +393,90 @@ impl Link { pub fn as_owner(&self) -> Option> { let back: BackLink = self.clone().into(); match self.clone().borrow_mut().deref_mut() { - Op::Enf(Enf { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + Op::Enf(Enf { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Enf(enf) => { let owner: Link = Owner::Enf(back).into(); enf._owner = Singleton::from(owner.clone()); enf._owner.0.clone() - } - Op::Boundary(Boundary { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Boundary(Boundary { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Boundary(boundary) => { let owner: Link = Owner::Boundary(back).into(); boundary._owner = Singleton::from(owner.clone()); boundary._owner.0.clone() - } - Op::Add(Add { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Add(Add { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Add(add) => { let owner: Link = Owner::Add(back).into(); add._owner = Singleton::from(owner.clone()); add._owner.0.clone() - } - Op::Sub(Sub { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Sub(Sub { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Sub(sub) => { let owner: Link = Owner::Sub(back).into(); sub._owner = Singleton::from(owner.clone()); sub._owner.0.clone() - } - Op::Mul(Mul { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Mul(Mul { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Mul(mul) => { let owner: Link = Owner::Mul(back).into(); mul._owner = Singleton::from(owner.clone()); mul._owner.0.clone() - } - Op::Exp(Exp { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Exp(Exp { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Exp(exp) => { let owner: Link = Owner::Exp(back).into(); exp._owner = Singleton::from(owner.clone()); exp._owner.0.clone() - } - Op::If(If { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::If(If { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::If(if_op) => { let owner: Link = Owner::If(back).into(); if_op._owner = Singleton::from(owner.clone()); if_op._owner.0.clone() - } - Op::For(For { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::For(For { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::For(for_op) => { let owner: Link = Owner::For(back).into(); for_op._owner = Singleton::from(owner.clone()); for_op._owner.0.clone() - } - Op::Call(Call { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Call(Call { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Call(call) => { let owner: Link = Owner::Call(back).into(); call._owner = Singleton::from(owner.clone()); call._owner.0.clone() - } - Op::Fold(Fold { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Fold(Fold { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Fold(fold) => { let owner: Link = Owner::Fold(back).into(); fold._owner = Singleton::from(owner.clone()); fold._owner.0.clone() - } - Op::Vector(Vector { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Vector(Vector { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Vector(vector) => { let owner: Link = Owner::Vector(back).into(); vector._owner = Singleton::from(owner.clone()); vector._owner.0.clone() - } - Op::Matrix(Matrix { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Matrix(Matrix { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Matrix(matrix) => { let owner: Link = Owner::Matrix(back).into(); matrix._owner = Singleton::from(owner.clone()); matrix._owner.0.clone() - } - Op::Accessor(Accessor { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::Accessor(Accessor { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::Accessor(accessor) => { let owner: Link = Owner::Accessor(back).into(); accessor._owner = Singleton::from(owner.clone()); accessor._owner.0.clone() - } - Op::BusOp(BusOp { - _owner: Singleton(Some(link)), - .. - }) => Some(link.clone()), + }, + Op::BusOp(BusOp { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), Op::BusOp(bus_op) => { let owner: Link = Owner::BusOp(back).into(); bus_op._owner = Singleton::from(owner.clone()); bus_op._owner.0.clone() - } + }, Op::Parameter(_) => None, Op::Value(_) => None, Op::None(_) => None, @@ -828,15 +740,22 @@ impl Link { } } -impl From for Link { - fn from(value: i64) -> Self { - Op::Value(Value { - value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(value as u64)), - ..Default::default() - }, - ..Default::default() - }) - .into() - } +macro_rules! impl_from_integer_for_op { + ($($ty:ty),*) => { + $( + impl From<$ty> for Link { + fn from(value: $ty) -> Self { + Op::Value(Value { + value: SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(value as u64)), + ..Default::default() + }, + ..Default::default() + }) + .into() + } + } + )* + }; } +impl_from_integer_for_op!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize); diff --git a/mir/src/ir/nodes/ops/accessor.rs b/mir/src/ir/nodes/ops/accessor.rs index bf5178a78..5a5d8318f 100644 --- a/mir/src/ir/nodes/ops/accessor.rs +++ b/mir/src/ir/nodes/ops/accessor.rs @@ -1,18 +1,19 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; -use air_parser::ast::AccessType; -use miden_diagnostics::{SourceSpan, Spanned}; use std::hash::Hash; +use miden_diagnostics::{SourceSpan, Spanned}; + +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent accessing a given op, `indexable`, in two different ways: -/// - access_type: AccessType, which describes for example how to access a given index for a Vector (e.g. `v[0]`) +/// - access_type: AccessType, which describes for example how to access a given index for a Vector +/// (e.g. `v[0]`) /// - offset: usize, which describes the row offset for a trace column access (e.g. `a'`) -/// #[derive(Hash, Clone, PartialEq, Eq, Debug, Builder, Spanned, Default)] #[enum_wrapper(Op)] pub struct Accessor { pub parents: Vec>, pub indexable: Link, - pub access_type: AccessType, + pub access_type: MirAccessType, pub offset: usize, pub _node: Singleton, pub _owner: Singleton, @@ -20,10 +21,18 @@ pub struct Accessor { pub span: SourceSpan, } +#[derive(Hash, Clone, PartialEq, Eq, Debug, Default)] +pub enum MirAccessType { + #[default] + Default, + Index(Link), + Matrix(Link, Link), +} + impl Accessor { pub fn create( indexable: Link, - access_type: AccessType, + access_type: MirAccessType, offset: usize, span: SourceSpan, ) -> Link { @@ -41,7 +50,14 @@ impl Accessor { impl Parent for Accessor { type Child = Op; fn children(&self) -> Link>> { - Link::new(vec![self.indexable.clone()]) + let vec = match self.access_type { + MirAccessType::Default => vec![self.indexable.clone()], + MirAccessType::Index(ref idx) => vec![self.indexable.clone(), idx.clone()], + MirAccessType::Matrix(ref row, ref col) => { + vec![self.indexable.clone(), row.clone(), col.clone()] + }, + }; + Link::new(vec) } } diff --git a/mir/src/ir/nodes/ops/add.rs b/mir/src/ir/nodes/ops/add.rs index 6affe1cde..e5bfe1ebf 100644 --- a/mir/src/ir/nodes/ops/add.rs +++ b/mir/src/ir/nodes/ops/add.rs @@ -1,8 +1,8 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent the addition of two MIR ops, `lhs` and `rhs` -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Add { @@ -17,13 +17,7 @@ pub struct Add { impl Add { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Add(Self { - lhs, - rhs, - span, - ..Default::default() - }) - .into() + Op::Add(Self { lhs, rhs, span, ..Default::default() }).into() } } diff --git a/mir/src/ir/nodes/ops/boundary.rs b/mir/src/ir/nodes/ops/boundary.rs index d96cf20b2..8785fa018 100644 --- a/mir/src/ir/nodes/ops/boundary.rs +++ b/mir/src/ir/nodes/ops/boundary.rs @@ -1,12 +1,13 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use std::hash::Hash; + use air_parser::ast::Boundary as BoundaryKind; use miden_diagnostics::{SourceSpan, Spanned}; -use std::hash::Hash; + +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent bounding a given op, `expr`, to access either the first or last row /// /// Note: Boundary ops are only valid to describe boundary constraints, not integrity constraints -/// #[derive(Clone, PartialEq, Default, Eq, Debug, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Boundary { @@ -31,13 +32,7 @@ impl Hash for Boundary { impl Boundary { pub fn create(expr: Link, kind: BoundaryKind, span: SourceSpan) -> Link { - Op::Boundary(Self { - expr, - kind, - span, - ..Default::default() - }) - .into() + Op::Boundary(Self { expr, kind, span, ..Default::default() }).into() } } diff --git a/mir/src/ir/nodes/ops/bus_op.rs b/mir/src/ir/nodes/ops/bus_op.rs index ee895e03d..a868275ca 100644 --- a/mir/src/ir/nodes/ops/bus_op.rs +++ b/mir/src/ir/nodes/ops/bus_op.rs @@ -1,7 +1,9 @@ -use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Parent, Singleton}; -use miden_diagnostics::{SourceSpan, Spanned}; use std::hash::Hash; +use miden_diagnostics::{SourceSpan, Spanned}; + +use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Parent, Singleton}; + #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] pub enum BusOpKind { #[default] diff --git a/mir/src/ir/nodes/ops/call.rs b/mir/src/ir/nodes/ops/call.rs index 854261dfa..efa169328 100644 --- a/mir/src/ir/nodes/ops/call.rs +++ b/mir/src/ir/nodes/ops/call.rs @@ -1,12 +1,15 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Root, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; -/// A MIR operation to represent a call to a given function, a `Root` that represents either a `Function` or an `Evaluator` +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Root, Singleton}; + +/// A MIR operation to represent a call to a given function, a `Root` that represents either a +/// `Function` or an `Evaluator` /// /// Notes: -/// - The `arguments` are the arguments to the function call, and will replace the encountered `Parameter` nodes in the function's body by the corresponding argument during the Inlining pass +/// - The `arguments` are the arguments to the function call, and will replace the encountered +/// `Parameter` nodes in the function's body by the corresponding argument during the Inlining +/// pass /// - After the Inlining pass, no Call ops should be present in the graph -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Call { @@ -35,7 +38,8 @@ impl Call { impl Parent for Call { type Child = Op; fn children(&self) -> Link>> { - // Here, we do not include the function as a child to make it easier to implement the visitor pattern for the passes + // Here, we do not include the function as a child to make it easier to implement the + // visitor pattern for the passes self.arguments.clone() } } diff --git a/mir/src/ir/nodes/ops/enf.rs b/mir/src/ir/nodes/ops/enf.rs index 6abdeb25c..0206a4023 100644 --- a/mir/src/ir/nodes/ops/enf.rs +++ b/mir/src/ir/nodes/ops/enf.rs @@ -1,8 +1,8 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to enforce that a given MIR op, `expr` equals zero -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Enf { @@ -16,12 +16,7 @@ pub struct Enf { impl Enf { pub fn create(expr: Link, span: SourceSpan) -> Link { - Op::Enf(Self { - expr, - span, - ..Default::default() - }) - .into() + Op::Enf(Self { expr, span, ..Default::default() }).into() } } diff --git a/mir/src/ir/nodes/ops/exp.rs b/mir/src/ir/nodes/ops/exp.rs index b317cc159..c888d1cf6 100644 --- a/mir/src/ir/nodes/ops/exp.rs +++ b/mir/src/ir/nodes/ops/exp.rs @@ -1,10 +1,10 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent the exponentiation of a MIR op, `lhs` by another, `rhs` /// /// Note: `rhs` should be a constant integer after all the passes -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Exp { @@ -19,13 +19,7 @@ pub struct Exp { impl Exp { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Exp(Self { - lhs, - rhs, - span, - ..Default::default() - }) - .into() + Op::Exp(Self { lhs, rhs, span, ..Default::default() }).into() } } diff --git a/mir/src/ir/nodes/ops/fold.rs b/mir/src/ir/nodes/ops/fold.rs index 03dad99a6..0ca0a7c9d 100644 --- a/mir/src/ir/nodes/ops/fold.rs +++ b/mir/src/ir/nodes/ops/fold.rs @@ -1,13 +1,15 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; -/// A MIR operation to represent folding a given Vector operator according to a given operator and initial value +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + +/// A MIR operation to represent folding a given Vector operator according to a given operator and +/// initial value /// /// Notes: /// - operators, of type FoldOperator, can either represent an Addition or a Multiplication -/// - the Fold operation will be unrolled during the Unrolling pass (as a chain of Add or Mul operations) +/// - the Fold operation will be unrolled during the Unrolling pass (as a chain of Add or Mul +/// operations) /// - After the Unrolling pass, no Fold ops should be present in the graph -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Fold { diff --git a/mir/src/ir/nodes/ops/for_op.rs b/mir/src/ir/nodes/ops/for_op.rs index 7b67d8b90..9ab9a200d 100644 --- a/mir/src/ir/nodes/ops/for_op.rs +++ b/mir/src/ir/nodes/ops/for_op.rs @@ -5,11 +5,10 @@ use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singlet /// A MIR operation to represent list comprehensions. /// /// Notes: -/// - the For operation will be unrolled into a Vector during the Unrolling pass, -/// each element of the Vector will be the result of the expression expr` for the given iterators indices +/// - the For operation will be unrolled into a Vector during the Unrolling pass, each element of +/// the Vector will be the result of the expression expr` for the given iterators indices /// - Optionally, a selector can be provided (useful to represent conditional enforcements) /// - After the Unrolling pass, no For ops should be present in the graph -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct For { diff --git a/mir/src/ir/nodes/ops/if_op.rs b/mir/src/ir/nodes/ops/if_op.rs index 02ae44da2..023f65634 100644 --- a/mir/src/ir/nodes/ops/if_op.rs +++ b/mir/src/ir/nodes/ops/if_op.rs @@ -1,37 +1,41 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent conditional constraints /// /// Notes: -/// - the If operation will be unrolled into a Vector during the Unrolling pass, combining the then and else branches. -/// For example, If(s, vec![a, b], vec![c]) will be unrolled into: vec![s * a, s * b, (1 - s) * c] +/// - the If operation will be unrolled into a Vector during the Unrolling pass, combining the then +/// and else branches. For example, If(s, vec![a, b], vec![c]) will be unrolled into: vec![s * a, +/// s * b, (1 - s) * c] /// - After the Unrolling pass, no If ops should be present in the graph -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct If { pub parents: Vec>, - pub condition: Link, - pub then_branch: Link, - pub else_branch: Link, + pub match_arms: Link>, pub _node: Singleton, pub _owner: Singleton, #[span] pub span: SourceSpan, } +#[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] +pub struct MatchArm { + pub condition: Link, + pub expr: Link, +} + +impl MatchArm { + pub fn new(expr: Link, condition: Link) -> Self { + Self { condition, expr } + } +} + impl If { - pub fn create( - condition: Link, - then_branch: Link, - else_branch: Link, - span: SourceSpan, - ) -> Link { + pub fn create(match_arms: Vec, span: SourceSpan) -> Link { Op::If(Self { - condition, - then_branch, - else_branch, + match_arms: match_arms.into(), span, ..Default::default() }) @@ -42,11 +46,13 @@ impl If { impl Parent for If { type Child = Op; fn children(&self) -> Link>> { - Link::new(vec![ - self.condition.clone(), - self.then_branch.clone(), - self.else_branch.clone(), - ]) + Link::new( + self.match_arms + .borrow() + .iter() + .flat_map(|arm| vec![arm.condition.clone(), arm.expr.clone()]) + .collect(), + ) } } diff --git a/mir/src/ir/nodes/ops/matrix.rs b/mir/src/ir/nodes/ops/matrix.rs index ba539c5f3..277783ca4 100644 --- a/mir/src/ir/nodes/ops/matrix.rs +++ b/mir/src/ir/nodes/ops/matrix.rs @@ -1,8 +1,8 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent a matrix of MIR ops of a given size -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Matrix { diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 5c05d46aa..c45120d47 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -15,7 +15,7 @@ mod sub; mod value; mod vector; -pub use accessor::Accessor; +pub use accessor::{Accessor, MirAccessType}; pub use add::Add; pub use boundary::Boundary; pub use bus_op::{BusOp, BusOpKind}; @@ -24,7 +24,7 @@ pub use enf::Enf; pub use exp::Exp; pub use fold::{Fold, FoldOperator}; pub use for_op::For; -pub use if_op::If; +pub use if_op::{If, MatchArm}; pub use matrix::Matrix; pub use mul::Mul; pub use parameter::Parameter; diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index 05954c3af..247123755 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -1,8 +1,8 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent the multiplication of two MIR ops, `lhs` and `rhs` -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Mul { @@ -17,13 +17,7 @@ pub struct Mul { impl Mul { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Mul(Self { - lhs, - rhs, - span, - ..Default::default() - }) - .into() + Op::Mul(Self { lhs, rhs, span, ..Default::default() }).into() } } diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 7dbdc6b3b..939215d27 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -1,7 +1,9 @@ +use std::hash::{Hash, Hasher}; + +use miden_diagnostics::{SourceSpan, Spanned}; + use super::MirType; use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; -use miden_diagnostics::{SourceSpan, Spanned}; -use std::hash::{Hash, Hasher}; /// A MIR operation to represent a `Parameter` in a function or evaluator. /// Also used in If and For loops to represent declared parameters. diff --git a/mir/src/ir/nodes/ops/sub.rs b/mir/src/ir/nodes/ops/sub.rs index 99b49a353..94e2c00ad 100644 --- a/mir/src/ir/nodes/ops/sub.rs +++ b/mir/src/ir/nodes/ops/sub.rs @@ -1,8 +1,8 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; -/// A MIR operation to represent the substraction of two MIR ops, `lhs` and `rhs` -/// +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + +/// A MIR operation to represent the subtraction of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Sub { @@ -17,13 +17,7 @@ pub struct Sub { impl Sub { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Sub(Self { - lhs, - rhs, - span, - ..Default::default() - }) - .into() + Op::Sub(Self { lhs, rhs, span, ..Default::default() }).into() } } diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index 9e117b135..dbd7791be 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -1,10 +1,13 @@ -use air_parser::ast::{self, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId}; +use air_parser::ast::{ + self, BusType, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId, +}; use miden_diagnostics::{SourceSpan, Spanned}; use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Singleton}; /// A MIR operation to represent a known value, [Value]. -/// Wraps a [SpannedMirValue] to represent a known value in the [MIR]. +/// +/// Wraps a [SpannedMirValue] to represent a known value in the MIR. #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Value { @@ -16,11 +19,13 @@ pub struct Value { impl Value { pub fn create(value: SpannedMirValue) -> Link { - Op::Value(Self { - value, - ..Default::default() - }) - .into() + Op::Value(Self { value, ..Default::default() }).into() + } + pub fn get_inner_const(&self) -> Option { + match &self.value.value { + MirValue::Constant(ConstantValue::Felt(v)) => Some(*v), + _ => None, + } } } @@ -49,7 +54,7 @@ impl Child for Value { } } -/// Represents a known value in the [MIR]. +/// Represents a known value in the MIR. /// /// Values are either constant, or evaluated at runtime using the context /// provided to an AirScript program (i.e. public inputs, etc.). @@ -69,7 +74,8 @@ pub enum MirValue { PublicInputTable(PublicInputTableAccess), /// A reference to a specific index in the random values array. /// - /// Random values are not provided by the user in the AirScript program, but are used to expand Bus constraints. + /// Random values are not provided by the user in the AirScript program, but are used to expand + /// Bus constraints. RandomValue(usize), /// A binding to a set of consecutive trace columns of a given size. TraceAccessBinding(TraceAccessBinding), @@ -81,7 +87,7 @@ pub enum MirValue { Unconstrained, } -/// [BusAccess] is like [SymbolAccess], but is used to describe an access to a specific bus. +/// [BusAccess] is like SymbolAccess, but is used to describe an access to a specific bus. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct BusAccess { /// The trace segment being accessed @@ -109,7 +115,8 @@ pub enum ConstantValue { Matrix(Vec>), } -/// [TraceAccess] is like [SymbolAccess], but is used to describe an access to a specific trace column or columns. +/// [TraceAccess] is like SymbolAccess, but is used to describe an access to a specific trace +/// column or columns. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TraceAccess { /// The trace segment being accessed @@ -127,11 +134,7 @@ pub struct TraceAccess { impl TraceAccess { /// Creates a new [TraceAccess]. pub const fn new(segment: TraceSegmentId, column: TraceColumnIndex, row_offset: usize) -> Self { - Self { - segment, - column, - row_offset, - } + Self { segment, column, row_offset } } } @@ -144,7 +147,7 @@ pub struct TraceAccessBinding { pub size: usize, } -/// Represents a typed value in the [MIR] +/// Represents a typed value in the MIR. #[derive(Debug, Eq, PartialEq, Clone, Hash, Spanned)] pub struct SpannedMirValue { #[span] @@ -170,8 +173,8 @@ impl From for MirType { } } -/// Represents an access of a [PeriodicColumn], similar in nature to [TraceAccess]. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +/// Represents an access of a PeriodicColumn, similar in nature to [TraceAccess]. +#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)] pub struct PeriodicColumnAccess { pub name: QualifiedIdentifier, pub cycle: usize, @@ -182,8 +185,8 @@ impl PeriodicColumnAccess { } } -/// Represents an access of a [PublicInput], similar in nature to [TraceAccess]. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +/// Represents an access of a PublicInput, similar in nature to [TraceAccess]. +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)] pub struct PublicInputAccess { /// The name of the public input to access pub name: Identifier, @@ -203,28 +206,23 @@ impl PublicInputAccess { pub struct PublicInputTableAccess { /// The name of the public input to bind pub table_name: Identifier, - /// The name of the bus to bind - /// The bus name is not always known at the time of instantiation, - /// making it an Option allows setting it later. - bus_name: Option, /// The number of columns in the table pub num_cols: usize, + /// The type of bus to bind (multiset or logUp). + /// The bus type is not always known at the time of instantiation, + /// making it an Option allows setting it later. + bus_type: Option, } impl PublicInputTableAccess { pub const fn new(table_name: Identifier, num_cols: usize) -> Self { - Self { - table_name, - bus_name: None, - num_cols, - } + Self { table_name, num_cols, bus_type: None } } - pub fn set_bus_name(&mut self, bus_name: Identifier) { - self.bus_name = Some(bus_name); + pub fn set_bus_type(&mut self, bus_type: BusType) { + self.bus_type = Some(bus_type); } - pub fn bus_name(&self) -> Identifier { - self.bus_name - .expect("Bus name should have already been set") + pub fn bus_type(&self) -> BusType { + self.bus_type.expect("Bus type should have already been set") } } diff --git a/mir/src/ir/nodes/ops/vector.rs b/mir/src/ir/nodes/ops/vector.rs index 7bccb34fa..fc53c3d67 100644 --- a/mir/src/ir/nodes/ops/vector.rs +++ b/mir/src/ir/nodes/ops/vector.rs @@ -1,8 +1,8 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; + /// A MIR operation to represent a vector of MIR ops of a given size -/// #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Vector { diff --git a/mir/src/ir/nodes/root.rs b/mir/src/ir/nodes/root.rs index accd7fcbd..0c8a4c9d5 100644 --- a/mir/src/ir/nodes/root.rs +++ b/mir/src/ir/nodes/root.rs @@ -58,24 +58,18 @@ impl Link { pub fn as_node(&self) -> Link { let back: BackLink = self.clone().into(); match self.borrow_mut().deref_mut() { - Root::Function(Function { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + Root::Function(Function { _node: Singleton(Some(link)), .. }) => link.clone(), Root::Function(f) => { let node: Link = Node::Function(back).into(); f._node = Singleton::from(node.clone()); node - } - Root::Evaluator(Evaluator { - _node: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Root::Evaluator(Evaluator { _node: Singleton(Some(link)), .. }) => link.clone(), Root::Evaluator(e) => { let node: Link = Node::Evaluator(back).into(); e._node = Singleton::from(node.clone()); node - } + }, Root::None(span) => Node::None(*span).into(), } } @@ -84,24 +78,18 @@ impl Link { pub fn as_owner(&self) -> Link { let back: BackLink = self.clone().into(); match self.borrow_mut().deref_mut() { - Root::Function(Function { - _owner: Singleton(Some(link)), - .. - }) => link.clone(), + Root::Function(Function { _owner: Singleton(Some(link)), .. }) => link.clone(), Root::Function(f) => { let owner: Link = Owner::Function(back).into(); f._owner = Singleton::from(owner.clone()); owner - } - Root::Evaluator(Evaluator { - _owner: Singleton(Some(link)), - .. - }) => link.clone(), + }, + Root::Evaluator(Evaluator { _owner: Singleton(Some(link)), .. }) => link.clone(), Root::Evaluator(e) => { let owner: Link = Owner::Evaluator(back).into(); e._owner = Singleton::from(owner.clone()); owner - } + }, Root::None(span) => Owner::None(*span).into(), } } diff --git a/mir/src/ir/nodes/roots/evaluator.rs b/mir/src/ir/nodes/roots/evaluator.rs index 49df90cc9..35e0fd6e7 100644 --- a/mir/src/ir/nodes/roots/evaluator.rs +++ b/mir/src/ir/nodes/roots/evaluator.rs @@ -1,6 +1,7 @@ -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; + /// A MIR Root to represent a Evaluator definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Root)] diff --git a/mir/src/ir/nodes/roots/function.rs b/mir/src/ir/nodes/roots/function.rs index 9fa3fa150..12051f466 100644 --- a/mir/src/ir/nodes/roots/function.rs +++ b/mir/src/ir/nodes/roots/function.rs @@ -1,6 +1,7 @@ -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; +use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; + /// A MIR Root to represent a Function definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Root)] diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index 77edc614c..d1a2c513b 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -1,6 +1,7 @@ -use miden_diagnostics::{SourceSpan, Spanned}; use std::ops::Deref; +use miden_diagnostics::{SourceSpan, Spanned}; + use crate::ir::{BackLink, Child, Link, Node, Op, Parent, Root}; /// The nodes that can own [Op] nodes @@ -10,8 +11,7 @@ use crate::ir::{BackLink, Child, Link, Node, Op, Parent, Root}; /// so it can be updated to the correct variant when the inner struct is updated /// Note: The [None] variant is used to represent a [Owner] that: /// - is not yet initialized -/// - no longer exists (due to its ref-count dropping to 0). -/// We refer to those as "stale" nodes. +/// - no longer exists (due to its ref-count dropping to 0). We refer to those as "stale" nodes. #[derive(Clone, Eq, Debug, Spanned)] pub enum Owner { Function(BackLink), @@ -208,7 +208,8 @@ impl Link { Root::None(span) => Owner::None(*span), }; } else { - unreachable!(); + // If the [Owner] is stale, we set it to None + to_update = Owner::None(self.span()); } *self.borrow_mut() = to_update; diff --git a/mir/src/ir/quad_eval.rs b/mir/src/ir/quad_eval.rs new file mode 100644 index 000000000..d5b0774e6 --- /dev/null +++ b/mir/src/ir/quad_eval.rs @@ -0,0 +1,225 @@ +extern crate alloc; +use alloc::collections::BTreeMap; +use std::ops::Deref; + +use air_parser::ast::TraceSegmentId; +use miden_core::{Felt, QuadExtension}; +use rand::{distr::Uniform, prelude::*}; +use winter_math::{FieldElement, StarkField}; + +use crate::{ + CompileError, + ir::{ConstantValue, Link, MirValue, Op, PeriodicColumnAccess, PublicInputAccess}, +}; + +pub type QuadFelt = QuadExtension; + +/// Returns a random [QuadFelt] value. +fn rand_quad_felt(rng: &mut R) -> QuadFelt { + // Note: using a uniform distribution over all u64 values would lead to a non-uniform + // distribution of Felt values. + let distr = Uniform::new(0, Felt::MODULUS).unwrap(); // Unwrap is safe as Felt::MODULUS is > 0 + + QuadFelt::new(Felt::new(rng.sample(distr)), Felt::new(rng.sample(distr))) +} + +/// Returns a [QuadFelt] corresponding to a given base element. +pub fn const_quad_felt(felt: Felt) -> QuadFelt { + QuadFelt::new(felt, Felt::ZERO) +} + +/// Helper function to either query an existing evaluation or create a new random one if the index +/// is out of bounds. +pub fn query_indexed_eval( + rng: &mut R, + evaluations: &mut Vec, + index: usize, +) -> QuadFelt { + if evaluations.len() <= index { + evaluations.resize_with(index + 1, || rand_quad_felt(rng)); + } + evaluations[index] +} + +/// Helper function to either query an existing evaluation or create a new random one if the element +/// is not present in the map. +pub fn query_mapped_eval( + rng: &mut R, + evaluation_map: &mut BTreeMap, + element: &K, +) -> QuadFelt { + *evaluation_map.entry(element.clone()).or_insert_with(|| rand_quad_felt(rng)) +} + +/// Holds the random inputs taken by leaf nodes, in order to persist them across different node +/// evaluations. +#[derive(Debug, Clone, Default)] +pub struct RandomInputs { + rng: ThreadRng, + // A vector to hold the random values taken for the main trace, indexed in the following way: + // $main[0], $main[0]', $main[1], $main[1]', $main[2], ... + main_trace: Vec, + rand_values: Vec, + public_inputs: BTreeMap, + periodic_columns: BTreeMap, +} + +impl RandomInputs { + /// Evaluates a given MIR node at random points. + /// + /// Note that we currently assume this will be called only during the unrolling phase, some + /// operation types are not handled. + pub fn eval(&mut self, op: Link) -> Result { + match op.borrow().deref() { + Op::Enf(e) => { + let expr = self.eval(e.expr.clone())?; + Ok(expr) + }, + Op::Add(a) => { + let lhs = self.eval(a.lhs.clone())?; + let rhs = self.eval(a.rhs.clone())?; + Ok(lhs + rhs) + }, + Op::Sub(s) => { + let lhs = self.eval(s.lhs.clone())?; + let rhs = self.eval(s.rhs.clone())?; + Ok(lhs - rhs) + }, + Op::Mul(m) => { + let lhs = self.eval(m.lhs.clone())?; + let rhs = self.eval(m.rhs.clone())?; + Ok(lhs * rhs) + }, + Op::Exp(e) => { + let lhs = self.eval(e.lhs.clone())?; + let rhs = self.eval(e.rhs.clone())?; + let base_elems = rhs.to_base_elements(); + if base_elems[1].as_int() != 0 { + // If the second base element is not zero, we cannot evaluate the exponentiation + // This should not happen because either the non-constant powers would have been + // caught in the parser or it is in the body of a list + // comprehension which has not been expanded yet (and we would have a + // Op::Parameter instead) + Err(CompileError::Failed) + } else { + let power = rhs.to_base_elements()[0].as_int(); + Ok(lhs.exp(power)) + } + }, + Op::Parameter(_) => { + // We cannot easily detect that two parameters refer to the same For, so we consider + // them to be all different + Ok(rand_quad_felt(&mut self.rng)) + }, + Op::Value(v) => { + match &v.value.value { + MirValue::Constant(ConstantValue::Felt(c)) => { + let felt = Felt::new(*c); + Ok(const_quad_felt(felt)) + }, + // We associate a random value to each trace access of the main trace, + // indexed in the following way, each column having two + // distinct evaluations to account for the two possible row offsets: + // $main[0], $main[0]', $main[1], $main[1]', $main[2], ... + // Note: if we encounter a trace access corresponding to an index we have not + // yet evaluated, we will randomly generate values for + // this trace access, but also for all previous indices. + MirValue::TraceAccess(trace_access) => match trace_access.segment { + TraceSegmentId::Main => { + let index = trace_access.column * 2 + trace_access.row_offset; + Ok(query_indexed_eval(&mut self.rng, &mut self.main_trace, index)) + }, + _ => { + println!( + "Unexpected trace_access segment in RandomInputs::eval: {}. This segment should only be used for buses and should be handled separately.", + trace_access.segment + ); + Err(CompileError::Failed) + }, + }, + MirValue::RandomValue(u) => { + Ok(query_indexed_eval(&mut self.rng, &mut self.rand_values, *u)) + }, + // For PublicInput and PeriodicColumn, we use the Hash of the element to + // associate a unique random value or each public input and + // each periodic column access + MirValue::PublicInput(pi) => { + Ok(query_mapped_eval(&mut self.rng, &mut self.public_inputs, pi)) + }, + MirValue::PeriodicColumn(pc) => { + Ok(query_mapped_eval(&mut self.rng, &mut self.periodic_columns, pc)) + }, + MirValue::Null + | MirValue::BusAccess(_) + | MirValue::Unconstrained + | MirValue::PublicInputTable(_) => { + // Bus related values, we expect these to be handled separately + println!( + "Unexpected values in RandomInputs::eval, the following op should only be used for Bus expressions and should be handled separately: {op:?}" + ); + Err(CompileError::Failed) + }, + MirValue::TraceAccessBinding(_) | MirValue::Constant(_) => { + // These should have been unrolled as Vector and handled beforehand + println!( + "Unexpected values in RandomInputs::eval, the following op should already be Unrolled at this stage: {op:?}" + ); + Err(CompileError::Failed) + }, + } + }, + Op::Accessor(a) => { + if let Op::Value(v) = a.indexable.borrow().deref() + && let MirValue::TraceAccess(trace_access) = v.value.value + { + // Use accessor offset instead of the trace_access row_offset + let index = trace_access.column * 2 + a.offset; + match trace_access.segment { + TraceSegmentId::Main => { + return Ok(query_indexed_eval( + &mut self.rng, + &mut self.main_trace, + index, + )); + }, + _ => { + println!( + "Unexpected trace_access segment in RandomInputs::eval: {}. This segment should only be used for buses and should be handled separately.", + trace_access.segment + ); + return Err(CompileError::Failed); + }, + } + } + let indexable = self.eval(a.indexable.clone())?; + Ok(indexable) + }, + Op::Call(_) => { + // We expect Inlining to have already been done before Unrolling + println!( + "Unexpected operation in RandomInputs::eval, Calls should have been Inlined in a previous pass: {op:?}" + ); + Err(CompileError::Failed) + }, + Op::Fold(_) | Op::Vector(_) | Op::Matrix(_) | Op::For(_) | Op::If(_) | Op::None(_) => { + println!( + "Unexpected operation in RandomInputs::eval, the following operation should already be Unrolled at this stage: {op:?}" + ); + Err(CompileError::Failed) + }, + Op::Boundary(_) => { + println!( + "Unexpected operation in RandomInputs::eval, the following operation is not valid in integrity constraints: {op:?}" + ); + Err(CompileError::Failed) + }, + Op::BusOp(_) => { + // Bus related operation, we expect these to be handled separately + println!( + "Unexpected operation in RandomInputs::eval, the following op should only be used for Bus expressions and should be handled separately: {op:?}" + ); + Err(CompileError::Failed) + }, + } + } +} diff --git a/mir/src/ir/utils.rs b/mir/src/ir/utils.rs index e9fbb32c0..742f224c7 100644 --- a/mir/src/ir/utils.rs +++ b/mir/src/ir/utils.rs @@ -13,10 +13,10 @@ pub fn strip_spans(mir: &mut Mir) { let graph = mir.constraint_graph_mut(); let mut visitor = StripSpansVisitor::default(); match visitor.run(graph) { - Ok(_) => {} + Ok(_) => {}, Err(e) => { panic!("Error stripping spans: {e:?}"); - } + }, } } @@ -47,14 +47,10 @@ pub fn extract_roots( } if include_bus { let buses = graph.get_bus_nodes(); - let bus_columns = buses - .iter() - .flat_map(|b| b.borrow().columns.clone()) - .map(|n| n.as_node()); - let bus_latches = buses - .iter() - .flat_map(|b| b.borrow().latches.clone()) - .map(|n| n.as_node()); + let bus_columns = + buses.iter().flat_map(|b| b.borrow().columns.clone()).map(|n| n.as_node()); + let bus_latches = + buses.iter().flat_map(|b| b.borrow().latches.clone()).map(|n| n.as_node()); nodes.extend(bus_columns); nodes.extend(bus_latches); } @@ -227,35 +223,35 @@ impl Visitor for StripSpansVisitor { let mut value = value.as_value_mut().unwrap(); value.value.span = Default::default(); match &mut value.value.value { - MirValue::Constant(_) => {} - MirValue::TraceAccess(_) => {} + MirValue::Constant(_) => {}, + MirValue::TraceAccess(_) => {}, MirValue::PeriodicColumn(v) => { - v.name.module.0 = Span::new(SourceSpan::default(), v.name.module.0.item); + v.name.module.0 = Span::new(SourceSpan::default(), v.name.module.0.item.clone()); match v.name.item { NamespacedIdentifier::Function(f) => { v.name.item = NamespacedIdentifier::Function(Identifier::new( SourceSpan::default(), f.0.item, )); - } + }, NamespacedIdentifier::Binding(b) => { v.name.item = NamespacedIdentifier::Binding(Identifier::new( SourceSpan::default(), b.0.item, )); - } + }, }; - } + }, MirValue::PublicInput(v) => { v.name.0 = Span::new(SourceSpan::default(), v.name.0.item); - } + }, MirValue::PublicInputTable(v) => { v.table_name.0 = Span::new(SourceSpan::default(), v.table_name.0.item); - } - MirValue::RandomValue(_) => {} - MirValue::TraceAccessBinding(_) => {} - MirValue::BusAccess(_) => {} - MirValue::Null | MirValue::Unconstrained => {} + }, + MirValue::RandomValue(_) => {}, + MirValue::TraceAccessBinding(_) => {}, + MirValue::BusAccess(_) => {}, + MirValue::Null | MirValue::Unconstrained => {}, } Ok(()) } diff --git a/mir/src/lib.rs b/mir/src/lib.rs index 5135360c2..b8c6c5c81 100644 --- a/mir/src/lib.rs +++ b/mir/src/lib.rs @@ -1,13 +1,38 @@ -mod codegen; - pub mod ir; pub mod passes; #[cfg(test)] mod tests; -pub use self::codegen::CodeGenerator; +use air_parser::ast::Program; +use air_pass::Pass; +use miden_diagnostics::{Diagnostic, DiagnosticsHandler, ToDiagnostic}; + +use crate::ir::Mir; + +/// Abstracts the various passes done on the MIR representation of the program. +pub struct MirPasses<'a> { + diagnostics: &'a DiagnosticsHandler, +} + +impl<'a> MirPasses<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } +} + +impl Pass for MirPasses<'_> { + type Input<'a> = Program; + type Output<'a> = Mir; + type Error = CompileError; -use miden_diagnostics::{Diagnostic, ToDiagnostic}; + fn run<'a>(&mut self, input: Self::Input<'a>) -> Result, Self::Error> { + let mut passes = passes::AstToMir::new(self.diagnostics) + .chain(passes::Inlining::new(self.diagnostics)) + .chain(passes::Unrolling::new(self.diagnostics)) + .chain(passes::ConstantPropagation::new(self.diagnostics)); + passes.run(input) + } +} /// Error type that can be returned during the Mir passes #[derive(Debug, thiserror::Error)] diff --git a/mir/src/passes/constant_propagation.rs b/mir/src/passes/constant_propagation.rs index 4744bf4bf..c0a08505e 100644 --- a/mir/src/passes/constant_propagation.rs +++ b/mir/src/passes/constant_propagation.rs @@ -1,18 +1,18 @@ +use std::ops::Deref; + use air_pass::Pass; -use miden_diagnostics::DiagnosticsHandler; +use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; use super::visitor::Visitor; use crate::{ - ir::{Link, Mir, Node}, CompileError, + ir::{ + ConstantValue, Graph, Link, Mir, MirAccessType, MirValue, Node, Op, Parent, + SpannedMirValue, Value, + }, + passes::handle_accessor_visit, }; -/// TODO MIR: -/// If needed, implement constant propagation / folding pass on MIR -/// Run through every operation in the graph -/// If we can deduce the resulting value based on the constants of the operands, -/// replace the operation itself with a constant -/// pub struct ConstantPropagation<'a> { #[allow(unused)] diagnostics: &'a DiagnosticsHandler, @@ -31,35 +31,222 @@ impl Pass for ConstantPropagation<'_> { } impl<'a> ConstantPropagation<'a> { - #[allow(unused)] pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { - diagnostics, - work_stack: vec![], + Self { diagnostics, work_stack: vec![] } + } +} + +// For the ConstantPropagation, we use a tweaked version of the Visitor trait, +// each visit_*_bis function returns an Option> instead of Result<(), CompileError>, +// to mutate the nodes (e.g. modifying a Add(lhs, rhs) to Value(lhs + rhs)). +impl ConstantPropagation<'_> { + fn visit_add_bis(&mut self, add: Link) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let add_ref = add.as_add().unwrap(); + let lhs = add_ref.lhs.clone(); + let rhs = add_ref.rhs.clone(); + + if let Some(0) = get_inner_const(&lhs) { + Ok(Some(rhs)) + } else if let Some(0) = get_inner_const(&rhs) { + Ok(Some(lhs)) + } else { + try_fold_const_binary_op(lhs, rhs, add.clone(), add_ref.span()) + } + } + + fn visit_sub_bis(&mut self, sub: Link) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let sub_ref = sub.as_sub().unwrap(); + let lhs = sub_ref.lhs.clone(); + let rhs = sub_ref.rhs.clone(); + + if let Some(0) = get_inner_const(&rhs) { + Ok(Some(lhs)) + } else { + try_fold_const_binary_op(lhs, rhs, sub.clone(), sub_ref.span()) } } + + fn visit_mul_bis(&mut self, mul: Link) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let mul_ref = mul.as_mul().unwrap(); + let lhs = mul_ref.lhs.clone(); + let rhs = mul_ref.rhs.clone(); + + match (get_inner_const(&lhs), get_inner_const(&rhs)) { + (Some(0), _) | (_, Some(0)) => Ok(Some(Value::create(SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(0)), + span: mul_ref.span, + }))), + (Some(1), _) => Ok(Some(rhs)), + (_, Some(1)) => Ok(Some(lhs)), + _ => try_fold_const_binary_op(lhs, rhs, mul.clone(), mul_ref.span()), + } + } + + fn visit_exp_bis(&mut self, exp: Link) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let exp_ref = exp.as_exp().unwrap(); + let lhs = exp_ref.lhs.clone(); + let rhs = exp_ref.rhs.clone(); + + // x^0 = 1 (must take precedence, including 0^0) + if let Some(0) = get_inner_const(&rhs) { + Ok(Some(Value::create(SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(1)), + span: exp_ref.span, + }))) + } else if let Some(0) = get_inner_const(&lhs) { + // 0^k = 0, but only when k is known and non-zero + if get_inner_const(&rhs).is_some() { + return Ok(Some(Value::create(SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(0)), + span: exp_ref.span, + }))); + } + Ok(None) + } else { + try_fold_const_binary_op(lhs, rhs, exp.clone(), exp_ref.span()) + } + } + + fn visit_accessor_bis(&mut self, accessor: Link) -> Result>, CompileError> { + handle_accessor_visit(accessor.clone(), true, self.diagnostics) + } } impl Visitor for ConstantPropagation<'_> { fn work_stack(&mut self) -> &mut Vec> { &mut self.work_stack } - fn root_nodes_to_visit( - &self, - graph: &crate::ir::Graph, - ) -> Vec> { + + // We visit all boundary constraints and all integrity constraints + // No need to visit the functions or evaluators, as they should have been inlined before this + // pass + fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let bus_roots: Vec<_> = graph + .buses + .values() + .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) + .collect(); let combined_roots = boundary_constraints_roots_ref .clone() .into_iter() .map(|bc| bc.as_node()) - .chain( - integrity_constraints_roots_ref - .clone() - .into_iter() - .map(|ic| ic.as_node()), - ); + .chain(integrity_constraints_roots_ref.clone().into_iter().map(|ic| ic.as_node())) + .chain(bus_roots.into_iter().map(|b| b.as_node())); combined_roots.collect() } + + fn visit_node(&mut self, _graph: &mut Graph, node: Link) -> Result<(), CompileError> { + if node.is_stale() { + return Ok(()); + } + + // In this pass, we both need to dispatch the visitor depending on the node type, + // and also mutate the node if needed. We implement custom visit_*_bis methods + // that returns a Some(updated_node) if we need to update the node's value. + let updated_op: Option> = match node.borrow().deref() { + Node::Add(a) => a.to_link().map_or(Ok(None), |el| self.visit_add_bis(el))?, + Node::Sub(s) => s.to_link().map_or(Ok(None), |el| self.visit_sub_bis(el))?, + Node::Mul(m) => m.to_link().map_or(Ok(None), |el| self.visit_mul_bis(el))?, + Node::Exp(e) => e.to_link().map_or(Ok(None), |el| self.visit_exp_bis(el))?, + Node::Accessor(e) => e.to_link().map_or(Ok(None), |el| self.visit_accessor_bis(el))?, + Node::Vector(_) + | Node::Matrix(_) + | Node::Enf(_) + | Node::Boundary(_) + | Node::BusOp(_) + | Node::Value(_) + | Node::None(_) => None, + _ => { + unreachable!( + "Unexpected node during Mir's ConstantPropagation: Function, Evaluators, Calls, If, For, Fold and Parameter should have been inlined and unrolled before this pass. Found: {:?}", + node + ); + }, + }; + + // We update the node if needed + if let Some(updated_op) = updated_op { + node.as_op().unwrap().set(&updated_op); + } + + Ok(()) + } +} + +// HELPERS FUNCTIONS +// ================================================================================================ + +pub fn get_inner_const(value: &Link) -> Option { + match value.borrow().deref() { + Op::Value(v) => v.get_inner_const(), + Op::Accessor(accessor) => { + match (accessor.access_type.clone(), accessor.indexable.borrow().deref()) { + (MirAccessType::Default, _) => get_inner_const(&accessor.indexable), + (MirAccessType::Index(index), Op::Vector(vector)) => { + let index = get_inner_const(&index).expect("Expected constant index") as usize; + + let vec_children = vector.children(); + let vec_ref = vec_children.borrow(); + vec_ref.get(index).and_then(get_inner_const) + }, + (MirAccessType::Matrix(row, col), Op::Matrix(matrix)) => { + let row = get_inner_const(&row).expect("Expected constant row") as usize; + let col = get_inner_const(&col).expect("Expected constant column") as usize; + + let mat_children = matrix.children(); + let mat_ref = mat_children.borrow(); + mat_ref.get(row).and_then(|row| { + if let Op::Vector(row_vector) = row.borrow().deref() { + let row_children = row_vector.children(); + let row_ref = row_children.borrow(); + row_ref.get(col).and_then(get_inner_const) + } else { + None + } + }) + }, + _ => None, + } + }, + _ => None, + } +} + +/// Helper function to fold constant binary operations (Add, Sub, Mul, Exp) +/// into their resulting value if both operands are constant values. +fn try_fold_const_binary_op( + lhs: Link, + rhs: Link, + parent: Link, + span: SourceSpan, +) -> Result>, CompileError> { + let mut updated_binary_op = None; + + if let (Some(lhs_const), Some(rhs_const)) = (get_inner_const(&lhs), get_inner_const(&rhs)) { + let folded = match parent.borrow().deref() { + Op::Add(_) => lhs_const.checked_add(rhs_const), + Op::Sub(_) => lhs_const.checked_sub(rhs_const), + Op::Mul(_) => lhs_const.checked_mul(rhs_const), + Op::Exp(_) => { + let rhs_const = rhs_const.try_into().map_err(|_| CompileError::Failed)?; + lhs_const.checked_pow(rhs_const) + }, + _ => unreachable!("Unexpected parent operation: {:?}", parent), + }; + if let Some(folded) = folded { + let new_value = Value::create(SpannedMirValue { + value: MirValue::Constant(crate::ir::ConstantValue::Felt(folded)), + span, + }); + updated_binary_op = Some(new_value); + } + } + + Ok(updated_binary_op) } diff --git a/mir/src/passes/inlining.rs b/mir/src/passes/inlining.rs index 01c6f1e3e..9849222d5 100644 --- a/mir/src/passes/inlining.rs +++ b/mir/src/passes/inlining.rs @@ -3,33 +3,87 @@ use std::{collections::HashMap, ops::Deref}; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Spanned}; +use super::{duplicate_node_or_replace, visitor::Visitor}; use crate::{ CompileError, ir::{ - Accessor, Graph, Link, Mir, MirType, MirValue, Node, Op, Parameter, Parent, Root, + Graph, Link, Mir, MirAccessType, MirType, MirValue, Node, Op, Parameter, Parent, Root, SpannedMirValue, TraceAccessBinding, Value, Vector, }, + passes::constant_propagation::get_inner_const, }; -use super::{duplicate_node_or_replace, visitor::Visitor}; - -/// This pass handles inlining of Call nodes at there call sites. +/// This pass handles inlining of `Call` nodes at their call sites. /// /// It works in three steps: /// * Firstly, we visit the graph to build the call dependency graph. -/// * This dependency graph is then used to compute the wanted inlining order -/// (we first replace calls to callees that do not have Calls in their body). -/// If it is not possible create this order, this means there is a circular dependency. -/// * Then, we visit the graph again at each Call nodes, building a duplicate of the body -/// (with Parameter replaced by call arguments), and replacing the Call node by this duplicate body. +/// * This dependency graph is then used to compute the wanted inlining order (we first replace +/// calls to callees that do not have `Call` in their body). If it is not possible to create this +/// order, this means there is a circular dependency. +/// * Then, we visit the graph again at each `Call` nodes, building a duplicate of the body (with +/// Parameter replaced by call arguments), and replacing the `Call` node by this duplicate body. /// pub struct Inlining<'a> { diagnostics: &'a DiagnosticsHandler, } + impl<'a> Inlining<'a> { pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { Self { diagnostics } } + + /// Runs the Inlining pass once (with both InliningFirstPass and InliningSecondPass) + /// + /// Returns true if any calls were inlined, false otherwise, to let the caller know if any + /// changes were made or if we reached a fixed point. + fn run_once(&mut self, ir: &mut Mir) -> Result { + // The first pass only identifies the call graph dependencies and the needed calls to inline + let mut first_pass = InliningFirstPass::new(self.diagnostics); + Visitor::run(&mut first_pass, ir.constraint_graph_mut())?; + + // We then create the inlining order (inlining first the functions and evaluators that do + // not call other functions or evaluators) + let func_eval_inlining_order = + create_inlining_order(self.diagnostics, first_pass.func_eval_dependency_graph.clone())?; + + // The second pass actually inlines the calls + let mut second_pass = InliningSecondPass::new( + self.diagnostics, + func_eval_inlining_order.clone(), + first_pass.func_eval_nodes_where_called.clone(), + ); + Visitor::run(&mut second_pass, ir.constraint_graph_mut())?; + + Ok(second_pass.had_calls) + } +} + +// If we have to run the inlining algorithm `INLINING_LIMIT`, we return an error +const INLINING_LIMIT: usize = 10; + +impl Pass for Inlining<'_> { + type Input<'a> = Mir; + type Output<'a> = Mir; + type Error = CompileError; + + fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { + let mut had_calls = true; + let mut iterations = 0; + + while had_calls && iterations < INLINING_LIMIT { + had_calls = self.run_once(&mut ir)?; + iterations += 1; + } + + if had_calls { + self.diagnostics.error( + "Inlining call depth limit reached, some calls may not have been inlined. Aborting.".to_string(), + ); + return Err(CompileError::Failed); + } + + Ok(ir) + } } pub struct InliningFirstPass<'a> { @@ -39,13 +93,15 @@ pub struct InliningFirstPass<'a> { // general context work_stack: Vec>, in_func_or_eval: bool, - // When encountering a call, we store it here to construct the dependency graph once reaching the root node + // When encountering a call, we store it here to construct the dependency graph once reaching + // the root node current_callees_encountered: Vec>, // HashMap func_eval_dependency_graph: HashMap, Vec>)>, // HashMap> func_eval_nodes_where_called: HashMap, Vec>)>, // Op is a Call here } + impl<'a> InliningFirstPass<'a> { pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { Self { @@ -59,7 +115,8 @@ impl<'a> InliningFirstPass<'a> { } } -/// This structure is used to keep track of what is needed to inline a call to a given function or evaluator. +/// This structure is used to keep track of what is needed to inline a call to a given function or +/// evaluator. #[derive(Clone, Debug)] pub struct CallInliningContext { body: Link>>, @@ -67,7 +124,6 @@ pub struct CallInliningContext { pure_function: bool, ref_node: Link, } -impl CallInliningContext {} pub struct InliningSecondPass<'a> { diagnostics: &'a DiagnosticsHandler, @@ -84,7 +140,10 @@ pub struct InliningSecondPass<'a> { // HashMap)> func_eval_nodes_where_called: HashMap, Vec>)>, // Op is a Call here + had_calls: bool, + seen_root_call: bool, } + impl<'a> InliningSecondPass<'a> { pub fn new( diagnostics: &'a DiagnosticsHandler, @@ -99,38 +158,12 @@ impl<'a> InliningSecondPass<'a> { params_for_ref_node: HashMap::new(), func_eval_nodes_where_called, func_eval_inlining_order, + had_calls: false, + seen_root_call: false, } } } -impl Pass for Inlining<'_> { - type Input<'a> = Mir; - type Output<'a> = Mir; - type Error = CompileError; - - fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { - // The first pass only identifies the call graph dependencies and the needed calls to inline - let mut first_pass = InliningFirstPass::new(self.diagnostics); - Visitor::run(&mut first_pass, ir.constraint_graph_mut())?; - - // We then create the inlining order (inlining first the functions and evaluators that do not call other functions or evaluators) - let func_eval_inlining_order = create_inlining_order( - self.diagnostics, - first_pass.func_eval_dependency_graph.clone(), - )?; - - // The second pass actually inlines the calls - let mut second_pass = InliningSecondPass::new( - self.diagnostics, - func_eval_inlining_order.clone(), - first_pass.func_eval_nodes_where_called.clone(), - ); - Visitor::run(&mut second_pass, ir.constraint_graph_mut())?; - - Ok(ir) - } -} - /// Helper function to create the inlining order depending on the dependency graph /// /// Raises an error if a circular dependency is detected @@ -143,33 +176,27 @@ fn create_inlining_order( // Note: we remove an element at each iteration (or raise diag), so this will terminate while !func_eval_dependency_graph.is_empty() { // Find a function without dependency - match func_eval_dependency_graph - .clone() - .iter() - .find(|(_, (_k, v))| v.is_empty()) - { + match func_eval_dependency_graph.clone().iter().find(|(_, (_k, v))| v.is_empty()) { Some((f_ptr, (f, _))) => { func_eval_inlining_order.push(f.clone()); // Remove the entry of dependency graph corresponding to the next function to inline func_eval_dependency_graph.remove(f_ptr); - } + }, _ => { diagnostics .diagnostic(Severity::Error) .with_message("Circular dependency detected") .emit(); return Err(CompileError::Failed); - } + }, } let removed_fn = func_eval_inlining_order.last().unwrap(); // Remove the function from the list of dependencies of all other functions - func_eval_dependency_graph - .iter_mut() - .for_each(|(_, (_, v))| { - v.retain(|x| x != removed_fn); - }); + func_eval_dependency_graph.iter_mut().for_each(|(_, (_, v))| { + v.retain(|x| x != removed_fn); + }); } Ok(func_eval_inlining_order) } @@ -214,29 +241,23 @@ impl Visitor for InliningFirstPass<'_> { .clone() .into_iter() .map(|bc| bc.as_node()) - .chain( - integrity_constraints_roots_ref - .clone() - .into_iter() - .map(|ic| ic.as_node()), - ) + .chain(integrity_constraints_roots_ref.clone().into_iter().map(|ic| ic.as_node())) .chain(bus_roots.into_iter().map(|b| b.as_node())) .chain(evaluators.into_iter().map(|e| e.as_node())) .chain(functions.into_iter().map(|f| f.as_node())); combined_roots.collect() } - // When visiting a function or an evaluator, we have just finished visiting their bodies so we can update the dependency graph for this root - // We then clear the current_callees_encountered vec to prepare to visit the next function or evaluator + // When visiting a function or an evaluator, we have just finished visiting their bodies so we + // can update the dependency graph for this root We then clear the + // current_callees_encountered vec to prepare to visit the next function or evaluator fn visit_function( &mut self, _graph: &mut Graph, function: Link, ) -> Result<(), CompileError> { - self.func_eval_dependency_graph.insert( - function.get_ptr(), - (function, self.current_callees_encountered.clone()), - ); + self.func_eval_dependency_graph + .insert(function.get_ptr(), (function, self.current_callees_encountered.clone())); self.current_callees_encountered.clear(); Ok(()) } @@ -254,7 +275,8 @@ impl Visitor for InliningFirstPass<'_> { } // When visiting a call, we: - // - add it to the current list of callees if we're currently visiting the bodies of functions or evaluators (to build the dependency graph) + // - add it to the current list of callees if we're currently visiting the bodies of functions + // or evaluators (to build the dependency graph) // - add it to the list of calls to inline with the func_eval_nodes_where_called map fn visit_call(&mut self, _graph: &mut Graph, call: Link) -> Result<(), CompileError> { // safe to unwrap because we just dispatched on it @@ -292,7 +314,10 @@ impl Visitor for InliningSecondPass<'_> { call_nodes_to_inline_in_order } fn run(&mut self, graph: &mut Graph) -> Result<(), CompileError> { - for root_node in self.root_nodes_to_visit(graph).iter() { + let root_nodes_to_visit = self.root_nodes_to_visit(graph); + self.had_calls = !root_nodes_to_visit.is_empty(); + + for root_node in root_nodes_to_visit.iter() { let mut updated_op = None; if let Some(op) = root_node.as_op() { @@ -323,42 +348,52 @@ impl Visitor for InliningSecondPass<'_> { self.call_inlining_context = Some(context.clone()); self.nodes_to_replace.clear(); self.params_for_ref_node.clear(); + self.seen_root_call = false; self.scan_node(graph, root_node.clone())?; - while let Some(node) = self.work_stack().pop() { self.visit_node(graph, node.clone())?; } if context.pure_function { - // We have finished inlining the body, we can now replace the Call node with the last expression of the body + // We have finished inlining the body, we can now replace the Call node with the + // last expression of the body let last_child_of_body = context.body.borrow().last().unwrap().clone(); - let (_, new_node) = self - .nodes_to_replace - .get(&last_child_of_body.get_ptr()) - .unwrap() - .clone(); + let (_, new_node) = + self.nodes_to_replace.get(&last_child_of_body.get_ptr()).unwrap().clone(); updated_op = Some(new_node); } else { - // We have finished inlining the body, we can now replace the Call node with all the body + // We have finished inlining the body, we can now replace the Call node with all + // the body let mut new_nodes = Vec::new(); for body_node in context.body.borrow().iter() { // FIXME: Maybe we should only push nodes that are Enf()? // Depends if additional nodes change things (e.g. the Vector size..) // For now I think we can keep all nodes, and just ignore the non-Enf nodes // When building the constraints during lowering Mir -> Air - new_nodes.push( - self.nodes_to_replace - .get(&body_node.get_ptr()) + + let new_node = + self.nodes_to_replace.get(&body_node.get_ptr()).unwrap().1.clone(); + + if let Some(bus_op) = new_node.clone().as_bus_op() { + let latch = bus_op.latch.clone(); + + bus_op + .bus + .to_link() .unwrap() - .1 - .clone(), - ); + .borrow_mut() + .columns + .push(new_node.clone()); + bus_op.bus.to_link().unwrap().borrow_mut().latches.push(latch.clone()); + } + + if new_node.clone().as_enf().is_some() { + new_nodes.push(new_node); + } } - let span = new_nodes - .iter() - .map(|n| n.span()) - .fold(SourceSpan::UNKNOWN, |acc, s| { + let span = + new_nodes.iter().map(|n| n.span()).fold(SourceSpan::UNKNOWN, |acc, s| { acc.merge(s).unwrap_or(SourceSpan::UNKNOWN) }); let new_nodes_vector = Vector::create(new_nodes, span); @@ -370,8 +405,9 @@ impl Visitor for InliningSecondPass<'_> { self.call_inlining_context = None; } - // Effectively replace the Call node with the updated op - // Note: We also update the references of Parameters that referenced the node we are replacing + // Effectively replace the `Call` node with the updated op + // Note: We also update the references of Parameters that referenced the node we are + // replacing if let Some(updated_op) = updated_op { // let prev_owner_ptr = updated_op.as_owner().unwrap().get_ptr(); @@ -380,30 +416,22 @@ impl Visitor for InliningSecondPass<'_> { root_node.as_op().unwrap().set(&updated_op); if let Some(params) = params { - let new_owner = root_node - .clone() - .as_op() - .unwrap() - .clone() - .as_owner() - .unwrap(); + let new_owner = root_node.clone().as_op().unwrap().clone().as_owner().unwrap(); for param in params.iter() { - param - .as_parameter_mut() - .unwrap() - .set_ref_node(new_owner.clone()); + param.as_parameter_mut().unwrap().set_ref_node(new_owner.clone()); } } } } Ok(()) } + fn scan_node(&mut self, _graph: &Graph, node: Link) -> Result<(), CompileError> { self.work_stack().push(node.clone()); if let Some(op) = node.clone().as_op() { - // If we scan a Call node, we do not visit its children (the call's arguments) + // If we scan a `Call` node, we do not visit its children (the call's arguments) // TODO INLINING: Check whether this is the wanted behavior - if op.as_call().is_some() { + if op.as_call().is_some() && !self.seen_root_call { return Ok(()); }; for child in node.children().borrow().iter() { @@ -414,16 +442,14 @@ impl Visitor for InliningSecondPass<'_> { } fn visit_call(&mut self, _graph: &mut Graph, _call: Link) -> Result<(), CompileError> { - let Some(context) = self.call_inlining_context.clone() else { - unreachable!("InliningSecondPass::visit_node: call_inlining_context is None"); - }; + let context = self + .call_inlining_context + .clone() + .expect("InliningSecondPass::visit_node: call_inlining_context is None"); if context.pure_function { // Instead of scanning all the body, we only scan the last node, // which represents the return value of the function - self.scan_node( - _graph, - context.body.borrow().last().unwrap().clone().as_node(), - )?; + self.scan_node(_graph, context.body.borrow().last().unwrap().clone().as_node())?; } else { // We scan all the nodes related to the body for body_node in context.body.borrow().iter() { @@ -443,63 +469,61 @@ impl Visitor for InliningSecondPass<'_> { panic!("InliningSecondPass::visit_node on a non-Op node: {node:?}") }); - // First, check if it's a known Call to inline, + // First, check if it's a known `Call` to inline, // if so, set the context and scan its body - if call_op.clone().as_call().is_some() { + if call_op.clone().as_call().is_some() && !self.seen_root_call { + self.seen_root_call = true; self.visit_call(graph, call_op.clone())?; + // NOTE: Once the call has been processed, we return early to avoid + // an infinite loop. + // If there are nested calls, they will be handled via fixed-point + // compilation (see [ as Pass>::run]). + return Ok(()); + } + + // Else, we are currently visiting the body of a function or an evaluator of a call + // we want to inline. We use our helper duplicate_node_or_replace to + // duplicate the body, while replacing the `Function` or `Evaluator` parameters with + // the `Call` arguments + if self.call_inlining_context.clone().unwrap().pure_function { + duplicate_node_or_replace( + &mut self.nodes_to_replace, + call_op.clone(), + self.call_inlining_context.clone().unwrap().arguments.borrow().clone(), + self.call_inlining_context.clone().unwrap().ref_node, + None, + &mut self.params_for_ref_node, + ); } else { - // Else, we are currently visiting the body of a function or an evaluator of a call we want to inline - // We use our helper duplicate_node_or_replace to duplicate the body, while replacing the Function or Evaluator parameters with the Call arguments - if self.call_inlining_context.clone().unwrap().pure_function { - duplicate_node_or_replace( - &mut self.nodes_to_replace, - call_op.clone(), - self.call_inlining_context - .clone() - .unwrap() - .arguments - .borrow() - .clone(), - self.call_inlining_context.clone().unwrap().ref_node, - None, - &mut self.params_for_ref_node, - ); - } else { - // If we're inside the body of an evaluator, we first need to unpack the arguments of the call to have a Vector of Trace columns, and not - // bindings to multiple columns - let args = self - .call_inlining_context - .clone() - .unwrap() - .arguments - .borrow() - .clone(); + // If we're inside the body of an evaluator, we first need to unpack the + // arguments of the call to have a `Vector` of `TraceColumn`, and not + // bindings to multiple columns + let args = self.call_inlining_context.clone().unwrap().arguments.borrow().clone(); - let callee_params = self - .call_inlining_context - .clone() - .unwrap() - .ref_node - .as_root() - .unwrap() - .as_evaluator() - .unwrap() - .parameters - .clone(); - - check_evaluator_argument_sizes(&args, callee_params, self.diagnostics)?; - - let args_unpacked = unpack_evaluator_arguments(&args); - - duplicate_node_or_replace( - &mut self.nodes_to_replace, - call_op.clone(), - args_unpacked, - self.call_inlining_context.clone().unwrap().ref_node, - None, - &mut self.params_for_ref_node, - ); - } + let callee_params = self + .call_inlining_context + .clone() + .unwrap() + .ref_node + .as_root() + .unwrap() + .as_evaluator() + .unwrap() + .parameters + .clone(); + + check_evaluator_argument_sizes(&args, callee_params, self.diagnostics)?; + + let args_unpacked = unpack_evaluator_arguments(&args); + + duplicate_node_or_replace( + &mut self.nodes_to_replace, + call_op.clone(), + args_unpacked, + self.call_inlining_context.clone().unwrap().ref_node, + None, + &mut self.params_for_ref_node, + ); } } @@ -507,6 +531,52 @@ impl Visitor for InliningSecondPass<'_> { } } +/// Helper function to recursively extract the outermost accessors +/// if the op is an accessor, otherwise return the op unchanged +fn extract_accessor(op: Link) -> Link { + let Some(accessor) = op.as_accessor() else { + return op; + }; + let mut indexable = accessor.indexable.clone(); + match &accessor.access_type { + MirAccessType::Default => { + // Recursively extract accessors from the indexable + extract_accessor(accessor.indexable.clone()) + }, + MirAccessType::Index(idx_op) => { + let idx = get_inner_const(idx_op) + .unwrap_or_else(|| panic!("expected constant index, got {:#?}", idx_op)) + as usize; + while indexable.clone().as_accessor().is_some() { + indexable = extract_accessor(indexable.clone()); + } + match indexable.borrow().deref() { + Op::Vector(v) => v.children().borrow()[idx].clone(), + Op::Matrix(m) => m.children().borrow()[idx].clone(), + _ => unreachable!("expected vector or matrix, got {:#?}", indexable), + } + }, + MirAccessType::Matrix(row, col) => { + let row = get_inner_const(row) + .unwrap_or_else(|| panic!("expected constant row, got {:#?}", row)) + as usize; + let col = get_inner_const(col) + .unwrap_or_else(|| panic!("expected constant column, got {:#?}", col)) + as usize; + while indexable.clone().as_accessor().is_some() { + indexable = extract_accessor(indexable.clone()); + } + match indexable.borrow().deref() { + Op::Matrix(m) => { + m.children().borrow()[row].clone().as_matrix().unwrap().children().borrow()[col] + .clone() + }, + _ => unreachable!("expected matrix, got {:#?}", indexable), + } + }, + } +} + /// Helper function to check, for each trace segment, that the total size of arguments is correct fn check_evaluator_argument_sizes( args: &[Link], @@ -516,63 +586,11 @@ fn check_evaluator_argument_sizes( for ((trace_segment_id, trace_segments_params), trace_segments_arg) in callee_params.iter().enumerate().zip(args.iter()) { - let Some(trace_segments_arg_vector) = trace_segments_arg.as_vector() else { - unreachable!("expected vector, got {:?}", trace_segments_arg); - }; + let trace_segments_arg_vector = trace_segments_arg + .as_vector() + .unwrap_or_else(|| panic!("expected vector, got {trace_segments_arg:?}")); let children = trace_segments_arg_vector.children(); - let mut trace_segments_arg_vector_len = 0; - for child in children.borrow().deref() { - if let Some(value) = child.as_value() { - let Value { - value: SpannedMirValue { value, .. }, - .. - } = value.deref(); - - let param_size = match value { - MirValue::TraceAccessBinding(tab) => tab.size, - MirValue::TraceAccess(_) => 1, - _ => unreachable!("expected trace access binding, got {:?}", value), - }; - trace_segments_arg_vector_len += param_size; - } else if let Some(parameter) = child.as_parameter() { - let Parameter { ty, .. } = parameter.deref(); - let size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, - _ => unreachable!("expected felt or vector, got {:?}", ty), - }; - trace_segments_arg_vector_len += size; - } else if let Some(accessor) = child.as_accessor() { - let Accessor { indexable, .. } = accessor.deref(); - - if let Some(value) = indexable.as_value() { - let Value { - value: SpannedMirValue { value, .. }, - .. - } = value.deref(); - - let param_size = match value { - MirValue::TraceAccessBinding(tab) => tab.size, - MirValue::TraceAccess(_) => 1, - _ => unreachable!("expected trace access binding, got {:?}", value), - }; - trace_segments_arg_vector_len += param_size; - } else if let Some(parameter) = indexable.as_parameter() { - let Parameter { ty, .. } = parameter.deref(); - let size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, - _ => unreachable!("expected felt or vector, got {:?}", ty), - }; - trace_segments_arg_vector_len += size; - } else { - unreachable!("expected value or parameter, got {:?}", child); - } - } else { - unreachable!("expected value or parameter, got {:?}", child); - } - } - + let (trace_segments_arg_vector_len, _) = process_evaluator_arg_children(children.clone()); if trace_segments_params.len() != trace_segments_arg_vector_len { diagnostics .diagnostic(Severity::Error) @@ -605,85 +623,102 @@ fn check_evaluator_argument_sizes( fn unpack_evaluator_arguments(args: &[Link]) -> Vec> { let mut args_unpacked = Vec::new(); for args_for_trace_segment in args.iter() { - let Some(trace_segment_vec) = args_for_trace_segment.as_vector() else { - unreachable!( - "Arguments of a Call node to Evaluator should be a Vectors for each trace segment" - ); - }; + let trace_segment_vec = args_for_trace_segment.as_vector().expect( + "Arguments of a Call node to Evaluator should be a Vector for each trace segment", + ); let children = trace_segment_vec.children(); - for arg in children.borrow().deref() { - if let Some(value) = arg.as_value() { - let Value { - value: SpannedMirValue { span, value, .. }, - .. - } = value.deref(); - - match value { - MirValue::TraceAccessBinding(tab) => { - if tab.size > 1 { - for index in 0..tab.size { - let new_arg = Value::create(SpannedMirValue { - value: MirValue::TraceAccessBinding(TraceAccessBinding { - size: 1, - segment: tab.segment, - offset: tab.offset + index, - }), - span: *span, - }); - args_unpacked.push(new_arg); - } - } else { - args_unpacked.push(arg.clone()); + let (_, unpacked_children) = process_evaluator_arg_children(children.clone()); + args_unpacked.extend(unpacked_children); + } + args_unpacked +} + +fn process_evaluator_arg_children(children: Link>>) -> (usize, Vec>) { + let mut trace_segments_arg_vector_len = 0; + let mut args_unpacked = Vec::new(); + for child in children.borrow().deref() { + // We need to extract the outermost accessors, if any. + // This can happen when slices are used to rebind + // trace bindings in nested evaluator calls + let child = extract_accessor(child.clone()); + if let Some(value) = child.as_value() { + let Value { + value: SpannedMirValue { value, span }, .. + } = value.deref(); + + let param_size = match value { + MirValue::TraceAccessBinding(tab) => { + if tab.size > 1 { + for index in 0..tab.size { + let new_arg = Value::create(SpannedMirValue { + value: MirValue::TraceAccessBinding(TraceAccessBinding { + size: 1, + segment: tab.segment, + offset: tab.offset + index, + }), + span: *span, + }); + args_unpacked.push(new_arg); } + } else { + args_unpacked.push(child.clone()); } - MirValue::TraceAccess(_ta) => { - args_unpacked.push(arg.clone()); - } - _ => unreachable!( - "expected trace access binding or trace access, got {:?}", - value - ), - }; - } else if let Some(_parameter) = arg.as_parameter() { - args_unpacked.push(arg.clone()); - } else if let Some(accessor) = arg.as_accessor() { - let Accessor { indexable, .. } = accessor.deref(); - - if let Some(value) = indexable.as_value() { - let Value { - value: SpannedMirValue { value, .. }, + tab.size + }, + MirValue::TraceAccess(_ta) => { + args_unpacked.push(child.clone()); + 1 + }, + _ => unreachable!("expected trace access binding or trace access, got {:?}", value), + }; + trace_segments_arg_vector_len += param_size; + } else if let Some(parameter) = child.as_parameter() { + args_unpacked.push(child.clone()); + let Parameter { ty, .. } = parameter.deref(); + let size = match ty { + MirType::Felt => 1, + MirType::Vector(len) => *len, + _ => unreachable!("expected felt or vector, got {:?}", ty), + }; + trace_segments_arg_vector_len += size; + } else if let Some(vector) = child.as_vector() { + // When slices are used to rebind trace bindings in nested evaluator calls, + // translation from Ast to Mir inserts a Vector of + // Accessor(MirAccessType::Index), representing the unwrapped + // slice. + // We need to unpack all the elements of this vector and add them + // to the arguments + // We also need to make sure that this vector is a vector of Felts + // and to add its length to the total length of arguments + for c in vector.children().borrow().iter() { + let c = extract_accessor(c.clone()); + if let Some(param) = c.as_parameter() { + assert!( + param.ty == MirType::Felt, + "expected parameter of type Felt, got {:#?}", + param + ); + args_unpacked.push(c.clone()); + } else if let Some(value) = c.as_value() { + let SpannedMirValue { + value: MirValue::Constant(_) | MirValue::TraceAccess(_), .. - } = value.deref(); - - let _param_size = match value { - MirValue::TraceAccessBinding(tab) => tab.size, - MirValue::TraceAccess(_) => 1, - _ => unreachable!("expected trace access binding, got {:?}", value), + } = value.value + else { + unreachable!("expected constant or trace access, got {:?}", value); }; - - args_unpacked.push(indexable.clone()); - } else if let Some(parameter) = indexable.as_parameter() { - let Parameter { ty, .. } = parameter.deref(); - let _size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, - _ => unreachable!("expected felt or vector, got {:?}", ty), - }; - - args_unpacked.push(indexable.clone()); + args_unpacked.push(c.clone()); } else { - unreachable!( - "expected value or parameter (or accessor on one), got {:?}", - arg - ); + unreachable!("expected value or parameter (or accessor on one), got {:?}", c); } - } else { - unreachable!( - "expected value or parameter (or accessor on one), got {:?}", - arg - ); } + let vector_len = vector.children().borrow().len(); + trace_segments_arg_vector_len += vector_len; + } else if let Some(_accessor) = child.as_accessor() { + unreachable!("accessors should have been extracted already"); + } else { + unreachable!("expected value, parameter, or vector of them, got {:#?}", child); } } - args_unpacked + (trace_segments_arg_vector_len, args_unpacked) } diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index dcef4eeed..789a74017 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -1,32 +1,33 @@ +mod constant_propagation; mod inlining; mod translate; mod unrolling; mod visitor; +use std::{collections::HashMap, ops::Deref}; + +pub use constant_propagation::ConstantPropagation; pub use inlining::Inlining; +use miden_diagnostics::{DiagnosticsHandler, Spanned}; pub use translate::AstToMir; pub use unrolling::Unrolling; pub use visitor::Visitor; -// Note: ConstantPropagation and ValueNumbering are not implemented yet in the MIR -//mod constant_propagation; -//mod value_numbering; -//pub use constant_propagation::ConstantPropagation; -//pub use value_numbering::ValueNumbering; - -use std::collections::HashMap; -use std::ops::Deref; -use miden_diagnostics::Spanned; - -use crate::ir::{ - Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, Matrix, Mul, Node, Op, - Owner, Parameter, Parent, Sub, Value, Vector, +use crate::{ + CompileError, + ir::{ + Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, + MirAccessType, MirValue, Mul, Node, Op, Owner, Parameter, Parent, PublicInputAccess, + SpannedMirValue, Sub, TraceAccess, Value, Vector, + }, }; /// Helper to duplicate a MIR node and its children recursively -/// It should be used when we want to reference the same node multiple times in the MIR graph (e.g. referencing let bound variables) +/// It should be used when we want to reference the same node multiple times in the MIR graph (e.g. +/// referencing let bound variables) /// -/// Note: the current_replace_map is only used to keep track of For nodes, that can be referenced by Parameters inside their bodies -/// Then, duplicated Parameters should reference the new For node, not the original one +/// Note: the current_replace_map is only used to keep track of `For` nodes, that can be referenced +/// by `Parameters` inside their bodies Then, duplicated Parameters should reference the new `For` +/// node, not the original one pub fn duplicate_node( node: Link, current_replace_map: &mut HashMap, Link)>, @@ -36,62 +37,58 @@ pub fn duplicate_node( let expr = enf.expr.clone(); let new_expr = duplicate_node(expr, current_replace_map); Enf::create(new_expr, enf.span()) - } + }, Op::Boundary(boundary) => { let expr = boundary.expr.clone(); let kind = boundary.kind; let new_expr = duplicate_node(expr, current_replace_map); Boundary::create(new_expr, kind, boundary.span()) - } + }, Op::Add(add) => { let lhs = add.lhs.clone(); let rhs = add.rhs.clone(); let new_lhs_node = duplicate_node(lhs, current_replace_map); let new_rhs_node = duplicate_node(rhs, current_replace_map); Add::create(new_lhs_node, new_rhs_node, add.span()) - } + }, Op::Sub(sub) => { let lhs = sub.lhs.clone(); let rhs = sub.rhs.clone(); let new_lhs_node = duplicate_node(lhs, current_replace_map); let new_rhs_node = duplicate_node(rhs, current_replace_map); Sub::create(new_lhs_node, new_rhs_node, sub.span()) - } + }, Op::Mul(mul) => { let lhs = mul.lhs.clone(); let rhs = mul.rhs.clone(); let new_lhs_node = duplicate_node(lhs, current_replace_map); let new_rhs_node = duplicate_node(rhs, current_replace_map); Mul::create(new_lhs_node, new_rhs_node, mul.span()) - } + }, Op::Exp(exp) => { let lhs = exp.lhs.clone(); let rhs = exp.rhs.clone(); let new_lhs_node = duplicate_node(lhs, current_replace_map); let new_rhs_node = duplicate_node(rhs, current_replace_map); Exp::create(new_lhs_node, new_rhs_node, exp.span()) - } + }, Op::If(if_node) => { - let condition = if_node.condition.clone(); - let then_branch = if_node.then_branch.clone(); - let else_branch = if_node.else_branch.clone(); - let new_condition = duplicate_node(condition, current_replace_map); - let new_then_branch = duplicate_node(then_branch, current_replace_map); - let new_else_branch = duplicate_node(else_branch, current_replace_map); - If::create( - new_condition, - new_then_branch, - new_else_branch, - if_node.span(), - ) - } + let match_arms = if_node.match_arms.clone(); + let new_match_arms = match_arms + .borrow() + .iter() + .cloned() + .map(|arm| { + let new_expr = duplicate_node(arm.expr, current_replace_map); + let new_cond = duplicate_node(arm.condition, current_replace_map); + MatchArm::new(new_expr, new_cond) + }) + .collect::>(); + If::create(new_match_arms, if_node.span()) + }, Op::For(for_node) => { - let new_for_node: Link = For::create( - Link::default(), - Link::default(), - Link::default(), - for_node.span(), - ); + let new_for_node: Link = + For::create(Link::default(), Link::default(), Link::default(), for_node.span()); current_replace_map.insert(node.get_ptr(), (node, new_for_node.clone())); let iterators = for_node.iterators.clone(); @@ -112,7 +109,7 @@ pub fn duplicate_node( *new_for_node.as_for_mut().unwrap().expr.borrow_mut() = new_body.borrow().clone(); new_for_node - } + }, Op::Call(call) => { let arguments = call.arguments.clone(); let function = call.function.clone(); @@ -123,7 +120,7 @@ pub fn duplicate_node( .map(|x| duplicate_node(x, current_replace_map)) .collect::>(); Call::create(function, new_arguments, call.span()) - } + }, Op::Fold(fold) => { let iterator = fold.iterator.clone(); let operator = fold.operator.clone(); @@ -131,7 +128,7 @@ pub fn duplicate_node( let new_iterator = duplicate_node(iterator, current_replace_map); let new_initial_value = duplicate_node(initial_value, current_replace_map); Fold::create(new_iterator, operator, new_initial_value, fold.span()) - } + }, Op::Vector(vector) => { let children_link = vector.children().clone(); let children_ref = children_link.borrow(); @@ -142,7 +139,7 @@ pub fn duplicate_node( .map(|x| duplicate_node(x, current_replace_map)) .collect(); Vector::create(new_children, vector.span()) - } + }, Op::Matrix(matrix) => { let mut new_matrix = Vec::new(); let children_link = matrix.children().clone(); @@ -166,14 +163,23 @@ pub fn duplicate_node( new_matrix.push(new_row); } Matrix::create(new_matrix, matrix.span()) - } + }, Op::Accessor(accessor) => { let indexable = accessor.indexable.clone(); - let access_type = accessor.access_type.clone(); + let new_access_type = match accessor.access_type.clone() { + MirAccessType::Default => MirAccessType::Default, + MirAccessType::Index(index) => { + MirAccessType::Index(duplicate_node(index, current_replace_map)) + }, + MirAccessType::Matrix(row, col) => MirAccessType::Matrix( + duplicate_node(row, current_replace_map), + duplicate_node(col, current_replace_map), + ), + }; let offset = accessor.offset; let new_indexable = duplicate_node(indexable, current_replace_map); - Accessor::create(new_indexable, access_type, offset, accessor.span()) - } + Accessor::create(new_indexable, new_access_type, offset, accessor.span()) + }, Op::BusOp(bus_op) => { let bus = bus_op.bus.clone(); let kind = bus_op.kind; @@ -183,7 +189,7 @@ pub fn duplicate_node( .map(|x| duplicate_node(x.clone(), current_replace_map)) .collect(); BusOp::create(bus, kind, args, bus_op.span()) - } + }, Op::Parameter(parameter) => { let owner_ref = parameter .ref_node @@ -193,40 +199,38 @@ pub fn duplicate_node( Parameter::create(parameter.position, parameter.ty.clone(), parameter.span()); if let Some(_root_ref) = owner_ref.as_root() { - new_param - .as_parameter_mut() - .unwrap() - .set_ref_node(owner_ref); - } else if let Some((_replaced_node, replaced_by)) = - current_replace_map.get(&owner_ref.as_op().unwrap().get_ptr()) + new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref); + } else if let Some(op_ref) = owner_ref.as_op() + && let Some((_replaced_node, replaced_by)) = + current_replace_map.get(&op_ref.get_ptr()) { new_param .as_parameter_mut() .unwrap() .set_ref_node(replaced_by.clone().as_owner().unwrap()); } else { - new_param - .as_parameter_mut() - .unwrap() - .set_ref_node(owner_ref); + new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref); } new_param - } + }, Op::Value(value) => Value::create(value.value.clone()), Op::None(span) => Op::None(*span).into(), } } -/// Helper used to duplicate nodes and their children recursively, used during Inlining and Unrolling -/// Additionally, if a Leaf is a Parameter that references the given ref_owner (if set to Some()) or ref_node, -/// it is replaced with the corresponding item of the replace_parameter_list Vec. +/// Helper used to duplicate nodes and their children recursively, used during Inlining and +/// Unrolling Additionally, if a Leaf is a `Parameter` that references the given ref_owner (if set +/// to Some()) or ref_node, it is replaced with the corresponding item of the replace_parameter_list +/// Vec. /// -/// This is useful for inlining function calls (and replacing their parameters with the arguments of the call) for Inlining, -/// and for Unrolling loops (and replacing their parameters with the iterator values) for Unrolling. -/// Inlining: replace_parameter_list = arguments should be the arguments from the Call() -/// Unrolling: replace_parameter_list = self.for_inlining_context.unwrap().iterators +/// This is useful for inlining function calls (and replacing their parameters with the arguments of +/// the call) for Inlining, and for Unrolling loops (and replacing their parameters with the +/// iterator values) for Unrolling. Inlining: replace_parameter_list = arguments should be the +/// arguments from the `Call` Unrolling: replace_parameter_list = +/// self.for_inlining_context.unwrap().iterators /// -/// Note: The params_for_ref_node parameters is used to keep track of Parameters, and update their ref_node as needed. +/// Note: The params_for_ref_node parameters is used to keep track of `Parameters`, and update their +/// ref_node as needed. pub fn duplicate_node_or_replace( current_replace_map: &mut HashMap, Link)>, node: Link, @@ -240,67 +244,64 @@ pub fn duplicate_node_or_replace( match node.borrow().deref() { Op::Enf(enf) => { let expr = enf.expr.clone(); - let new_expr = current_replace_map.get(&expr.get_ptr()).unwrap().1.clone(); + let new_expr = current_replace_map[&expr.get_ptr()].1.clone(); let new_node = Enf::create(new_expr, enf.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Boundary(boundary) => { let expr = boundary.expr.clone(); let kind = boundary.kind; - let new_expr = current_replace_map.get(&expr.get_ptr()).unwrap().1.clone(); + let new_expr = current_replace_map[&expr.get_ptr()].1.clone(); let new_node = Boundary::create(new_expr, kind, boundary.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Add(add) => { let lhs = add.lhs.clone(); let rhs = add.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Add::create(new_lhs_node, new_rhs_node, add.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Sub(sub) => { let lhs = sub.lhs.clone(); let rhs = sub.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Sub::create(new_lhs_node, new_rhs_node, sub.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Mul(mul) => { let lhs = mul.lhs.clone(); let rhs = mul.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Mul::create(new_lhs_node, new_rhs_node, mul.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Exp(exp) => { let lhs = exp.lhs.clone(); let rhs = exp.rhs.clone(); - let new_lhs_node = current_replace_map.get(&lhs.get_ptr()).unwrap().1.clone(); - let new_rhs_node = current_replace_map.get(&rhs.get_ptr()).unwrap().1.clone(); + let new_lhs_node = current_replace_map[&lhs.get_ptr()].1.clone(); + let new_rhs_node = current_replace_map[&rhs.get_ptr()].1.clone(); let new_node = Exp::create(new_lhs_node, new_rhs_node, exp.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::If(if_node) => { - let cond = if_node.condition.clone(); - let then_branch = if_node.then_branch.clone(); - let else_branch = if_node.else_branch.clone(); - let new_cond = current_replace_map.get(&cond.get_ptr()).unwrap().1.clone(); - let new_then_branch = current_replace_map - .get(&then_branch.get_ptr()) - .unwrap() - .1 - .clone(); - let new_else_branch = current_replace_map - .get(&else_branch.get_ptr()) - .unwrap() - .1 - .clone(); - let new_node = If::create(new_cond, new_then_branch, new_else_branch, if_node.span()); + let match_arms = if_node.match_arms.clone(); + let new_match_arms = match_arms + .borrow() + .iter() + .cloned() + .map(|arm| { + let new_expr = current_replace_map[&arm.expr.get_ptr()].1.clone(); + let new_cond = current_replace_map[&arm.condition.get_ptr()].1.clone(); + MatchArm::new(new_expr, new_cond) + }) + .collect::>(); + let new_node = If::create(new_match_arms, if_node.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::For(for_node) => { let iterators = for_node.iterators.clone(); let body = for_node.expr.clone(); @@ -309,16 +310,10 @@ pub fn duplicate_node_or_replace( .borrow() .iter() .cloned() - .map(|iterator| { - current_replace_map - .get(&iterator.get_ptr()) - .unwrap() - .1 - .clone() - }) + .map(|iterator| current_replace_map[&iterator.get_ptr()].1.clone()) .collect::>() .into(); - let new_body = current_replace_map.get(&body.get_ptr()).unwrap().1.clone(); + let new_body = current_replace_map[&body.get_ptr()].1.clone(); let new_selector = current_replace_map .get(&selector.get_ptr()) .map(|selector| selector.1.clone()) @@ -330,10 +325,7 @@ pub fn duplicate_node_or_replace( if let Some(params) = params_for_ref_node.get(&prev_owner_ptr.unwrap()).cloned() { let new_owner = new_node.clone().as_owner().unwrap(); for param in params.iter() { - param - .as_parameter_mut() - .unwrap() - .set_ref_node(new_owner.clone()); + param.as_parameter_mut().unwrap().set_ref_node(new_owner.clone()); } params_for_ref_node @@ -341,7 +333,7 @@ pub fn duplicate_node_or_replace( .or_default() .extend(params.clone()); } - } + }, Op::Call(call) => { let arguments = call.arguments.clone(); let function = call.function.clone(); @@ -349,34 +341,20 @@ pub fn duplicate_node_or_replace( .borrow() .iter() .cloned() - .map(|argument| { - current_replace_map - .get(&argument.get_ptr()) - .unwrap() - .1 - .clone() - }) + .map(|argument| current_replace_map[&argument.get_ptr()].1.clone()) .collect::>(); let new_node = Call::create(function, new_arguments, call.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Fold(fold) => { let iterator = fold.iterator.clone(); let operator = fold.operator.clone(); let initial_value = fold.initial_value.clone(); - let new_iterator = current_replace_map - .get(&iterator.get_ptr()) - .unwrap() - .1 - .clone(); - let new_initial_value = current_replace_map - .get(&initial_value.get_ptr()) - .unwrap() - .1 - .clone(); + let new_iterator = current_replace_map[&iterator.get_ptr()].1.clone(); + let new_initial_value = current_replace_map[&initial_value.get_ptr()].1.clone(); let new_node = Fold::create(new_iterator, operator, new_initial_value, fold.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Vector(vector) => { let children_link = vector.children().clone(); let children_ref = children_link.borrow(); @@ -384,11 +362,11 @@ pub fn duplicate_node_or_replace( let new_children = children .iter() .cloned() - .map(|child| current_replace_map.get(&child.get_ptr()).unwrap().1.clone()) + .map(|child| current_replace_map[&child.get_ptr()].1.clone()) .collect(); let new_node = Vector::create(new_children, vector.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Matrix(matrix) => { let mut new_matrix = Vec::new(); let children_link = matrix.children().clone(); @@ -406,33 +384,56 @@ pub fn duplicate_node_or_replace( let new_row_as_vec = row_children .iter() .cloned() - .map(|child| current_replace_map.get(&child.get_ptr()).unwrap().1.clone()) + .map(|child| current_replace_map[&child.get_ptr()].1.clone()) .collect::>(); let new_row = Vector::create(new_row_as_vec, row.span()); new_matrix.push(new_row); } let new_node = Matrix::create(new_matrix, matrix.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Accessor(accessor) => { let indexable = accessor.indexable.clone(); - let access_type = accessor.access_type.clone(); + let new_access_type = match accessor.access_type.clone() { + MirAccessType::Default => MirAccessType::Default, + MirAccessType::Index(index) => { + MirAccessType::Index(current_replace_map[&index.get_ptr()].1.clone()) + }, + MirAccessType::Matrix(row, col) => MirAccessType::Matrix( + current_replace_map[&row.get_ptr()].1.clone(), + current_replace_map[&col.get_ptr()].1.clone(), + ), + }; let offset = accessor.offset; - let new_indexable = current_replace_map - .get(&indexable.get_ptr()) - .unwrap() - .1 - .clone(); - let new_node = Accessor::create(new_indexable, access_type, offset, accessor.span()); + let new_indexable = current_replace_map[&indexable.get_ptr()].1.clone(); + let new_node = + Accessor::create(new_indexable, new_access_type, offset, accessor.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::BusOp(bus_op) => { let bus = bus_op.bus.clone(); let kind = bus_op.kind; let args = bus_op.args.clone(); - let new_node = BusOp::create(bus, kind, args, bus_op.span()); + let latch = bus_op.latch.clone(); + + let new_args = args + .iter() + .cloned() + .map(|arg| current_replace_map[&arg.get_ptr()].1.clone()) + .collect(); + let new_latch = current_replace_map[&latch.get_ptr()].1.clone(); + let new_node = BusOp::create(bus.clone(), kind, new_args, bus_op.span()); + + // Update latch of cloned bus_op + new_node + .as_bus_op_mut() + .unwrap() + .latch + .borrow_mut() + .clone_from(&new_latch.borrow()); + current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); - } + }, Op::Parameter(parameter) => { let owner_ref = parameter .ref_node @@ -445,22 +446,17 @@ pub fn duplicate_node_or_replace( }; if owner_ref == ref_owner { - let new_node = replace_parameter_list[parameter.position].clone(); + let replace_by_node = replace_parameter_list[parameter.position].clone(); + let new_node = duplicate_node(replace_by_node, &mut Default::default()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); } else { let new_param = Parameter::create(parameter.position, parameter.ty.clone(), parameter.span()); if let Some(_root_ref) = owner_ref.as_root() { - new_param - .as_parameter_mut() - .unwrap() - .set_ref_node(owner_ref.clone()); + new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref.clone()); } else { - new_param - .as_parameter_mut() - .unwrap() - .set_ref_node(owner_ref.clone()); + new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref.clone()); params_for_ref_node .entry(owner_ref.get_ptr()) .or_default() @@ -469,11 +465,268 @@ pub fn duplicate_node_or_replace( current_replace_map.insert(node.get_ptr(), (node.clone(), new_param)); } - } + }, Op::Value(value) => { let new_node = Value::create(value.value.clone()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); + }, + Op::None(_) => {}, + } +} + +/// Helper function to extract the constant felt value from a Link if it is one. +pub fn get_inner_const(value: &Link) -> Option { + match value.borrow().deref() { + Op::Value(v) => v.get_inner_const(), + _ => None, + } +} + +/// Handle the visit of an accessor node, used for both Unrolling and ConstantPropagation passes +/// The `expect_constant_indices` bool indicates whether the indices need to be constant at +/// this stage. +pub fn handle_accessor_visit( + accessor: Link, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result>, CompileError> { + let accessor_ref = accessor.as_accessor().unwrap(); + let indexable = accessor_ref.indexable.clone(); + let mir_access_type = accessor_ref.access_type.clone(); + let offset = accessor_ref.offset; + + match mir_access_type { + // If we have a Default accessor, we add the row offset if needed, otherwise we just return + // the indexable + MirAccessType::Default => Ok(Some(add_row_offset_if_trace_access(&indexable, offset))), + // If we have an Index accessor, we compute the index and query the index-th element of the + // indexable. If the index is not a constant and computed_indices is true, we raise + // a diagnostic. If the index is not a constant and computed_indices is false, we + // keep the node as is. If the index is an out-of-bound constant, we raise a + // diagnostic. + MirAccessType::Index(index) => unroll_accessor_index_access_type( + indexable, + index, + offset, + expect_constant_indices, + diagnostics, + ), + // If we have an Matrix accessor, we compute both the corresponding row and column, and + // query the indexable accordingly. If either of row or col is not a constant and + // computed_indices is true, we raise a diagnostic. If either of row or col is not a + // constant and computed_indices is false, we keep the node as is. If either of row + // or col is an out-of-bound constant, we raise a diagnostic. + MirAccessType::Matrix(row, col) => unroll_accessor_matrix_access_type( + indexable, + row, + col, + expect_constant_indices, + diagnostics, + ), + } +} + +/// Helper function to unroll an Index accessor +fn unroll_accessor_index_access_type( + indexable: Link, + index: Link, + accessor_offset: usize, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result>, CompileError> { + let Some(index_usize) = extract_index_value(&index, expect_constant_indices, diagnostics)? + else { + return Ok(None); + }; + if let Op::Vector(indexable_vector) = indexable.borrow().deref() { + let indexable_vec = indexable_vector.children().borrow().deref().clone(); + let child_accessed = match indexable_vec.get(index_usize) { + Some(child_accessed) => child_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access an index which is out of bounds") + .with_primary_label(index.span(), "index out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + Ok(Some(add_row_offset_if_trace_access(child_accessed, accessor_offset))) + } else if let Some(value) = indexable.clone().as_value() { + // If the indexable is either a PublicInput or a TraceAccess, we treat the index as + // an offset + let mir_value = value.value.value.clone(); + match mir_value { + MirValue::PublicInput(public_input_access) => { + let new_node = Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::PublicInput(PublicInputAccess { + name: public_input_access.name, + index: public_input_access.index + index_usize, + }), + }); + Ok(Some(new_node)) + }, + MirValue::TraceAccess(trace_access) => { + // We also need to account for the row offset + let new_node = Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::TraceAccess(TraceAccess { + segment: trace_access.segment, + column: trace_access.column + index_usize, + row_offset: trace_access.row_offset, + }), + }); + Ok(Some(new_node)) + }, + _ => { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Index with indexable {:?}", + indexable + ); + }, + } + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Index with indexable {:?}", + indexable + ); + } +} + +/// Helper function to unroll a Matrix accessor +fn unroll_accessor_matrix_access_type( + indexable: Link, + row: Link, + col: Link, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result>, CompileError> { + let Some(row_usize) = extract_index_value(&row, expect_constant_indices, diagnostics)? else { + return Ok(None); + }; + let Some(col_usize) = extract_index_value(&col, expect_constant_indices, diagnostics)? else { + return Ok(None); + }; + // Replace the current node by the index-th element of the vector + // Raise diag if index is out of bounds + if let Op::Vector(indexable_vector) = indexable.borrow().deref() { + let indexable_vec = indexable_vector.children().borrow().deref().clone(); + let row_accessed = match indexable_vec.get(row_usize) { + Some(row_accessed) => row_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a row which is out of bounds") + .with_primary_label(row.span(), "row out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { + let row_accessed_vec = row_accessed_vector.children().borrow().deref().clone(); + let child_accessed = match row_accessed_vec.get(col_usize) { + Some(child_accessed) => child_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a col which is out of bounds") + .with_primary_label(col.span(), "col out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + Ok(Some(child_accessed.clone())) + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Matrix with indexable {:?}", + indexable + ); + } + } else if let Op::Matrix(indexable_matrix) = indexable.borrow().deref() { + let indexable_vec = indexable_matrix.children().borrow().deref().clone(); + let row_accessed = match indexable_vec.get(row_usize) { + Some(row_accessed) => row_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a row which is out of bounds") + .with_primary_label(row.span(), "row out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { + let row_accessed_vec = row_accessed_vector.children().borrow().deref().clone(); + let child_accessed = match row_accessed_vec.get(col_usize) { + Some(child_accessed) => child_accessed, + None => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access a col which is out of bounds") + .with_primary_label(col.span(), "col out of bounds") + .emit(); + return Err(CompileError::Failed); + }, + }; + Ok(Some(child_accessed.clone())) + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Matrix with indexable {:?}", + indexable + ); + } + } else { + unreachable!( + "Unexpected accessor, cannot have MirAccessType::Matrix with indexable {:?}", + indexable + ); + } +} + +/// Helper function to extract a usize value from an index expression with consistent error handling +/// +/// Returns: +/// - `Ok(Some(usize))` - Successfully extracted constant value +/// - `Ok(None)` - Not a constant value but not required (expect_constant_indices=false) +/// - `Err(CompileError)` - Not a constant value when required (expect_constant_indices=true) +fn extract_index_value( + index: &Link, + expect_constant_indices: bool, + diagnostics: &DiagnosticsHandler, +) -> Result, CompileError> { + match (get_inner_const(index), expect_constant_indices) { + (Some(value), _) => Ok(Some(value as usize)), + (None, true) => { + diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("the index is not constant during constant propagation") + .with_primary_label(index.span(), "index is not constant") + .emit(); + Err(CompileError::Failed) + }, + (None, false) => Ok(None), + } +} + +/// Helper function to add a row offset to a `TraceAccess` value, and return the node unchanged +/// otherwise. +fn add_row_offset_if_trace_access(node: &Link, offset: usize) -> Link { + if let Some(value) = node.clone().as_value() { + let mir_value = value.value.value.clone(); + if let MirValue::TraceAccess(trace_access) = mir_value { + Value::create(SpannedMirValue { + span: value.value.span(), + value: MirValue::TraceAccess(TraceAccess { + segment: trace_access.segment, + column: trace_access.column, + row_offset: trace_access.row_offset + offset, + }), + }) + } else { + node.clone() } - Op::None(_) => {} + } else { + node.clone() } } diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 6a5f11f40..db342ecb7 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -1,19 +1,22 @@ use core::panic; use std::ops::Deref; -use air_parser::ast::AccessType; -use air_parser::{LexicalScope, ast, symbols}; +use air_parser::{ + LexicalScope, + ast::{self, AccessType, ScalarExpr, TraceSegmentId}, + symbols, +}; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; -use crate::ir::BusAccess; use crate::{ CompileError, ir::{ - Accessor, Add, Boundary, Builder, Bus, BusOp, BusOpKind, Call, ConstantValue, Enf, - Evaluator, Exp, Fold, FoldOperator, For, Function, Link, Matrix, Mir, MirType, MirValue, - Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, - SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, + Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, + Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, + MirAccessType, MirType, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, + PublicInputTableAccess, Root, SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, + Vector, }, passes::duplicate_node, }; @@ -25,7 +28,8 @@ use crate::{ /// * has had constant propagation already applied /// /// Notes: -/// * During this step, we unpack parameters and arguments of evaluators, in order to make it easier to inline them +/// * During this step, we unpack parameters and arguments of evaluators, in order to make it easier +/// to inline them /// /// TODO: /// - [ ] Implement diagnostics for better error handling @@ -58,6 +62,9 @@ pub struct MirBuilder<'a> { mir: Mir, trace_columns: &'a Vec, bindings: LexicalScope<&'a ast::Identifier, Link>, + // The root node is either the evaluator or function we're currently translating the body of, + // or None if we're not inside a function or evaluator (e.g. translating boundary / integrity + // constraints) root: Link, root_name: Option<&'a ast::QualifiedIdentifier>, in_boundary: bool, @@ -89,9 +96,7 @@ impl<'a> MirBuilder<'a> { self.mir.public_inputs = self.program.public_inputs.clone(); for (qual_ident, ast_bus) in buses.iter() { let bus = self.translate_bus_definition(ast_bus)?; - self.mir - .constraint_graph_mut() - .insert_bus(*qual_ident, bus)?; + self.mir.constraint_graph_mut().insert_bus(qual_ident.clone(), bus)?; } for (ident, function) in &self.program.functions { @@ -117,16 +122,16 @@ impl<'a> MirBuilder<'a> { } for bus in self.mir.constraint_graph().buses.values() { - let bus_name = bus.borrow().name(); - if let Some(ref mut mirvalue) = bus.borrow().get_first().as_value_mut() { - if let MirValue::PublicInputTable(ref mut first) = mirvalue.value.value { - first.set_bus_name(bus_name); - } + let bus_type = bus.borrow().bus_type; + if let Some(ref mut mirvalue) = bus.borrow().get_first().as_value_mut() + && let MirValue::PublicInputTable(ref mut first) = mirvalue.value.value + { + first.set_bus_type(bus_type); } - if let Some(ref mut mirvalue) = bus.borrow().get_last().as_value_mut() { - if let MirValue::PublicInputTable(ref mut last) = mirvalue.value.value { - last.set_bus_name(bus_name); - } + if let Some(ref mut mirvalue) = bus.borrow().get_last().as_value_mut() + && let MirValue::PublicInputTable(ref mut last) = mirvalue.value.value + { + last.set_bus_type(bus_type); } } Ok(()) @@ -167,9 +172,7 @@ impl<'a> MirBuilder<'a> { set_all_ref_nodes(all_params_flatten.clone(), ev.as_owner()); - self.mir - .constraint_graph_mut() - .insert_evaluator(*ident, ev.clone())?; + self.mir.constraint_graph_mut().insert_evaluator(ident.clone(), ev.clone())?; Ok(ev) } @@ -207,12 +210,12 @@ impl<'a> MirBuilder<'a> { } let vector_node = Vector::create(params_vec, span); self.bindings.insert(name.unwrap(), vector_node.clone()); - } + }, ast::Type::Felt => { let param = all_params_flatten_for_trace_segment[i].clone(); i += 1; self.bindings.insert(name.unwrap(), param.clone()); - } + }, _ => unreachable!(), }; } @@ -241,19 +244,13 @@ impl<'a> MirBuilder<'a> { func = func.parameters(param.clone()); } i += 1; - let ret = Parameter::create( - i, - self.translate_type(&ast_func.return_type), - ast_func.span(), - ); + let ret = Parameter::create(i, self.translate_type(&ast_func.return_type), ast_func.span()); params.push(ret.clone()); let func = func.return_type(ret).build(); set_all_ref_nodes(params.clone(), func.as_owner()); - self.mir - .constraint_graph_mut() - .insert_function(*ident, func.clone())?; + self.mir.constraint_graph_mut().insert_function(ident.clone(), func.clone())?; Ok(func) } @@ -291,7 +288,7 @@ impl<'a> MirBuilder<'a> { let param = Parameter::create(*i, MirType::Felt, span); *i += 1; Ok(vec![param]) - } + }, ast::Type::Vector(size) => { let mut params = Vec::new(); for _ in 0..*size { @@ -300,7 +297,7 @@ impl<'a> MirBuilder<'a> { params.push(param); } Ok(params) - } + }, ast::Type::Matrix(_rows, _cols) => { let span = if let Some(name) = name { name.span() @@ -313,7 +310,7 @@ impl<'a> MirBuilder<'a> { .with_primary_label(span, "expected this to be a felt or vector") .emit(); Err(CompileError::Failed) - } + }, } } @@ -329,12 +326,12 @@ impl<'a> MirBuilder<'a> { let param = Parameter::create(*i, MirType::Felt, span); *i += 1; Ok(param) - } + }, ast::Type::Vector(size) => { let param = Parameter::create(*i, MirType::Vector(*size), span); *i += 1; Ok(param) - } + }, ast::Type::Matrix(_rows, _cols) => { let span = if let Some(name) = name { name.span() @@ -347,19 +344,21 @@ impl<'a> MirBuilder<'a> { .with_primary_label(span, "expected this to be a felt or vector") .emit(); Err(CompileError::Failed) - } + }, } } + /// Translates each statement in the body of a function or evaluator `func` fn translate_body( &mut self, _ident: &ast::QualifiedIdentifier, func: Link, body: &'a Vec, - ) -> Result, CompileError> { + ) -> Result<(), CompileError> { + // First, the root field sets the context that we are currently translating the body of + // `func`. It is for instance needed to correctly translate let statements and + // attach the statements of their bodies to the body of `func`. self.root = func.clone(); - self.bindings.enter(); - let func = func; for stmt in body { let op = self.translate_statement(stmt)?; match func.clone().borrow().deref() { @@ -367,12 +366,10 @@ impl<'a> MirBuilder<'a> { Root::Evaluator(e) => e.body.borrow_mut().push(op.clone()), Root::None(_span) => { unreachable!("expected function or evaluator, got None") - } + }, }; - self.root = func.clone(); } - self.bindings.exit(); - Ok(func) + Ok(()) } fn translate_type(&mut self, ty: &ast::Type) -> MirType { @@ -383,24 +380,47 @@ impl<'a> MirBuilder<'a> { } } + /// Translates a statement and returns the operation. + /// Note: for `let` statements, the returned operation is the last operation of its body. fn translate_statement(&mut self, stmt: &'a ast::Statement) -> Result, CompileError> { match stmt { ast::Statement::Let(let_stmt) => self.translate_let(let_stmt), ast::Statement::Expr(expr) => self.translate_expr(expr), ast::Statement::Enforce(enf) => self.translate_enforce(enf), - ast::Statement::EnforceIf(enf, cond) => self.translate_enforce_if(enf, cond), + ast::Statement::EnforceIf(match_expr) => self.translate_enforce_if(match_expr), ast::Statement::EnforceAll(list_comp) => self.translate_enforce_all(list_comp), ast::Statement::BusEnforce(list_comp) => self.translate_bus_enforce(list_comp), } } + + /// Translates a let statement by: + /// - binding its value to its name for the scope of its body, + /// - adding the statements of its body to the root's body if necessary, + /// - returning the operation of the last statement of its body, which is the value of the whole + /// block. + /// + /// Note: as we already return the operation of the last statement, we do not need to add it to + /// the root's body here. This should be handled by the caller. fn translate_let(&mut self, let_stmt: &'a ast::Let) -> Result, CompileError> { let name = &let_stmt.name; let value: Link = self.translate_expr(&let_stmt.value)?; let mut ret_value = value.clone(); self.bindings.enter(); self.bindings.insert(name, value.clone()); - for stmt in let_stmt.body.iter() { - ret_value = self.translate_statement(stmt)?; + for (i, stmt) in let_stmt.body.iter().enumerate() { + let new_stmt = self.translate_statement(stmt)?; + // Skip the last statement as it is returned + if i < let_stmt.body.len() - 1 { + match self.root.borrow().deref() { + Root::Function(f) => f.body.borrow_mut().push(new_stmt.clone()), + Root::Evaluator(e) => e.body.borrow_mut().push(new_stmt.clone()), + // Root::None means we are translating statements of boundary / integrity + // constraints, for which statements are inserted on enforce, nothing to do + // here. + Root::None(_span) => {}, + } + } + ret_value = new_stmt; } self.bindings.exit(); Ok(ret_value) @@ -416,10 +436,9 @@ impl<'a> MirBuilder<'a> { ast::Expr::Call(c) => self.translate_call(c), ast::Expr::ListComprehension(lc) => self.translate_list_comprehension(lc), ast::Expr::Let(l) => self.translate_let(l), - ast::Expr::Null(_) => Ok(Value::create(SpannedMirValue { - span: expr.span(), - value: MirValue::Null, - })), + ast::Expr::Null(_) => { + Ok(Value::create(SpannedMirValue { span: expr.span(), value: MirValue::Null })) + }, ast::Expr::Unconstrained(_) => Ok(Value::create(SpannedMirValue { span: expr.span(), value: MirValue::Unconstrained, @@ -435,10 +454,18 @@ impl<'a> MirBuilder<'a> { fn translate_enforce_if( &mut self, - _enf: &ast::ScalarExpr, - _cond: &ast::ScalarExpr, + match_expr: &'a ast::Match, ) -> Result, CompileError> { - unreachable!("all EnforceIf should have been transformed into EnforceAll") + let mut match_arms = Vec::new(); + + for match_arm in match_expr.match_arms.iter() { + let cond_node = self.translate_scalar_expr(&match_arm.condition)?; + let expr_node = self.translate_scalar_expr(&match_arm.expr)?; + match_arms.push(MatchArm::new(expr_node, cond_node)); + } + + let if_node = If::create(match_arms, match_expr.span()); + self.insert_enforce(if_node) } fn translate_enforce_all( @@ -474,12 +501,7 @@ impl<'a> MirBuilder<'a> { } else { Link::default() }; - for_node - .as_for_mut() - .unwrap() - .expr - .borrow_mut() - .clone_from(&body_node.borrow()); + for_node.as_for_mut().unwrap().expr.borrow_mut().clone_from(&body_node.borrow()); for_node .as_for_mut() .unwrap() @@ -554,7 +576,7 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); }; - } + }, _ => unimplemented!(), }; let sel = match list_comp.selector.as_ref() { @@ -572,22 +594,12 @@ impl<'a> MirBuilder<'a> { ) .emit(); return Err(CompileError::Failed); - } + }, }; // Note: safe to unwrap because we checked that bus_op is a BusOp above - bus_op - .as_bus_op_mut() - .unwrap() - .latch - .borrow_mut() - .clone_from(&sel.borrow()); - let bus_op_clone = bus_op.clone(); - let bus_op_ref = bus_op_clone.as_bus_op_mut().unwrap(); - let bus_link = bus_op_ref.bus.to_link().unwrap(); - let mut bus = bus_link.borrow_mut(); - bus.latches.push(sel.clone()); - bus.columns.push(bus_op.clone()); - Ok(bus_op) + bus_op.as_bus_op_mut().unwrap().latch.borrow_mut().clone_from(&sel.borrow()); + let enf_node = self.insert_enforce(bus_op.clone())?; + Ok(enf_node) } fn insert_enforce(&mut self, node: Link) -> Result, CompileError> { @@ -613,7 +625,7 @@ impl<'a> MirBuilder<'a> { .constraint_graph_mut() .insert_integrity_constraints_root(node_to_add.clone()); }; - } + }, }; Ok(node_to_add) } @@ -632,9 +644,9 @@ impl<'a> MirBuilder<'a> { } fn translate_vector_expr(&mut self, v: &'a [ast::Expr]) -> Result, CompileError> { - let span = v.iter().fold(SourceSpan::UNKNOWN, |acc, expr| { - acc.merge(expr.span()).unwrap_or(acc) - }); + let span = v + .iter() + .fold(SourceSpan::UNKNOWN, |acc, expr| acc.merge(expr.span()).unwrap_or(acc)); let mut node = Vector::builder().size(v.len()).span(span); for value in v.iter() { let value_node = self.translate_expr(value)?; @@ -647,9 +659,9 @@ impl<'a> MirBuilder<'a> { &mut self, v: &'a [ast::ScalarExpr], ) -> Result, CompileError> { - let span = v.iter().fold(SourceSpan::UNKNOWN, |acc, expr| { - acc.merge(expr.span()).unwrap_or(acc) - }); + let span = v + .iter() + .fold(SourceSpan::UNKNOWN, |acc, expr| acc.merge(expr.span()).unwrap_or(acc)); let mut node = Vector::builder().size(v.len()).span(span); for value in v.iter() { let value_node = self.translate_scalar_expr(value)?; @@ -662,9 +674,10 @@ impl<'a> MirBuilder<'a> { &mut self, m: &'a Span>>, ) -> Result, CompileError> { - let span = m.iter().flatten().fold(SourceSpan::UNKNOWN, |acc, expr| { - acc.merge(expr.span()).unwrap_or(acc) - }); + let span = m + .iter() + .flatten() + .fold(SourceSpan::UNKNOWN, |acc, expr| acc.merge(expr.span()).unwrap_or(acc)); let mut node = Matrix::builder().size(m.len()).span(span); for row in m.iter() { let row_node = self.translate_vector_scalar_expr(row)?; @@ -676,9 +689,9 @@ impl<'a> MirBuilder<'a> { fn translate_symbol_access( &mut self, - access: &ast::SymbolAccess, + access: &'a ast::SymbolAccess, ) -> Result, CompileError> { - match access.name { + match &access.name { // At this point during compilation, fully-qualified identifiers can only possibly refer // to a periodic column, as all functions have been inlined, and constants propagated. ast::ResolvableIdentifier::Resolved(qual_ident) => { @@ -687,7 +700,7 @@ impl<'a> MirBuilder<'a> { .value(SpannedMirValue { span: access.span(), value: MirValue::PeriodicColumn(crate::ir::PeriodicColumnAccess::new( - qual_ident, + qual_ident.clone(), pc.period(), )), }) @@ -717,21 +730,20 @@ impl<'a> MirBuilder<'a> { format!("in this access expression `{access:#?}`"), ) .emit(); - //unreachable!("expected reference to periodic column in `{:#?}`", access); Err(CompileError::Failed) } - } + }, // This must be one of public inputs or trace columns ast::ResolvableIdentifier::Global(ident) | ast::ResolvableIdentifier::Local(ident) => { self.translate_symbol_access_global_or_local(&ident, access) - } + }, // These should have been eliminated by previous compiler passes ast::ResolvableIdentifier::Unresolved(_ident) => { unreachable!( "expected fully-qualified or global reference, got `{:?}` instead", &access.name ); - } + }, } } @@ -765,7 +777,7 @@ impl<'a> MirBuilder<'a> { .emit(); CompileError::Failed })?; - } + }, ast::Boundary::Last => { bus.borrow_mut().set_last(rhs.clone()).map_err(|_| { self.diagnostics @@ -778,7 +790,7 @@ impl<'a> MirBuilder<'a> { .emit(); CompileError::Failed })?; - } + }, } return Ok(Op::None(bin_op.span()).into()); } @@ -789,23 +801,23 @@ impl<'a> MirBuilder<'a> { ast::BinaryOp::Add => { let node = Add::builder().lhs(lhs).rhs(rhs).span(bin_op.span()).build(); Ok(node) - } + }, ast::BinaryOp::Sub => { let node = Sub::builder().lhs(lhs).rhs(rhs).span(bin_op.span()).build(); Ok(node) - } + }, ast::BinaryOp::Mul => { let node = Mul::builder().lhs(lhs).rhs(rhs).span(bin_op.span()).build(); Ok(node) - } + }, ast::BinaryOp::Exp => { let node = Exp::builder().lhs(lhs).rhs(rhs).span(bin_op.span()).build(); Ok(node) - } + }, ast::BinaryOp::Eq => { let sub_node = Sub::builder().lhs(lhs).rhs(rhs).span(bin_op.span()).build(); Ok(Enf::builder().expr(sub_node).span(bin_op.span()).build()) - } + }, } } @@ -828,7 +840,7 @@ impl<'a> MirBuilder<'a> { .initial_value(accumulator_node) .build(); Ok(node) - } + }, symbols::Prod => { assert_eq!(call.args.len(), 1); let iterator_node = self.translate_expr(call.args.first().unwrap())?; @@ -841,7 +853,7 @@ impl<'a> MirBuilder<'a> { .initial_value(accumulator_node) .build(); Ok(node) - } + }, other => unimplemented!("unhandled builtin: {}", other), } } else { @@ -850,11 +862,7 @@ impl<'a> MirBuilder<'a> { // Get the known callee in the functions hashmap // Then, get the node index of the function definition let callee_node; - if let Some(callee) = self - .mir - .constraint_graph() - .get_function_root(&resolved_callee) - { + if let Some(callee) = self.mir.constraint_graph().get_function_root(&resolved_callee) { callee_node = callee.clone(); let mut errors = Vec::with_capacity(call.args.len()); arg_nodes = call @@ -902,10 +910,8 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); } - } else if let Some(callee) = self - .mir - .constraint_graph() - .get_evaluator_root(&resolved_callee) + } else if let Some(callee) = + self.mir.constraint_graph().get_evaluator_root(&resolved_callee) { // TRANSLATE TODO: // - For Evaluators, we need to: @@ -985,12 +991,7 @@ impl<'a> MirBuilder<'a> { }; let body_node = self.translate_scalar_expr(&list_comp.body)?; - for_node - .as_for_mut() - .unwrap() - .expr - .borrow_mut() - .clone_from(&body_node.borrow()); + for_node.as_for_mut().unwrap().expr.borrow_mut().clone_from(&body_node.borrow()); for_node .as_for_mut() .unwrap() @@ -999,6 +1000,7 @@ impl<'a> MirBuilder<'a> { .clone_from(&selector_node.borrow()); self.bindings.exit(); + Ok(for_node) } @@ -1040,7 +1042,7 @@ impl<'a> MirBuilder<'a> { fn translate_bounded_symbol_access( &mut self, - access: &ast::BoundedSymbolAccess, + access: &'a ast::BoundedSymbolAccess, ) -> Result, CompileError> { let access_node = self.translate_symbol_access(&access.column)?; let node = Boundary::builder() @@ -1051,6 +1053,76 @@ impl<'a> MirBuilder<'a> { Ok(node) } + // If an [ast::AccessType] is a slice, we need to + // translate it into a vector of MirAccessType::Index. + // If it is not a slice, return None. + // This is used to completely eliminate slice accesses in MIR + fn translate_potential_slice( + &mut self, + access_expr: &Link, + access: &'a ast::SymbolAccess, + ) -> Option> { + // If it's a slice access, we need to create a vector of MirAccessType::Index + if let AccessType::Slice(ast::RangeExpr { start, end, .. }) = &access.access_type { + let ( + ast::RangeBound::Const(Span { item: start, .. }), + ast::RangeBound::Const(Span { item: end, .. }), + ) = (start, end) + else { + unreachable!( + "Slice expressions must use constant integer bounds (as in arr[0..5]), found: {:#?}). Dynamic bounds such as variables or expressions are not supported.", + access.access_type + ); + }; + if start >= end { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("Slice is empty (start >= end)") + .with_primary_label( + access.span(), + format!("Slice start: {start}, Slice end: {end}"), + ) + .emit(); + return None; + } + let mut vector = Vector::builder().size(end - start).span(access.span()); + for i in *start..*end { + let mir_access_type = MirAccessType::Index(Link::::from(i)); + let inner_accessor = Accessor::builder() + .indexable(duplicate_node(access_expr.clone(), &mut Default::default())) + .access_type(mir_access_type) + .offset(access.offset) + .span(access.span()) + .build(); + vector = vector.elements(inner_accessor); + } + let vector = vector.build(); + return Some(vector); + }; + None + } + fn translate_access_type( + &mut self, + access_type: &'a ast::AccessType, + ) -> Result { + let mir_access_type = match access_type { + AccessType::Default => MirAccessType::Default, + AccessType::Index(index) => { + let index_node = self.translate_scalar_expr(index)?; + MirAccessType::Index(index_node) + }, + AccessType::Matrix(row, col) => { + let row_node = self.translate_scalar_expr(row)?; + let col_node = self.translate_scalar_expr(col)?; + MirAccessType::Matrix(row_node, col_node) + }, + AccessType::Slice(_range_expr) => unreachable!( + "Slices should have been transformed into vector operations during constant propagation" + ), + }; + Ok(mir_access_type) + } + fn translate_bus_operation( &mut self, ast_bus_op: &'a ast::BusOperation, @@ -1085,18 +1157,15 @@ impl<'a> MirBuilder<'a> { ast::BusOperator::Remove => BusOpKind::Remove, }; - let mut bus_op = BusOp::builder() - .span(ast_bus_op.span()) - .bus(bus) - .kind(bus_op_kind); + let mut bus_op = BusOp::builder().span(ast_bus_op.span()).bus(bus).kind(bus_op_kind); for arg in ast_bus_op.args.iter() { let mut arg_node = self.translate_expr(arg)?; let accessor_mut = arg_node.clone(); if let Some(accessor) = accessor_mut.as_accessor_mut() { match accessor.access_type { - AccessType::Default => { + MirAccessType::Default => { arg_node = accessor.indexable.clone(); - } + }, _ => { self.diagnostics .diagnostic(Severity::Error) @@ -1107,7 +1176,7 @@ impl<'a> MirBuilder<'a> { ) .emit(); return Err(CompileError::Failed); - } + }, } } bus_op = bus_op.args(arg_node); @@ -1159,27 +1228,18 @@ impl<'a> MirBuilder<'a> { fn translate_symbol_access_global_or_local( &mut self, ident: &ast::Identifier, - access: &ast::SymbolAccess, + access: &'a ast::SymbolAccess, ) -> Result, CompileError> { - // Special identifiers are those which are `$`-prefixed, and must refer to the names of trace segments (e.g. `$main`) + // Special identifiers are those which are `$`-prefixed, and must refer to the names of + // trace segments (e.g. `$main`) if ident.is_special() { // Must be a trace segment name - if let Some(trace_access) = self.trace_access(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccess(trace_access), - }) - .build()); + if let Some(trace_access) = self.trace_access(access)? { + return Ok(trace_access); } if let Some(tab) = self.trace_access_binding(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccessBinding(tab), - }) - .build()); + return Ok(tab); } // It should never be possible to reach this point - semantic analysis @@ -1190,63 +1250,42 @@ impl<'a> MirBuilder<'a> { ); } - // // If we reach here, this must be a let-bound variable + // If we reach here, this must be a let-bound variable if let Some(let_bound_access_expr) = self.bindings.get(access.name.as_ref()).cloned() { // If the let-bound variable is a parameter, we probably already have the type // // In that case, replacing the default type (Felt) with the one from the access - if let Some(mut param) = let_bound_access_expr.as_parameter_mut() { - if let Some(access_ty) = &access.ty { - param.ty = self.translate_type(access_ty); - } + if let Some(mut param) = let_bound_access_expr.as_parameter_mut() + && let Some(access_ty) = &access.ty + { + param.ty = self.translate_type(access_ty); + } + // If it's a slice access, we need to return its translation. + // This eliminates the case of [ast::AccessType::Slice] in MIR + if let Some(slice) = self.translate_potential_slice(&let_bound_access_expr, access) { + return Ok(slice); } + let mir_access_type = self.translate_access_type(&access.access_type)?; let accessor: Link = Accessor::create( duplicate_node(let_bound_access_expr, &mut Default::default()), - access.access_type.clone(), + mir_access_type, access.offset, access.span(), ); - return Ok(accessor); } - if let Some(trace_access) = self.trace_access(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccess(trace_access), - }) - .build()); + if let Some(trace_access) = self.trace_access(access)? { + return Ok(trace_access); } // Otherwise, we check bindings, trace bindings, and public inputs, in that order if let Some(tab) = self.trace_access_binding(access) { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::TraceAccessBinding(tab), - }) - .build()); + return Ok(tab); } - match self.public_input_access(access) { - (Some(public_input_access), None) => { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::PublicInput(public_input_access), - }) - .build()); - } - (None, Some(public_input_table_access)) => { - return Ok(Value::builder() - .value(SpannedMirValue { - span: access.span(), - value: MirValue::PublicInputTable(public_input_table_access), - }) - .build()); - } - _ => {} + if let Some(public_input_access) = self.public_input_access(access)? { + return Ok(public_input_access); } self.diagnostics @@ -1260,56 +1299,105 @@ impl<'a> MirBuilder<'a> { Err(CompileError::Failed) } - // Check assumptions, probably this assumed that the inlining pass did some work fn public_input_access( - &self, - access: &ast::SymbolAccess, - ) -> (Option, Option) { + &mut self, + access: &'a ast::SymbolAccess, + ) -> Result>, CompileError> { let Some(public_input) = self.mir.public_inputs.get(access.name.as_ref()) else { - return (None, None); + return Ok(None); }; - match access.access_type { - AccessType::Default => ( - None, - Some(PublicInputTableAccess::new( - public_input.name(), - public_input.size(), - )), - ), - AccessType::Index(index) => ( - Some(PublicInputAccess::new(public_input.name(), index)), - None, - ), + match access.access_type.clone() { + AccessType::Default => { + let public_input_table = + PublicInputTableAccess::new(public_input.name(), public_input.size()); + Ok(Some( + Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInputTable(public_input_table), + }) + .build(), + )) + }, + AccessType::Index(index) => { + // If the index is a constant, we construct the corresponding PublicInputAccess + if let ScalarExpr::Const(c) = *index { + let public_input_access = + PublicInputAccess::new(public_input.name(), c.item as usize); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInput(public_input_access), + }) + .build(); + Ok(Some(value)) + } else { + // Otherwise, we need to wrap a PublicInputAccess with an accessor. In this + // case, the accessor is not used to index into a vector, + // but rather to offset the targeted column + let public_input_access = PublicInputAccess::new(public_input.name(), 0); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInput(public_input_access), + }) + .build(); + // If it's a slice access, we need to return its translation. + // This eliminates the case of [ast::AccessType::Slice] in MIR + if let Some(slice) = self.translate_potential_slice(&value, access) { + return Ok(Some(slice)); + } + let mir_access_type = self.translate_access_type(&access.access_type)?; + let accessor = + Accessor::create(value, mir_access_type, access.offset, access.span()); + Ok(Some(accessor)) + } + }, _ => { // This should have been caught earlier during compilation unreachable!( "unexpected public input access type encountered during lowering: {:#?}", access ) - } + }, } } - // Check assumptions, probably this assumed that the inlining pass did some work - fn trace_access_binding(&self, access: &ast::SymbolAccess) -> Option { + fn trace_access_binding(&self, access: &ast::SymbolAccess) -> Option> { let id = access.name.as_ref(); for segment in self.trace_columns.iter() { - if let Some(binding) = segment - .bindings - .iter() - .find(|tb| tb.name.as_ref() == Some(id)) - { + if let Some(binding) = segment.bindings.iter().find(|tb| tb.name.as_ref() == Some(id)) { return match &access.access_type { - AccessType::Default => Some(TraceAccessBinding { - segment: binding.segment, - offset: binding.offset, - size: binding.size, - }), - AccessType::Slice(range_expr) => Some(TraceAccessBinding { - segment: binding.segment, - offset: binding.offset + range_expr.to_slice_range().start, - size: range_expr.to_slice_range().count(), - }), + AccessType::Default => { + let tab = TraceAccessBinding { + segment: binding.segment, + offset: binding.offset, + size: binding.size, + }; + Some( + Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccessBinding(tab), + }) + .build(), + ) + }, + AccessType::Slice(range_expr) => { + let tab = TraceAccessBinding { + segment: binding.segment, + offset: binding.offset + range_expr.to_slice_range().start, + size: range_expr.to_slice_range().count(), + }; + Some( + Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccessBinding(tab), + }) + .build(), + ) + }, _ => None, }; } @@ -1317,56 +1405,121 @@ impl<'a> MirBuilder<'a> { None } - // Check assumptions, probably this assumed that the inlining pass did some work - fn trace_access(&self, access: &ast::SymbolAccess) -> Option { + fn trace_access( + &mut self, + access: &'a ast::SymbolAccess, + ) -> Result>, CompileError> { + assert_eq!( + self.trace_columns.len(), + 1, + "In MIR, expected exactly one trace segment to be present" + ); let id = access.name.as_ref(); - for (i, segment) in self.trace_columns.iter().enumerate() { - if segment.name == id { - if let AccessType::Index(column) = access.access_type { - return Some(TraceAccess::new(i, column, access.offset)); - } else { - // This should have been caught earlier during compilation - unreachable!( - "unexpected trace access type encountered during lowering: {:#?}", - &access - ); - } + let segment = self.trace_columns.first().unwrap(); + + if segment.name == id { + // We access $main[i] + if let AccessType::Index(column) = access.access_type.clone() { + let node = self.translate_indexed_trace_access( + column, + TraceSegmentId::Main, + 0, + access.offset, + access, + )?; + Ok(Some(node)) + } else { + // This should have been caught earlier during compilation + unreachable!( + "unexpected trace access type encountered during lowering: {:#?}", + &access + ); } - - if let Some(binding) = segment - .bindings - .iter() - .find(|tb| tb.name.as_ref() == Some(id)) - { - return match access.access_type { - AccessType::Default if binding.size == 1 => Some(TraceAccess::new( + } else if let Some(binding) = + segment.bindings.iter().find(|tb| tb.name.as_ref() == Some(id)) + { + // We access a trace binding defined in the main trace. + match access.access_type.clone() { + AccessType::Default if binding.size == 1 => { + let ta = TraceAccess::new(binding.segment, binding.offset, access.offset); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccess(ta), + }) + .build(); + // If it's a slice access, we need to return its translation. + // This eliminates the case of [ast::AccessType::Slice] in MIR + if let Some(slice) = self.translate_potential_slice(&value, access) { + return Ok(Some(slice)); + } + let mir_binding_access = self.translate_access_type(&binding.access)?; + let accessor = Accessor::create(value, mir_binding_access, 0, access.span()); + Ok(Some(accessor)) + }, + AccessType::Index(extra_offset) if binding.size > 1 => { + let node = self.translate_indexed_trace_access( + extra_offset, binding.segment, binding.offset, access.offset, - )), - AccessType::Index(extra_offset) if binding.size > 1 => Some(TraceAccess::new( - binding.segment, - binding.offset + extra_offset, - access.offset, - )), - // This should have been caught earlier during compilation - /*_ => unreachable!( - "unexpected trace access type encountered during lowering: {:#?}", - access - ),*/ - _ => None, - }; + access, + )?; + Ok(Some(node)) + }, + _ => Ok(None), } + } else { + // We do not access a trace + Ok(None) + } + } + + /// Helper function to translate a trace_access based on an index. If the index is a constant, + /// we build the corresponding TraceAccess. Otherwise, we build an Accessor around the + /// TraceAccess, the index should then be treated as an offset and not a way to index into a + /// collection. + fn translate_indexed_trace_access( + &mut self, + index: Box, + segment: TraceSegmentId, + offset: usize, + row_offset: usize, + access: &'a ast::SymbolAccess, + ) -> Result, CompileError> { + // If the index is a constant, we construct the corresponding TraceAccess + if let ScalarExpr::Const(c) = *index { + let ta = TraceAccess::new(segment, offset + c.item as usize, row_offset); + Ok(Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccess(ta), + }) + .build()) + } else { + // Otherwise, we need to wrap a TraceAccess with an accessor. In this case, the accessor + // is not used to index into a vector, but rather to offset the targeted column + let ta = TraceAccess::new(segment, offset, row_offset); + let value = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::TraceAccess(ta), + }) + .build(); + // If it's a slice access, we need to return its translation. + // This eliminates the case of [ast::AccessType::Slice] in MIR + if let Some(slice) = self.translate_potential_slice(&value, access) { + return Ok(slice); + } + let mir_access_type = self.translate_access_type(&access.access_type)?; + Ok(Accessor::create(value, mir_access_type, 0, access.span())) } - None } } fn set_all_ref_nodes(params: Vec>, ref_node: Link) { for param in params { - let Some(mut param) = param.as_parameter_mut() else { - unreachable!("expected parameter, got {:?}", param); - }; + let mut param = param.as_parameter_mut().expect("Tried to set ref node on non-parameter"); param.set_ref_node(ref_node.clone()); } } diff --git a/mir/src/passes/unrolling.rs b/mir/src/passes/unrolling.rs deleted file mode 100644 index f9e1662ef..000000000 --- a/mir/src/passes/unrolling.rs +++ /dev/null @@ -1,1083 +0,0 @@ -use std::{collections::HashMap, ops::Deref, rc::Rc}; - -use air_parser::ast::AccessType; -use air_pass::Pass; -use miden_diagnostics::{DiagnosticsHandler, Spanned}; - -use crate::{CompileError, ir::*}; - -use super::{duplicate_node_or_replace, visitor::Visitor}; - -/// This pass follows a similar approach as the Inlining pass. -/// It requires that this Inlining pass has already been done. -/// -/// * In the first step, we visit the graph, unrolling each node type except For nodes. -/// Instead, for these node types we gather the context to inline them in the second pass. -/// * In the second pass, we inline the bodies of For nodes. -/// -/// TODO: -/// - [ ] Implement diagnostics for better error handling -pub struct Unrolling<'a> { - diagnostics: &'a DiagnosticsHandler, -} - -impl<'a> Unrolling<'a> { - pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { diagnostics } - } -} - -/// This structure is used to keep track of what is needed to inline a For node -#[derive(Clone, Debug)] -pub struct ForInliningContext { - body: Link, - iterators: Vec>, - selector: Option>, - ref_node: Link, -} -impl ForInliningContext {} - -pub struct UnrollingFirstPass<'a> { - #[allow(unused)] - diagnostics: &'a DiagnosticsHandler, - - // general context - work_stack: Vec>, - - // For each child of a For node encountered, we store the context to inline it in the second pass - bodies_to_inline: Vec<(Link, ForInliningContext)>, - - // We keep track of all parameters referencing a given For node - params_for_ref_node: HashMap>>, - // We keep a reference to For nodes in order to avoid the backlinks stored in Parameters - // referencing them to be dropped - all_for_nodes: HashMap, Link)>, -} - -impl<'a> UnrollingFirstPass<'a> { - pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { - diagnostics, - work_stack: vec![], - bodies_to_inline: vec![], - params_for_ref_node: HashMap::new(), - all_for_nodes: HashMap::new(), - } - } -} - -pub struct UnrollingSecondPass<'a> { - #[allow(unused)] - diagnostics: &'a DiagnosticsHandler, - - // general context - work_stack: Vec>, - // A list of all the children of For nodes to inline - bodies_to_inline: Vec<(Link, ForInliningContext)>, - // The current context for inlining a For node, if any - for_inlining_context: Option, - // A map of nodes to replace with their inlined version - nodes_to_replace: HashMap, Link)>, - // We keep track of all parameters referencing a given For node - params_for_ref_node: HashMap>>, - // We keep a reference to For nodes in order to avoid the backlinks stored in Parameters - // referencing them to be dropped - all_for_nodes: HashMap, Link)>, -} -impl<'a> UnrollingSecondPass<'a> { - pub fn new( - diagnostics: &'a DiagnosticsHandler, - bodies_to_inline: Vec<(Link, ForInliningContext)>, - all_for_nodes: HashMap, Link)>, - ) -> Self { - Self { - diagnostics, - work_stack: vec![], - bodies_to_inline, - for_inlining_context: None, - nodes_to_replace: HashMap::new(), - params_for_ref_node: HashMap::new(), - all_for_nodes, - } - } -} - -impl Pass for Unrolling<'_> { - type Input<'a> = Mir; - type Output<'a> = Mir; - type Error = CompileError; - - fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { - // The first pass unrolls all nodes fully, except for For nodes - let mut first_pass = UnrollingFirstPass::new(self.diagnostics); - Visitor::run(&mut first_pass, ir.constraint_graph_mut())?; - - // The second pass actually inlines the For nodes - let mut second_pass = UnrollingSecondPass::new( - self.diagnostics, - first_pass.bodies_to_inline.clone(), - first_pass.all_for_nodes.clone(), - ); - Visitor::run(&mut second_pass, ir.constraint_graph_mut())?; - Ok(ir) - } -} - -// For the first pass of Unrolling, we use a tweeked version of the Visitor trait, -// each visit_*_bis function returns an Option> instead of Result<(), CompileError>, -// to mutate the nodes (e.g. modifying a Operation to Vector) -impl UnrollingFirstPass<'_> { - fn visit_value_bis( - &mut self, - _graph: &mut Graph, - value: Link, - ) -> Result>, CompileError> { - // safe to unwrap because we just dispatched on it - let mut updated_value = None; - - { - let value_ref = value.as_value().unwrap(); - let mir_value = value_ref.value.value.clone(); - match mir_value { - MirValue::Constant(c) => match c { - ConstantValue::Felt(_) => {} - ConstantValue::Vector(v) => { - let mut vec = vec![]; - for val in v { - let val = Value::create(SpannedMirValue { - span: value_ref.value.span, - value: MirValue::Constant(ConstantValue::Felt(val)), - }); - vec.push(val); - } - updated_value = Some(Vector::create(vec, value_ref.span())); - } - ConstantValue::Matrix(m) => { - let mut res_m = vec![]; - for row in m { - let mut res_row = vec![]; - for val in row { - let val = Value::create(SpannedMirValue { - span: value_ref.value.span, - value: MirValue::Constant(ConstantValue::Felt(val)), - }); - res_row.push(val); - } - let res_row_vec = Vector::create(res_row, value_ref.span()); - res_m.push(res_row_vec); - } - updated_value = Some(Matrix::create(res_m, value_ref.span())); - } - }, - MirValue::TraceAccess(_) => {} - MirValue::PeriodicColumn(_) => {} - MirValue::PublicInput(_) => {} - MirValue::PublicInputTable(_) => {} - MirValue::RandomValue(_) => {} - MirValue::TraceAccessBinding(trace_access_binding) => { - // Create Trace Access based on this binding - if trace_access_binding.size == 1 { - let val = Value::create(SpannedMirValue { - span: value_ref.value.span, - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, // ??? - }), - }); - updated_value = Some(val); - } else { - let mut vec = vec![]; - for index in 0..trace_access_binding.size { - let val = Value::create(SpannedMirValue { - span: value_ref.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset + index, - row_offset: 0, // ??? - }), - }); - vec.push(val); - } - updated_value = Some(Vector::create(vec, value_ref.span())); - } - } - MirValue::BusAccess(_) => {} - MirValue::Null => {} - MirValue::Unconstrained => {} - } - } - - Ok(updated_value) - } - - fn visit_add_bis( - &mut self, - _graph: &mut Graph, - add: Link, - ) -> Result>, CompileError> { - // safe to un wrap because we just dispatched on it - - let mut updated_add = None; - - { - let add_ref = add.as_add().unwrap(); - let lhs = add_ref.lhs.clone(); - let rhs = add_ref.rhs.clone(); - - if let (Op::Vector(lhs_vector), Op::Vector(rhs_vector)) = - (lhs.borrow().deref(), rhs.borrow().deref()) - { - let lhs_vec = lhs_vector.children().borrow().deref().clone(); - let rhs_vec = rhs_vector.children().borrow().deref().clone(); - - if lhs_vec.len() != rhs_vec.len() { - // Raise diag - todo!(); - } else { - let mut new_vec = vec![]; - for (lhs, rhs) in lhs_vec.iter().zip(rhs_vec.iter()) { - let new_node = Add::create(lhs.clone(), rhs.clone(), add_ref.span()); - new_vec.push(new_node); - } - updated_add = Some(Vector::create(new_vec, add_ref.span())); - } - }; - } - - Ok(updated_add) - } - - fn visit_sub_bis( - &mut self, - _graph: &mut Graph, - sub: Link, - ) -> Result>, CompileError> { - // safe to unwrap because we just dispatched on it - - let mut updated_sub = None; - - let sub_ref = sub.as_sub().unwrap(); - let lhs = sub_ref.lhs.clone(); - let rhs = sub_ref.rhs.clone(); - - if let (Op::Vector(lhs_vector), Op::Vector(rhs_vector)) = - (lhs.borrow().deref(), rhs.borrow().deref()) - { - let lhs_vec = lhs_vector.children().borrow().deref().clone(); - let rhs_vec = rhs_vector.children().borrow().deref().clone(); - - if lhs_vec.len() != rhs_vec.len() { - // Raise diag - } else { - let mut new_vec = vec![]; - for (lhs, rhs) in lhs_vec.iter().zip(rhs_vec.iter()) { - let new_node = Sub::create(lhs.clone(), rhs.clone(), sub_ref.span()); - new_vec.push(new_node); - } - updated_sub = Some(Vector::create(new_vec, sub_ref.span())); - } - }; - - Ok(updated_sub) - } - - fn visit_mul_bis( - &mut self, - _graph: &mut Graph, - mul: Link, - ) -> Result>, CompileError> { - let mut updated_mul = None; - - { - let mul_ref = mul.as_mul().unwrap(); - let lhs = mul_ref.lhs.clone(); - let rhs = mul_ref.rhs.clone(); - - if let (Op::Vector(lhs_vector), Op::Vector(rhs_vector)) = - (lhs.borrow().deref(), rhs.borrow().deref()) - { - let lhs_vec = lhs_vector.children().borrow().deref().clone(); - let rhs_vec = rhs_vector.children().borrow().deref().clone(); - - if lhs_vec.len() != rhs_vec.len() { - // Raise diag - } else { - let mut new_vec = vec![]; - for (lhs, rhs) in lhs_vec.iter().zip(rhs_vec.iter()) { - let new_node = Mul::create(lhs.clone(), rhs.clone(), mul_ref.span()); - new_vec.push(new_node); - } - updated_mul = Some(Vector::create(new_vec, mul_ref.span())); - } - }; - } - - Ok(updated_mul) - } - - fn visit_exp_bis( - &mut self, - _graph: &mut Graph, - exp: Link, - ) -> Result>, CompileError> { - let mut updated_exp = None; - - { - let exp_ref = exp.as_exp().unwrap(); - let lhs = exp_ref.lhs.clone(); - let rhs = exp_ref.rhs.clone(); - - if let (Op::Vector(lhs_vector), Op::Vector(rhs_vector)) = - (lhs.borrow().deref(), rhs.borrow().deref()) - { - let lhs_vec = lhs_vector.children().borrow().deref().clone(); - let rhs_vec = rhs_vector.children().borrow().deref().clone(); - - if lhs_vec.len() != rhs_vec.len() { - // Raise diag - } else { - let mut new_vec = vec![]; - for (lhs, rhs) in lhs_vec.iter().zip(rhs_vec.iter()) { - let new_node = Exp::create(lhs.clone(), rhs.clone(), exp_ref.span()); - new_vec.push(new_node); - } - updated_exp = Some(Vector::create(new_vec, exp_ref.span())); - } - }; - } - - Ok(updated_exp) - } - - fn visit_enf_bis( - &mut self, - _graph: &mut Graph, - enf: Link, - ) -> Result>, CompileError> { - let mut updated_enf = None; - - { - let enf_ref = enf.as_enf().unwrap(); - let expr = enf_ref.expr.clone(); - if let Op::Vector(vec) = expr.borrow().deref() { - let ops = vec.children().borrow().deref().clone(); - let mut new_vec = vec![]; - for op in ops.iter() { - let new_node = Enf::create(op.clone(), enf_ref.span()); - new_vec.push(new_node); - } - updated_enf = Some(Vector::create(new_vec, enf_ref.span())); - }; - } - - Ok(updated_enf) - } - - fn visit_fold_bis( - &mut self, - _graph: &mut Graph, - fold: Link, - ) -> Result>, CompileError> { - let updated_fold; - - { - let fold_ref = fold.as_fold().unwrap(); - let iterator = fold_ref.iterator.clone(); - let operator = fold_ref.operator.clone(); - let initial_value = fold_ref.initial_value.clone(); - - let iterator_ref = iterator.borrow(); - let Op::Vector(iterator_vector) = iterator_ref.deref() else { - unreachable!(); - }; - let iterator_nodes = iterator_vector.children().borrow().deref().clone(); - - let mut acc_node = initial_value; - match operator { - FoldOperator::Add => { - for iterator_node in iterator_nodes { - let new_acc_node = Add::create(acc_node, iterator_node, fold_ref.span()); - acc_node = new_acc_node; - } - } - FoldOperator::Mul => { - for iterator_node in iterator_nodes { - let new_acc_node = Mul::create(acc_node, iterator_node, fold_ref.span()); - acc_node = new_acc_node; - } - } - FoldOperator::None => {} - } - updated_fold = Some(acc_node); - } - - Ok(updated_fold) - } - - fn visit_parameter_bis( - &mut self, - _graph: &mut Graph, - parameter: Link, - ) -> Result>, CompileError> { - // FIXME: Just check that the parameter is a scalar, raise diag otherwise - // List comprehension bodies should only be scalar expressions - - let owner_ref = parameter - .as_parameter() - .unwrap() - .ref_node - .to_link() - .unwrap_or_else(|| panic!("Ref node invalid")); - - self.params_for_ref_node - .entry(owner_ref.get_ptr()) - .or_default() - .push(parameter.clone()); - Ok(None) - } - - /*fn visit_if_old(&mut self, _graph: &mut Graph, if_node: Link) -> Result>, CompileError>{ - let if_ref = if_node.as_if().unwrap(); - let condition = if_ref.condition.clone(); - let then_branch = if_ref.then_branch.clone(); - let else_branch = if_ref.else_branch.clone(); - - if let ( - Op::Vector(condition_vector), - Op::Vector(then_branch_vector), - Op::Vector(else_branch_vector), - ) = ( - condition.borrow().deref(), - then_branch.borrow().deref(), - else_branch.borrow().deref(), - ) { - let condition_vec = condition_vector.children().borrow().deref().clone(); - let then_branch_vec = then_branch_vector.children().borrow().deref().clone(); - let else_branch_vec = else_branch_vector.children().borrow().deref().clone(); - - if condition_vec.len() != then_branch_vec.len() - || condition_vec.len() != else_branch_vec.len() - { - // Raise diag - } else { - let mut new_vec = vec![]; - for ((condition, then_branch), else_branch) in condition_vec - .iter() - .zip(then_branch_vec.iter()) - .zip(else_branch_vec.iter()) - { - let new_node = - If::create(condition.clone(), then_branch.clone(), else_branch.clone()); - new_vec.push(new_node); - } - if_node.set(&Vector::create(new_vec)); - } - }; - - Ok(()) - }*/ - - fn visit_if_bis( - &mut self, - _graph: &mut Graph, - if_node: Link, - ) -> Result>, CompileError> { - let updated_if; - - { - let if_ref = if_node.as_if().unwrap(); - let condition = if_ref.condition.clone(); - let then_branch = if_ref.then_branch.clone(); - let else_branch = if_ref.else_branch.clone(); - - let mut new_vec = vec![]; - - if let Op::Vector(then_branch_vector) = then_branch.clone().borrow().deref() { - let then_branch_vec = then_branch_vector.children().borrow().deref().clone(); - - for then_branch in then_branch_vec { - let new_node = Mul::create(condition.clone(), then_branch, if_ref.span()); - new_vec.push(new_node); - } - } else { - let new_node = Mul::create(condition.clone(), then_branch, if_ref.span()); - new_vec.push(new_node); - } - - let one_constant = SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(1)), - }; - - if let Op::Vector(else_branch_vector) = else_branch.clone().borrow().deref() { - let else_branch_vec = else_branch_vector.children().borrow().deref().clone(); - - for else_branch in else_branch_vec { - let span = else_branch.span(); - let new_node = Mul::create( - Sub::create(Value::create(one_constant.clone()), condition.clone(), span), - else_branch, - span, - ); - new_vec.push(new_node); - } - } else { - let span = else_branch.span(); - let new_node = Mul::create( - Sub::create(Value::create(one_constant.clone()), condition.clone(), span), - else_branch, - span, - ); - new_vec.push(new_node); - } - - updated_if = Some(Vector::create(new_vec, if_ref.span())); - } - - Ok(updated_if) - } - - fn visit_boundary_bis( - &mut self, - _graph: &mut Graph, - boundary: Link, - ) -> Result>, CompileError> { - let mut updated_boundary = None; - - { - // safe to unwrap because we just dispatched on it - let boundary_ref = boundary.as_boundary().unwrap(); - let expr = boundary_ref.expr.clone(); - let kind = boundary_ref.kind; - - if let Op::Vector(vec) = expr.borrow().deref() { - let expr_vec = vec.children().borrow().deref().clone(); - let mut new_vec = vec![]; - for expr in expr_vec.iter() { - let new_node = Boundary::create(expr.clone(), kind, boundary_ref.span()); - new_vec.push(new_node); - } - updated_boundary = Some(Vector::create(new_vec, boundary_ref.span())); - }; - } - - Ok(updated_boundary) - } - - fn visit_accessor_bis( - &mut self, - _graph: &mut Graph, - accessor: Link, - ) -> Result>, CompileError> { - let mut updated_accessor = None; - - { - let accessor_ref = accessor.as_accessor().unwrap(); - let indexable = accessor_ref.indexable.clone(); - let access_type = accessor_ref.access_type.clone(); - let offset = accessor_ref.offset; - - if indexable.clone().as_parameter().is_none() { - match access_type { - AccessType::Default => { - /*// Check that the child node is a scalar, raise diag otherwise - if indexable.clone().as_vector().is_some() { - unreachable!(); // raise diag - } - if indexable.clone().as_matrix().is_some() { - unreachable!(); // raise diag - }*/ - updated_accessor = Some(indexable.clone()); - - if let Some(value) = indexable.clone().as_value() { - let mir_value = value.value.value.clone(); - - match mir_value { - MirValue::TraceAccess(trace_access) => { - let new_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + offset, - }), - }); - updated_accessor = Some(new_node); - } - _ => unreachable!(), - } - } - } - AccessType::Index(index) => { - // Check that the child node is a vector, raise diag otherwise - // Replace the current node by the index-th element of the vector - // Raise diag if index is out of bounds - - if let Op::Vector(indexable_vector) = indexable.borrow().deref() { - let indexable_vec = - indexable_vector.children().borrow().deref().clone(); - let child_accessed = match indexable_vec.get(index) { - Some(child_accessed) => child_accessed, - None => unreachable!(), // raise diag - }; - if let Some(value) = child_accessed.clone().as_value() { - let mir_value = value.value.value.clone(); - match mir_value { - MirValue::TraceAccess(trace_access) => { - let new_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + offset, - }), - }); - updated_accessor = Some(new_node); - } - _ => { - updated_accessor = Some(child_accessed.clone()); - } - } - } else { - updated_accessor = Some(child_accessed.clone()); - } - } else { - unreachable!("indexable is {:?}", indexable); // raise diag - }; - } - AccessType::Matrix(row, col) => { - // Check that the child node is a matrix, raise diag otherwise - // Replace the current node by the index-th element of the vector - // Raise diag if index is out of bounds - - if let Op::Vector(indexable_vector) = indexable.borrow().deref() { - let indexable_vec = - indexable_vector.children().borrow().deref().clone(); - let row_accessed = match indexable_vec.get(row) { - Some(row_accessed) => row_accessed, - None => unreachable!(), // raise diag - }; - - if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { - let row_accessed_vec = - row_accessed_vector.children().borrow().deref().clone(); - let child_accessed = match row_accessed_vec.get(col) { - Some(child_accessed) => child_accessed, - None => unreachable!(), // raise diag - }; - updated_accessor = Some(child_accessed.clone()); - } else { - unreachable!(); // raise diag - }; - } else if let Op::Matrix(indexable_matrix) = indexable.borrow().deref() { - let indexable_vec = - indexable_matrix.children().borrow().deref().clone(); - let row_accessed = match indexable_vec.get(row) { - Some(row_accessed) => row_accessed, - None => unreachable!(), // raise diag - }; - - if let Op::Vector(row_accessed_vector) = row_accessed.borrow().deref() { - let row_accessed_vec = - row_accessed_vector.children().borrow().deref().clone(); - let child_accessed = match row_accessed_vec.get(col) { - Some(child_accessed) => child_accessed, - None => unreachable!(), // raise diag - }; - updated_accessor = Some(child_accessed.clone()); - } else { - unreachable!(); // raise diag - }; - }; - } - - AccessType::Slice(_range_expr) => { - unreachable!(); // Slices are not scalar, raise diag - } - } - } - } - - Ok(updated_accessor) - } - - fn compute_iterator_len(iterator: Link) -> usize { - match iterator.borrow().deref() { - Op::Vector(vector) => vector.size, - Op::Matrix(matrix) => matrix.size, - Op::Accessor(accessor) => match &accessor.access_type { - AccessType::Default => Self::compute_iterator_len(accessor.indexable.clone()), - AccessType::Slice(range_expr) => range_expr.to_slice_range().count(), - AccessType::Index(_) => match accessor.indexable.borrow().deref() { - Op::Vector(_) => 1, - Op::Matrix(matrix) => { - let children = matrix.children().borrow().deref().clone(); - match children.first() { - Some(first_row) => match first_row.as_vector() { - Some(row_vector) => row_vector.size, - _ => unreachable!(), // Raise diag - }, - None => { - unreachable!(); // Raise diag - } - } - } - _ => unreachable!(), // Raise diag - }, - AccessType::Matrix(_, _) => 1, - }, - Op::Parameter(parameter) => match parameter.ty { - MirType::Felt => 1, - MirType::Vector(l) => l, - MirType::Matrix(l, _) => l, - }, - _ => 1, - } - } - - fn visit_for_bis( - &mut self, - _graph: &mut Graph, - for_node: Link, - ) -> Result>, CompileError> { - let updated_for; - - { - // For each value produced by the iterators, we need to: - // - Duplicate the body - // - Visit the body and replace the Variables with the value (with the correct index depending on the binding) - // If there is a selector, we need to enforce the selector on the body through an if node ? - - let for_node_clone = for_node.clone(); - let for_ref = for_node_clone.as_for().unwrap(); - let iterators_ref = for_ref.iterators.borrow(); - let iterators = iterators_ref.deref(); - let expr = for_ref.expr.clone(); - let selector = for_ref.selector.clone(); - - // Check iterator lengths - if iterators.is_empty() { - unreachable!(); // Raise diag - } - let iterator_expected_len = Self::compute_iterator_len(iterators[0].clone()); - - for iterator in iterators.iter().skip(1) { - let iterator_len = Self::compute_iterator_len(iterator.clone()); - if iterator_len != iterator_expected_len { - unreachable!() - // Raise diag - } - } - - let mut new_vec = vec![]; - - for i in 0..iterator_expected_len { - let new_node = - Parameter::create(i, MirType::Felt, for_node.as_for().unwrap().deref().span()); - new_vec.push(new_node.clone()); - - let iterators_i = iterators - .iter() - .map(|op| { - match op.borrow().deref() { - Op::Vector(vector) => { - let children = vector.children().borrow().deref().clone(); - children[i].clone() - } - Op::Matrix(matrix) => { - let children = matrix.children().borrow().deref().clone(); - children[i].clone() - } - Op::Accessor(accessor) => { - match accessor.indexable.borrow().deref() { - // If we access an outer loop parameter in the body of an inner loop, - // we need to create an Accessor for the correct index in this parameter - Op::Parameter(_parameter) => Accessor::create( - accessor.indexable.clone(), - AccessType::Index(i), - 0, - accessor.span(), - ), - _ => op.clone(), - } - } - _ => op.clone(), - } - }) - .collect::>(); - let selector = if let Op::None(_) = selector.borrow().deref() { - None - } else { - Some(selector.clone()) - }; - - self.bodies_to_inline.push(( - new_node.clone(), - ForInliningContext { - body: expr.clone(), - iterators: iterators_i, - selector, - ref_node: for_node.clone(), - }, - )); - } - - let new_vec_op = Vector::create(new_vec.clone(), for_node.span()); - for param in new_vec { - param - .as_parameter_mut() - .unwrap() - .set_ref_node(new_vec_op.as_owner().unwrap()); - } - updated_for = Some(new_vec_op); - } - - Ok(updated_for) - } - - fn visit_call_bis( - &mut self, - _graph: &mut Graph, - _call: Link, - ) -> Result>, CompileError> { - unreachable!("Calls should have been inlined before this pass"); - } - - fn visit_vector_bis( - &mut self, - _graph: &mut Graph, - vector: Link, - ) -> Result>, CompileError> { - let mut updated_vector = None; - - { - // safe to unwrap because we just dispatched on it - let vector_ref = vector.as_vector().unwrap(); - let children = vector_ref.elements.borrow().deref().clone(); - let size = vector_ref.size; - - if size == 1 { - let child = children.first().unwrap(); - updated_vector = Some(child.clone()); - } - } - - Ok(updated_vector) - //Ok(None) - } - fn visit_matrix_bis( - &mut self, - _graph: &mut Graph, - _matrix: Link, - ) -> Result>, CompileError> { - Ok(None) - } -} - -impl Visitor for UnrollingFirstPass<'_> { - fn work_stack(&mut self) -> &mut Vec> { - &mut self.work_stack - } - // We visit all boundary constraints and all integrity constraints - // No need to visit the functions or evaluators, as they should have been inlined before this pass - fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { - let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); - let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); - let bus_roots: Vec<_> = graph - .buses - .values() - .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) - .collect(); - let combined_roots = boundary_constraints_roots_ref - .clone() - .into_iter() - .map(|bc| bc.as_node()) - .chain( - integrity_constraints_roots_ref - .clone() - .into_iter() - .map(|ic| ic.as_node()), - ) - .chain(bus_roots.into_iter().map(|b| b.as_node())); - combined_roots.collect() - } - - fn visit_node(&mut self, graph: &mut Graph, node: Link) -> Result<(), CompileError> { - // We keep a reference to all For nodes to avoid dropping the backlinks stored in Parameters - if let Some(owner) = node.clone().as_owner() { - if let Some(op) = owner.clone().as_op() { - if let Some(_for_node) = op.as_for() { - self.all_for_nodes - .insert(op.get_ptr(), (op.clone(), owner.clone())); - } - } - } - - let updated_op: Result>, CompileError> = match node.borrow().deref() { - Node::Function(_f) => { - unreachable!("Functions should have been inlined before this pass") - } - Node::Evaluator(_e) => { - unreachable!("Evaluators should have been inlined before this pass") - } - Node::Enf(e) => to_link_and(e.clone(), graph, |g, el| self.visit_enf_bis(g, el)), - Node::Boundary(b) => { - to_link_and(b.clone(), graph, |g, el| self.visit_boundary_bis(g, el)) - } - Node::Add(a) => to_link_and(a.clone(), graph, |g, el| self.visit_add_bis(g, el)), - Node::Sub(s) => to_link_and(s.clone(), graph, |g, el| self.visit_sub_bis(g, el)), - Node::Mul(m) => to_link_and(m.clone(), graph, |g, el| self.visit_mul_bis(g, el)), - Node::Exp(e) => to_link_and(e.clone(), graph, |g, el| self.visit_exp_bis(g, el)), - Node::If(i) => to_link_and(i.clone(), graph, |g, el| self.visit_if_bis(g, el)), - Node::For(f) => to_link_and(f.clone(), graph, |g, el| self.visit_for_bis(g, el)), - Node::Call(c) => to_link_and(c.clone(), graph, |g, el| self.visit_call_bis(g, el)), - Node::Fold(f) => to_link_and(f.clone(), graph, |g, el| self.visit_fold_bis(g, el)), - Node::Vector(v) => to_link_and(v.clone(), graph, |g, el| self.visit_vector_bis(g, el)), - Node::Matrix(m) => to_link_and(m.clone(), graph, |g, el| self.visit_matrix_bis(g, el)), - Node::Accessor(a) => { - to_link_and(a.clone(), graph, |g, el| self.visit_accessor_bis(g, el)) - } - Node::BusOp(_b) => Ok(None), - Node::Parameter(p) => { - to_link_and(p.clone(), graph, |g, el| self.visit_parameter_bis(g, el)) - } - Node::Value(v) => to_link_and(v.clone(), graph, |g, el| self.visit_value_bis(g, el)), - Node::None(_) => Ok(None), - }; - - // We update the node if needed - if let Some(updated_op) = updated_op? { - node.as_op().unwrap().set(&updated_op); - } - - Ok(()) - } -} - -impl Visitor for UnrollingSecondPass<'_> { - fn work_stack(&mut self) -> &mut Vec> { - &mut self.work_stack - } - // The root nodes visited during the second pass are the children of For nodes to inline - fn root_nodes_to_visit(&self, _graph: &Graph) -> Vec> { - self.bodies_to_inline - .iter() - .map(|(k, _v)| k) - .cloned() - .map(|op| op.as_node()) - .collect::>() - } - fn run(&mut self, graph: &mut Graph) -> Result<(), CompileError> { - for root in self.root_nodes_to_visit(graph).iter() { - // Set context to inline the body for this index - let for_inlining_context = self.bodies_to_inline.iter().find_map(|(node, context)| { - if Rc::ptr_eq(&node.clone().as_node().link, &root.link) { - Some(context.clone()) - } else { - None - } - }); - - self.for_inlining_context = for_inlining_context; - // We inline a new body, so we clear the nodes to replace and the parameters for the ref node - self.nodes_to_replace.clear(); - self.params_for_ref_node.clear(); - - self.scan_node( - graph, - self.for_inlining_context.clone().unwrap().body.as_node(), - )?; - while let Some(node) = self.work_stack().pop() { - self.visit_node(graph, node.clone())?; - } - - // We have finished inlining the body, we can now replace the Root node with the body - let body = self.for_inlining_context.clone().unwrap().body; - let new_node = self - .nodes_to_replace - .get(&body.get_ptr()) - .unwrap() - .1 - .clone(); - - // If there is a selector, we need to enforce it on the body - let new_node_with_selector_if_needed = - if let Some(selector) = self.for_inlining_context.clone().unwrap().selector { - let zero_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), - }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> Enf(Sub(lhs, rhs) == 0), - // but it introduces an unnecessary zero node - Sub::create( - Mul::create(selector, new_node, root.span()), - zero_node, - root.span(), - ) - } else { - new_node - }; - - root.as_op().unwrap().set(&new_node_with_selector_if_needed); - - // Reset context to None - self.for_inlining_context = None; - } - - Ok(()) - } - fn visit_node(&mut self, _graph: &mut Graph, node: Link) -> Result<(), CompileError> { - if node.is_stale() { - return Ok(()); - } - if let Some(op) = node.clone().as_op() { - duplicate_node_or_replace( - &mut self.nodes_to_replace, - op, - self.for_inlining_context.clone().unwrap().iterators.clone(), - self.for_inlining_context - .clone() - .unwrap() - .ref_node - .as_node(), - Some( - self.all_for_nodes - .get( - &self - .for_inlining_context - .clone() - .unwrap() - .ref_node - .get_ptr(), - ) - .unwrap() - .1 - .clone(), - ), - &mut self.params_for_ref_node, - ); - } else { - unreachable!( - "UnrollingSecondPass::visit_node on a non-Op node: {:?}", - node - ); - } - Ok(()) - } -} - -fn to_link_and( - back: BackLink, - graph: &mut Graph, - f: F, -) -> Result>, CompileError> -where - F: FnOnce(&mut Graph, Link) -> Result>, CompileError>, -{ - if let Some(op) = back.to_link() { - f(graph, op) - } else { - Ok(None) - } -} diff --git a/mir/src/passes/unrolling/match_optimizer.rs b/mir/src/passes/unrolling/match_optimizer.rs new file mode 100644 index 000000000..5c2cafcfb --- /dev/null +++ b/mir/src/passes/unrolling/match_optimizer.rs @@ -0,0 +1,293 @@ +//! This module provides functionality for optimizing match expressions (`If` nodes in the MIR) +//! +//! A given `If` node contains match arms in the form of `(condition, expr)` pairs, where `expr` can +//! be one or more constraints that should be enforced on the associated condition. The goal of this +//! module is to optimize the node by: +//! - Removing duplicate constraints on the same condition, +//! - Combining `(condition, expr)` and `(condition2, expr2)` with disjoint conditions into a single +//! constraint `condition * expr + condition2 * expr2` +//! - Factorizing equivalent constraints (e.g. `(condition, expr)` and `(condition2, expr)`) in two +//! separate match arms into a single constraint `(condition1 + condition2) * expr1` +//! +//! The methodology to achieve these three goals is to rely on random evaluation of nodes to +//! identify equivalent constraints: +//! 1. We evaluate each constraint in the node at random points +//! 2. We group main constraints that evaluate to the same value +//! 3. We iterate over these groups (from biggest to smallest), and build a constraint by combining +//! groups that have disjoint conditions. For instance, if we have the evaluations `eval_1 = [ +//! (s0, A), (s1, B) ]` and `eval_2 = [ (s2, C) ]`, we can combine them into a single constraint +//! `(s0 + s1) * A + s2 * C` because the sets {s0, s1} and {s2} are disjoint. Note that `B` is +//! not included because it is equivalent to `A`, so we can remove the redundancy. +//! 4. We add bus-related constraints to the produced constraints, as they cannot be handled in the +//! same way as main constraints. + +use std::{ + collections::{BTreeMap, HashMap}, + ops::Deref, +}; + +use miden_diagnostics::{SourceSpan, Spanned}; + +use crate::{ + CompileError, + ir::{ + Add, ConstantValue, Enf, Link, MatchArm, MirValue, Mul, Op, Parent, QuadFelt, RandomInputs, + SpannedMirValue, Sub, Value, + }, + passes::duplicate_node, +}; + +/// A map used to keep track of the evaluation of each constraint: +/// The key is an index for the evaluation +/// The value is a vector of each pair `(condition, constraint)` where the constraint evaluates to +/// the given evaluation. +/// +/// For example, for an `If` node with two match arms `(s0, [A, B])` and `(s1, C)`, if A and C are +/// two constraints evaluating to the same value, we will construct the map { +/// 0: [(s0, A), (s1, C)], +/// 1: [(s0, B)] +/// } +type ConstraintEvaluationMap = BTreeMap, Link)>>; + +/// This struct provides methods used to combine and optimize constraints contained in match +/// statements (corresponding to `If` nodes in the MIR) +#[derive(Debug)] +pub struct MatchOptimizer<'a> { + // A reference to the random inputs used to evaluate the constraints + // (common for all matches) + random_inputs: &'a mut RandomInputs, + // Used to keep track of the evaluations values we've seen so far + // in the current `If` node. + node_evals: Vec, + // Used to keep track of the evaluation of each constraint + // We use indices to the node_evals vector as keys, to avoid non-determinism for + // iterating across random evaluations. + constraints_evaluation_indices: ConstraintEvaluationMap, +} + +impl<'a> MatchOptimizer<'a> { + /// Instantiates a new `MatchOptimizer` with the given `RandomInputs`. + pub fn new(random_inputs: &'a mut RandomInputs) -> Self { + Self { + random_inputs, + node_evals: Vec::new(), + constraints_evaluation_indices: ConstraintEvaluationMap::new(), + } + } + + /// For a given match arm: + /// - splits main and bus-related constraints + /// - returns bus-related constraints (without the wrapping `Op:Enf` for simplicity) + /// - evaluates the main constraints at random points, and groups constraints that evaluate to + /// the same value in `constraints_evaluation_indices` + pub fn evaluate_match_arm( + &mut self, + match_arm: &MatchArm, + ) -> Result>, CompileError> { + let condition = match_arm.condition.clone(); + let expr = match_arm.expr.clone(); + + // 1.1. Get all the individual constraints corresponding to this arm + let all_constraints = flatten_constraints(&expr); + + let mut constraints_to_eval = Vec::new(); + let mut bus_related_constraints_for_match_arm = Vec::new(); + + // 1.2. Filter out all BusOp nodes from the constraints, we will handle them separately + // Bus operations can only be found as an `Op::BusOp` wrapped in an `Op::Enf`, for + // simplicity we unwrap it here and re-wrap it in `gather_all_constraints`. + for constraint in all_constraints { + match constraint.borrow().deref() { + Op::Enf(enf) => match enf.expr.borrow().deref() { + Op::BusOp(_) => bus_related_constraints_for_match_arm.push(enf.expr.clone()), + _ => constraints_to_eval.push(constraint.clone()), + }, + Op::BusOp(_) => { + unreachable!("Error: A BusOp was found outside of an Enf node in a match arm") + }, + _ => constraints_to_eval.push(constraint.clone()), + } + } + + // 1.3. Evaluate all the other constraints at random points + for constraint in constraints_to_eval { + let eval = self.random_inputs.eval(constraint.clone())?; + // Check if we already have this eval in our list + match self.node_evals.iter().position(|e| e == &eval) { + Some(index) => { + // If the eval is already in the list, we use its index in the + // node_evals vec + let constraints_vec = + self.constraints_evaluation_indices.get_mut(&index).expect( + "Error: An evaluation index was not in constraints_evaluation_indices", + ); + + // We add the constraint to the corresponding vector only if the + // condition is not already present + // to remove duplicate constraints for the same selector + if !constraints_vec.iter().any(|(c, _)| c == &condition) { + constraints_vec.push((condition.clone(), constraint.clone())) + } + }, + None => { + // Otherwise, we add it to the list and use its index + self.node_evals.push(eval); + let index = self.node_evals.len() - 1; + if self + .constraints_evaluation_indices + .insert(index, vec![(condition.clone(), constraint.clone())]) + .is_some() + { + unreachable!( + "Error: A new evaluation index was found in constraints_evaluation_indices" + ); + } + }, + } + } + Ok(bus_related_constraints_for_match_arm) + } + + /// For each evaluation, computes the number of constraints (with different selectors) + /// evaluating to it. + /// Returns eval_lens: BTreeMap> + fn compute_eval_lens(&self) -> BTreeMap> { + let mut eval_lens = BTreeMap::new(); + for (eval_index, constraints_vec) in self.constraints_evaluation_indices.iter() { + eval_lens + .entry(constraints_vec.len()) + .and_modify(|v: &mut Vec<_>| v.push(*eval_index)) + .or_insert(vec![*eval_index]); + } + eval_lens + } + + /// Combines the constraints in an optimized way, based on their evaluations. + pub fn reduce_main_constraints(&mut self, span: SourceSpan) -> Vec> { + let mut new_vec = vec![]; + + let mut eval_lens = self.compute_eval_lens(); + + // Note: this will always terminate as we always remove at least one constraint per + // iteration + while !self.constraints_evaluation_indices.is_empty() { + let mut new_constraint = vec![]; + let mut taken_eval_indices = vec![]; + + // Start picking constraints from the biggest sets that evaluate to the same values + for eval_index in eval_lens.values().cloned().rev().flatten() { + let constraints = self.constraints_evaluation_indices.get(&eval_index).unwrap(); + + if have_disjoint_conditions(new_constraint.clone(), constraints) { + // If the new constraints are disjoint from the current ones, we can add + // them + new_constraint.push(constraints.clone()); + taken_eval_indices.push(eval_index); + } + } + + // Remove the picked evaluation indices from all structures + for eval_index in taken_eval_indices { + self.constraints_evaluation_indices.remove(&eval_index); + eval_lens.iter_mut().for_each(|(_len, evals)| { + evals.retain(|&e_idx| e_idx != eval_index); + }); + eval_lens.retain(|_, constraints| !constraints.is_empty()); + } + + // Create combined constraint + let mut cur_node = None; + for equivalent_constraints in new_constraint { + let mut cur_condition = None; + for (condition, _) in equivalent_constraints.clone() { + if cur_condition.is_none() { + cur_condition = Some(condition.clone()); + } else { + cur_condition = + Some(Add::create(cur_condition.unwrap(), condition.clone(), span)); + } + } + let new_node = Mul::create( + cur_condition.unwrap(), + equivalent_constraints.first().unwrap().1.clone(), + span, + ); + + cur_node = match cur_node { + Some(existing) => Some(Add::create(existing, new_node, span)), + None => Some(new_node), + }; + } + + // Note: This will ensure the resulting constraint is in the form `Sub(x,y)` + // representing `enf x = y` + let zero_node = Value::create(SpannedMirValue { + span: Default::default(), + value: MirValue::Constant(ConstantValue::Felt(0)), + }); + // The following unwrap is safe as we always have at least one constraint above + let new_node_with_sub_zero = Sub::create(cur_node.unwrap(), zero_node, span); + + new_vec.push(new_node_with_sub_zero); + } + new_vec + } + + /// Adds bus-related constraints to the resulting constraints vector. + /// + /// Note: Given a slice of bus-related constraints in the form of `(condition, constraints)`, we + /// update the latch of each constraint to take into account the condition before adding the + /// resulting constraint into the final vector. + pub fn gather_all_constraints( + bus_related_constraints: &mut [(Link, Vec>)], + combined_main_constraints: Vec>, + span: SourceSpan, + ) -> Vec> { + let mut all_constraints = combined_main_constraints; + for (condition, constraints) in bus_related_constraints.iter_mut() { + for constraint in constraints.iter_mut() { + let cur_latch = constraint.as_bus_op().unwrap().latch.clone(); + let new_latch = Mul::create( + condition.clone(), + duplicate_node(cur_latch, &mut HashMap::new()), + span, + ); + constraint + .as_bus_op_mut() + .unwrap() + .latch + .borrow_mut() + .clone_from(&new_latch.borrow()); + let enf_constraint = Enf::create(constraint.clone(), constraint.span()); + all_constraints.push(enf_constraint); + } + } + all_constraints + } +} + +/// Helper function used to check whether two sets of constraints have disjoint conditions. +/// This is used to combine constraints in the `visit_if_bis` method, as we can only combine +/// constraints that have disjoint selectors. +fn have_disjoint_conditions( + cur_constraints: Vec, Link)>>, + new_constraints: &[(Link, Link)], +) -> bool { + for (condition, _) in new_constraints { + if cur_constraints.iter().flatten().any(|(c, _)| c == condition) { + return false; // Duplicate condition found + } + } + true +} + +/// Flattens a constraint expression (potentially containing nested vectors) into a vector of +/// individual constraints. +fn flatten_constraints(op: &Link) -> Vec> { + match op.borrow().deref() { + Op::Vector(vec) => { + vec.children().borrow().deref().iter().flat_map(flatten_constraints).collect() + }, + _ => vec![op.clone()], + } +} diff --git a/mir/src/passes/unrolling/mod.rs b/mir/src/passes/unrolling/mod.rs new file mode 100644 index 000000000..183b1fe86 --- /dev/null +++ b/mir/src/passes/unrolling/mod.rs @@ -0,0 +1,219 @@ +use std::ops::Deref; + +use air_pass::Pass; +use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; + +use super::visitor::Visitor; +use crate::{CompileError, ir::*}; + +mod match_optimizer; +mod unrolling_first_pass; +mod unrolling_second_pass; +mod unrolling_third_pass; + +use unrolling_first_pass::UnrollingFirstPass; +use unrolling_second_pass::UnrollingSecondPass; +use unrolling_third_pass::UnrollingThirdPass; + +/// This pass follows a similar approach as the Inlining pass and requires that the latter has +/// already been done. +/// +/// * In the first step, we visit the graph, unrolling each node type except `For` nodes. Instead, +/// for these node types we gather the context to inline them in the second pass. In this first +/// pass, we also optimize constraints found in match statements. +/// * In the second pass, we inline the bodies of `For` nodes. +pub struct Unrolling<'a> { + diagnostics: &'a DiagnosticsHandler, +} + +impl<'a> Unrolling<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } +} + +/// This structure is used to keep track of what is needed to inline a For node +#[derive(Clone, Debug)] +pub struct ForInliningContext { + body: Link, + iterators: Vec>, + selector: Option>, + ref_node: Link, +} + +impl Pass for Unrolling<'_> { + type Input<'a> = Mir; + type Output<'a> = Mir; + type Error = CompileError; + + fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { + // The first pass unrolls all nodes fully, except for: + // - `For` nodes + // - `If` nodes and their parents + let mut first_pass = UnrollingFirstPass::new(self.diagnostics); + Visitor::run(&mut first_pass, ir.constraint_graph_mut())?; + + // The second pass actually inlines the `For` nodes + let mut second_pass = UnrollingSecondPass::new( + self.diagnostics, + first_pass.bodies_to_inline.clone(), + first_pass.all_for_nodes.clone(), + ); + Visitor::run(&mut second_pass, ir.constraint_graph_mut())?; + + // The third pass unrolls all the remaining nodes (`If` nodes and their parents) + let mut third_pass = UnrollingThirdPass::new(self.diagnostics); + Visitor::run(&mut third_pass, ir.constraint_graph_mut())?; + Ok(ir) + } +} + +/// Unrolls an `Enf` on vectors into a `Vector`. +pub fn visit_enf_bis(enf: Link) -> Result>, CompileError> { + let enf_ref = enf.as_enf().unwrap(); + let expr = enf_ref.expr.clone(); + if let Op::Vector(vec) = expr.borrow().deref() { + let ops = vec.children().borrow().clone(); + let new_vec = ops.iter().map(|op| Enf::create(op.clone(), enf_ref.span())).collect(); + return Ok(Some(Vector::create(new_vec, enf_ref.span()))); + } + Ok(None) +} + +/// Unrolls a `Value` if it represents a constant vector or matrix, or a trace access binding. +pub fn visit_value_bis(value: Link) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let value_ref = value.as_value().unwrap(); + let mir_value = value_ref.value.value.clone(); + match &mir_value { + MirValue::Constant(c) => match c { + ConstantValue::Felt(_) => {}, + ConstantValue::Vector(v) => { + return Ok(Some(unroll_constant_vector(v, value_ref.span()))); + }, + ConstantValue::Matrix(m) => { + return Ok(Some(unroll_constant_matrix(m, value_ref.span()))); + }, + }, + MirValue::TraceAccessBinding(trace_access_binding) => { + return Ok(Some(unroll_trace_access_binding(trace_access_binding, value_ref.span()))); + }, + MirValue::TraceAccess(_) + | MirValue::PeriodicColumn(_) + | MirValue::PublicInput(_) + | MirValue::PublicInputTable(_) + | MirValue::RandomValue(_) + | MirValue::BusAccess(_) + | MirValue::Null + | MirValue::Unconstrained => {}, + } + Ok(None) +} + +/// Unrolls a `Fold`. We replace the `Fold` node by a series of binary operations (e.g. `Add` or +/// `Mul`) applied to the initial value and each element of the iterator, depending on the +/// `FoldOperator`. +pub fn visit_fold_bis(fold: Link) -> Result>, CompileError> { + let fold_ref = fold.as_fold().unwrap(); + let iterator = fold_ref.iterator.clone(); + let operator = fold_ref.operator.clone(); + let initial_value = fold_ref.initial_value.clone(); + let iterator_ref = iterator.borrow(); + let Op::Vector(iterator_vector) = iterator_ref.deref() else { + unreachable!("Expected vector iterator in fold, found: {:?}", iterator_ref); + }; + let iterator_nodes = iterator_vector.children().borrow().clone(); + let resulting_node = + iterator_nodes.iter().fold(initial_value, |acc_node, node| match operator { + FoldOperator::Add => Add::create(acc_node, node.clone(), fold_ref.span()), + FoldOperator::Mul => Mul::create(acc_node, node.clone(), fold_ref.span()), + FoldOperator::None => { + unreachable!("Unexpected unrolling of Fold with None FoldOperator") + }, + }); + Ok(Some(resulting_node)) +} + +/// Unrolls a `Vector`. As vectors are already unrolled, we only need to replace vectors of size 1 +/// by their only child (i.e. a scalar). +pub fn visit_vector_bis(vector: Link) -> Result>, CompileError> { + // safe to unwrap because we just dispatched on it + let vector_ref = vector.as_vector().unwrap(); + let children = vector_ref.elements.borrow().clone(); + let size = vector_ref.size; + + // If the vector is of size 1, it is a scalar and we replace it by its only child + if size == 1 { + let child = children.first().unwrap(); + return Ok(Some(child.clone())); + } + // Otherwise, it is already in its unrolled form, we do nothing + Ok(None) +} + +// PRIVATE HELPER FUNCTIONS +// ================================================================================================ + +/// Unrolls a trace access binding into either: +/// - a TraceAccess if it is of size 1 +/// - or a Vector otherwise. +fn unroll_trace_access_binding( + trace_access_binding: &TraceAccessBinding, + span: SourceSpan, +) -> Link { + if trace_access_binding.size == 1 { + Value::create(SpannedMirValue { + span, + value: MirValue::TraceAccess(TraceAccess { + segment: trace_access_binding.segment, + column: trace_access_binding.offset, + row_offset: 0, + }), + }) + } else { + let mut vec = vec![]; + for index in 0..trace_access_binding.size { + let val = Value::create(SpannedMirValue { + span, + value: MirValue::TraceAccess(TraceAccess { + segment: trace_access_binding.segment, + column: trace_access_binding.offset + index, + row_offset: 0, + }), + }); + vec.push(val); + } + Vector::create(vec, span) + } +} + +/// Unrolls a constant vector into a Vector` +fn unroll_constant_vector(constant_vector: &Vec, span: SourceSpan) -> Link { + let mut vec = vec![]; + for val in constant_vector { + let val = Value::create(SpannedMirValue { + span, + value: MirValue::Constant(ConstantValue::Felt(*val)), + }); + vec.push(val); + } + Vector::create(vec, span) +} + +/// Unrolls a constant matrix into a `Matrix>` +fn unroll_constant_matrix(constant_matrix: &Vec>, span: SourceSpan) -> Link { + let mut res_m = vec![]; + for row in constant_matrix { + let mut res_row = vec![]; + for val in row { + let val = Value::create(SpannedMirValue { + span, + value: MirValue::Constant(ConstantValue::Felt(*val)), + }); + res_row.push(val); + } + let res_row_vec = Vector::create(res_row, span); + res_m.push(res_row_vec); + } + Matrix::create(res_m, span) +} diff --git a/mir/src/passes/unrolling/unrolling_first_pass.rs b/mir/src/passes/unrolling/unrolling_first_pass.rs new file mode 100644 index 000000000..567f62d3d --- /dev/null +++ b/mir/src/passes/unrolling/unrolling_first_pass.rs @@ -0,0 +1,287 @@ +use std::{collections::HashMap, ops::Deref}; + +use miden_diagnostics::{DiagnosticsHandler, Spanned}; + +use crate::{ + CompileError, + ir::{ + Accessor, ConstantValue, Graph, Link, MirAccessType, MirType, MirValue, Node, Op, Owner, + Parameter, Parent, SpannedMirValue, Value, Vector, + }, + passes::{ + Visitor, handle_accessor_visit, + unrolling::{ + ForInliningContext, visit_enf_bis, visit_fold_bis, visit_value_bis, visit_vector_bis, + }, + }, +}; + +pub struct UnrollingFirstPass<'a> { + #[allow(unused)] + diagnostics: &'a DiagnosticsHandler, + + // general context + work_stack: Vec>, + // For each child of a For node encountered, we store the context to inline it in the second + // pass + pub bodies_to_inline: Vec<(Link, ForInliningContext)>, + // We keep track of all parameters referencing a given For node + params_for_ref_node: HashMap>>, + // We keep a reference to For nodes in order to avoid the backlinks stored in Parameters + // referencing them to be dropped + pub all_for_nodes: HashMap, Link)>, +} + +impl<'a> UnrollingFirstPass<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { + diagnostics, + work_stack: vec![], + bodies_to_inline: vec![], + params_for_ref_node: HashMap::new(), + all_for_nodes: HashMap::new(), + } + } +} + +// For the first pass of Unrolling, we use a tweaked version of the Visitor trait, +// each visit_*_bis function returns an `Option>` instead of `Result<(), CompileError>`, +// to mutate the nodes (e.g. modifying an `Operation` to `Vector`) +impl UnrollingFirstPass<'_> { + fn visit_parameter_bis( + &mut self, + parameter: Link, + ) -> Result>, CompileError> { + // FIXME: Just check that the parameter is a scalar, raise diag otherwise + // List comprehension bodies should only be scalar expressions + + let owner_ref = + parameter.as_parameter().unwrap().ref_node.to_link().expect("Invalid Ref node"); + + self.params_for_ref_node + .entry(owner_ref.get_ptr()) + .or_default() + .push(parameter.clone()); + Ok(None) + } + + fn visit_accessor_bis(&mut self, accessor: Link) -> Result>, CompileError> { + let accessor_ref = accessor.as_accessor().unwrap(); + let indexable = accessor_ref.indexable.clone(); + if indexable.clone().as_parameter().is_none() { + handle_accessor_visit(accessor.clone(), false, self.diagnostics) + } else { + // We keep accessors wrapping parameters to allow for nested list comprehensions. + Ok(None) + } + } + + fn visit_for_bis(&mut self, for_node: Link) -> Result>, CompileError> { + // For each value produced by the iterators, we need to: + // - Duplicate the body + // - Visit the body and replace the Variables with the value (with the correct index + // depending on the binding) + // If there is a selector, we need to enforce the selector on the body + + let for_node_clone = for_node.clone(); + let for_ref = for_node_clone.as_for().unwrap(); + let iterators_ref = for_ref.iterators.borrow(); + let iterators = iterators_ref.deref(); + let expr = for_ref.expr.clone(); + let selector = for_ref.selector.clone(); + + let iterator_expected_len = validate_iterators_and_get_expected_len(iterators); + + let mut new_vec = vec![]; + for i in 0..iterator_expected_len { + let new_node = + Parameter::create(i, MirType::Felt, for_node.as_for().unwrap().deref().span()); + new_vec.push(new_node.clone()); + + let iterators_i = iterators + .iter() + .map(|iterator| get_iterator_child(iterator.clone(), i)) + .collect::>(); + let selector = if let Op::None(_) = selector.borrow().deref() { + None + } else { + Some(selector.clone()) + }; + + self.bodies_to_inline.push(( + new_node.clone(), + ForInliningContext { + body: expr.clone(), + iterators: iterators_i, + selector, + ref_node: for_node.clone(), + }, + )); + } + + let new_vec_op = Vector::create(new_vec.clone(), for_node.span()); + for param in new_vec { + param.as_parameter_mut().unwrap().set_ref_node(new_vec_op.as_owner().unwrap()); + } + Ok(Some(new_vec_op)) + } +} + +impl Visitor for UnrollingFirstPass<'_> { + fn work_stack(&mut self) -> &mut Vec> { + &mut self.work_stack + } + + // We visit all boundary constraints and all integrity constraints + // No need to visit the functions or evaluators, as they should have been inlined before this + // pass + fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { + let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); + let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let bus_roots: Vec<_> = graph + .buses + .values() + .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) + .collect(); + let combined_roots = boundary_constraints_roots_ref + .clone() + .into_iter() + .map(|bc| bc.as_node()) + .chain(integrity_constraints_roots_ref.clone().into_iter().map(|ic| ic.as_node())) + .chain(bus_roots.into_iter().map(|b| b.as_node())); + combined_roots.collect() + } + + fn visit_node(&mut self, _graph: &mut Graph, node: Link) -> Result<(), CompileError> { + // We keep a reference to all `For` nodes to avoid dropping the backlinks stored in + // `Parameters` + if let Some(owner) = node.clone().as_owner() + && let Some(op) = owner.clone().as_op() + && let Some(_for_node) = op.as_for() + { + self.all_for_nodes.insert(op.get_ptr(), (op.clone(), owner.clone())); + } + + // In this pass, we both need to dispatch the visitor depending on the node type, + // and also mutate the node if needed. We implement custom visit_*_bis methods + // that returns a Some(updated_node) if we need to update the node's value. + let updated_op: Option> = match node.borrow().deref() { + Node::Enf(e) => e.to_link().map_or(Ok(None), visit_enf_bis)?, + Node::Fold(f) => f.to_link().map_or(Ok(None), visit_fold_bis)?, + Node::Vector(v) => v.to_link().map_or(Ok(None), visit_vector_bis)?, + Node::Accessor(a) => a.to_link().map_or(Ok(None), |el| self.visit_accessor_bis(el))?, + Node::Value(v) => v.to_link().map_or(Ok(None), visit_value_bis)?, + Node::For(f) => f.to_link().map_or(Ok(None), |el| self.visit_for_bis(el))?, + Node::Parameter(p) => { + p.to_link().map_or(Ok(None), |el| self.visit_parameter_bis(el))? + }, + Node::Boundary(_) + | Node::Add(_) + | Node::Sub(_) + | Node::Mul(_) + | Node::Exp(_) + | Node::BusOp(_) + | Node::Matrix(_) + | Node::If(_) + | Node::None(_) => None, + _ => { + unreachable!( + "Unexpected node during Unrolling: Function, Evaluators and Calls should have been inlined before this pass. Found: {:?}", + node + ); + }, + }; + + // We update the node if needed + if let Some(updated_op) = updated_op { + node.as_op().unwrap().set(&updated_op); + } + + Ok(()) + } +} + +// HELPERS FUNCTIONS +// ================================================================================================ + +/// Sanity check that all iterators have the same length. +/// Note that semantic analysis should have already checked they are valid. +fn validate_iterators_and_get_expected_len(iterators: &[Link]) -> usize { + if iterators.is_empty() { + unreachable!("Semantic analysis should have caught empty iterators"); + } + let iterator_expected_len = compute_iterator_len(iterators[0].clone()); + for iterator in iterators.iter().skip(1) { + let iterator_len = compute_iterator_len(iterator.clone()); + if iterator_len != iterator_expected_len { + unreachable!("Semantic analysis should have caught iterator length mismatch"); + } + } + iterator_expected_len +} + +/// Computes the length of a node that is used as an iterator in a `For` node. +fn compute_iterator_len(iterator: Link) -> usize { + match iterator.borrow().deref() { + Op::Vector(vector) => vector.size, + Op::Matrix(matrix) => matrix.size, + Op::Accessor(accessor) => match &accessor.access_type { + MirAccessType::Default => compute_iterator_len(accessor.indexable.clone()), + MirAccessType::Index(_) => match accessor.indexable.borrow().deref() { + Op::Vector(_) => 1, + Op::Matrix(matrix) => { + let children = matrix.children().borrow().clone(); + children + .first() + .expect("Unexpected empty matrix") + .as_vector() + .expect("Expected vector for matrix row") + .size + }, + _ => unreachable!("Unexpected index into non indexable type"), + }, + MirAccessType::Matrix(..) => 1, + }, + Op::Parameter(parameter) => match parameter.ty { + MirType::Felt => 1, + MirType::Vector(l) => l, + MirType::Matrix(l, _) => l, + }, + _ => 1, + } +} + +/// Returns the i-th child of an iterator node. +fn get_iterator_child(op: Link, i: usize) -> Link { + match op.borrow().deref() { + Op::Vector(vector) => { + let children = vector.children().borrow().clone(); + children[i].clone() + }, + Op::Matrix(matrix) => { + let children = matrix.children().borrow().clone(); + children[i].clone() + }, + Op::Accessor(accessor) => { + match accessor.indexable.borrow().deref() { + // If we access an outer loop parameter in the body of an inner + // loop, we need to create + // an Accessor for the correct index in this parameter + Op::Parameter(_parameter) => { + let mir_access_type = MirAccessType::Index(Value::create(SpannedMirValue { + span: accessor.span(), + value: MirValue::Constant(ConstantValue::Felt(i as u64)), + })); + Accessor::create( + accessor.indexable.clone(), + mir_access_type, + 0, + accessor.span(), + ) + }, + _ => op.clone(), + } + }, + _ => op.clone(), + } +} diff --git a/mir/src/passes/unrolling/unrolling_second_pass.rs b/mir/src/passes/unrolling/unrolling_second_pass.rs new file mode 100644 index 000000000..1640eb66c --- /dev/null +++ b/mir/src/passes/unrolling/unrolling_second_pass.rs @@ -0,0 +1,157 @@ +use std::{collections::HashMap, ops::Deref, rc::Rc}; + +use miden_diagnostics::{DiagnosticsHandler, Spanned}; + +use crate::{ + CompileError, + ir::{Graph, Link, Mul, Node, Op, Owner, Parent, Vector}, + passes::{Visitor, duplicate_node, duplicate_node_or_replace, unrolling::ForInliningContext}, +}; + +pub struct UnrollingSecondPass<'a> { + #[allow(unused)] + diagnostics: &'a DiagnosticsHandler, + + // general context + work_stack: Vec>, + // A list of all the children of `For` nodes to inline + bodies_to_inline: Vec<(Link, ForInliningContext)>, + // The current context for inlining a `For` node, if any + for_inlining_context: Option, + // A map of nodes to replace with their inlined version + nodes_to_replace: HashMap, Link)>, + // We keep track of all parameters referencing a given `For` node + params_for_ref_node: HashMap>>, + // We keep a reference to `For` nodes in order to avoid the backlinks stored in Parameters + // referencing them to be dropped + all_for_nodes: HashMap, Link)>, +} + +impl<'a> UnrollingSecondPass<'a> { + pub fn new( + diagnostics: &'a DiagnosticsHandler, + bodies_to_inline: Vec<(Link, ForInliningContext)>, + all_for_nodes: HashMap, Link)>, + ) -> Self { + Self { + diagnostics, + work_stack: vec![], + bodies_to_inline, + for_inlining_context: None, + nodes_to_replace: HashMap::new(), + params_for_ref_node: HashMap::new(), + all_for_nodes, + } + } +} + +impl Visitor for UnrollingSecondPass<'_> { + fn work_stack(&mut self) -> &mut Vec> { + &mut self.work_stack + } + + // The root nodes visited during the second pass are the children of `For` nodes to inline + fn root_nodes_to_visit(&self, _graph: &Graph) -> Vec> { + self.bodies_to_inline + .iter() + .map(|(k, _v)| k) + .cloned() + .map(|op| op.as_node()) + .collect::>() + } + + fn run(&mut self, graph: &mut Graph) -> Result<(), CompileError> { + for root in self.root_nodes_to_visit(graph).iter() { + // Set the context corresponding to the `For` node we are inlining + self.set_context(root); + + // Recursively scan the body of the `For` node to inline + self.scan_node(graph, self.for_inlining_context.clone().unwrap().body.as_node())?; + while let Some(node) = self.work_stack().pop() { + self.visit_node(graph, node.clone())?; + } + + // We have finished inlining the body, we can now replace the Root node with the body + let body = self.for_inlining_context.clone().unwrap().body; + let new_node = self.nodes_to_replace.get(&body.get_ptr()).unwrap().1.clone(); + + // If there is a selector, we need to enforce it on the body + let new_node_with_selector_if_needed = + if let Some(selector) = self.for_inlining_context.clone().unwrap().selector { + if let Op::Vector(new_node_vector) = new_node.borrow().deref() { + let new_node_vec = new_node_vector.children().borrow().deref().clone(); + let mut new_vec = vec![]; + for new_node_child in new_node_vec.into_iter() { + let new_node_child_with_selector = Mul::create( + duplicate_node(selector.clone(), &mut HashMap::new()), + new_node_child, + root.span(), + ); + new_vec.push(new_node_child_with_selector); + } + Vector::create(new_vec, root.span()) + } else { + Mul::create(selector, new_node, root.span()) + } + } else { + new_node + }; + + // Update the root node with the new inlined body and reset the context to None + root.as_op().unwrap().set(&new_node_with_selector_if_needed); + self.for_inlining_context = None; + } + + Ok(()) + } + + fn visit_node(&mut self, _graph: &mut Graph, node: Link) -> Result<(), CompileError> { + // Skip stale nodes + if node.is_stale() { + return Ok(()); + } + + // visit_node is called on all the nodes in the body of a `For` node, they should never be + // Root nodes + let op = node.clone().as_op().expect("UnrollingSecondPass::visit_node on a non-Op node"); + + // Will duplicate the body of the `For` node, replacing the corresponding `For` node's + // Parameters by the values taken by iterators. Other Parameters will not be + // replaced (in case of nested `For` nodes) + duplicate_node_or_replace( + &mut self.nodes_to_replace, + op, + self.for_inlining_context.clone().unwrap().iterators.clone(), + self.for_inlining_context.clone().unwrap().ref_node.as_node(), + Some( + self.all_for_nodes + .get(&self.for_inlining_context.clone().unwrap().ref_node.get_ptr()) + .unwrap() + .1 + .clone(), + ), + &mut self.params_for_ref_node, + ); + Ok(()) + } +} + +impl<'a> UnrollingSecondPass<'a> { + /// Sets the context for inlining a `For` node based on the root node. + fn set_context(&mut self, root: &Link) { + // Set context to inline the body for this index + let for_inlining_context = self.bodies_to_inline.iter().find_map(|(node, context)| { + if Rc::ptr_eq(&node.clone().as_node().link, &root.link) { + Some(context.clone()) + } else { + None + } + }); + + self.for_inlining_context = for_inlining_context; + // We inline a new body, so we clear the nodes to replace and the parameters for the ref + // node + self.nodes_to_replace.clear(); + self.params_for_ref_node.clear(); + } +} diff --git a/mir/src/passes/unrolling/unrolling_third_pass.rs b/mir/src/passes/unrolling/unrolling_third_pass.rs new file mode 100644 index 000000000..1ff803a09 --- /dev/null +++ b/mir/src/passes/unrolling/unrolling_third_pass.rs @@ -0,0 +1,141 @@ +use std::ops::Deref; + +use miden_diagnostics::{DiagnosticsHandler, Spanned}; + +use crate::{ + CompileError, + ir::{Graph, Link, Node, Op, RandomInputs, Vector}, + passes::{ + Visitor, + unrolling::{ + match_optimizer::MatchOptimizer, visit_enf_bis, visit_fold_bis, visit_value_bis, + visit_vector_bis, + }, + }, +}; + +pub struct UnrollingThirdPass<'a> { + #[allow(unused)] + diagnostics: &'a DiagnosticsHandler, + + // general context + work_stack: Vec>, + // current evaluations of nodes at random points + random_inputs: RandomInputs, +} + +impl<'a> UnrollingThirdPass<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { + diagnostics, + work_stack: vec![], + random_inputs: RandomInputs::default(), + } + } +} + +// For the third pass of Unrolling, we use a tweaked version of the Visitor trait, +// each `visit_*_bis` function returns an `Option>` instead of `Result<(), CompileError>`, +// to mutate the nodes (e.g. modifying an `If` node to `Vector`) +impl UnrollingThirdPass<'_> { + /// Visiting an `If` node consists of evaluating all the main trace constraints contained in + /// the match arms, and combining them to optimize the resulting vector of constraints if + /// possible. We handle bus related constraints separately, as they cannot be combined with + /// main trace constraints. + /// + /// The methods returns a `Vector` node containing all the optimized constraints equivalent to + /// the `If` node. This means it both applies the selectors to the constraints of each match + /// arm, and combines them in optimized constraints. The documentation of the match_optimizer + /// module contains additional details on the optimization methodology. + fn visit_if_bis(&mut self, if_node: Link) -> Result>, CompileError> { + let if_ref = if_node.as_if().unwrap(); + let match_arms = if_ref.match_arms.borrow(); + + // 1. Instantiate a new MatchOptimizer to handle the constraints of this node + let mut match_optimizer = MatchOptimizer::new(&mut self.random_inputs); + + let mut bus_related_constraints = Vec::new(); + + // 2. For each match arm, gather bus-related constraints + // to be handled separately and evaluate the main constraints + for match_arm in match_arms.iter() { + let bus_related_constraints_for_match_arm = + match_optimizer.evaluate_match_arm(match_arm)?; + bus_related_constraints + .push((match_arm.condition.clone(), bus_related_constraints_for_match_arm)); + } + + // 3. Construct the new vector of combined main constraints + let combined_main_constraints = match_optimizer.reduce_main_constraints(if_ref.span); + + // 4. Add all the constraints that are bus-related + let new_vec = MatchOptimizer::gather_all_constraints( + &mut bus_related_constraints, + combined_main_constraints, + if_ref.span, + ); + + Ok(Some(Vector::create(new_vec, if_ref.span()))) + } +} + +impl Visitor for UnrollingThirdPass<'_> { + fn work_stack(&mut self) -> &mut Vec> { + &mut self.work_stack + } + + // We visit all boundary constraints and all integrity constraints + // No need to visit the functions or evaluators, as they should have been inlined before this + // pass + fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { + let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); + let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let bus_roots: Vec<_> = graph + .buses + .values() + .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) + .collect(); + let combined_roots = boundary_constraints_roots_ref + .clone() + .into_iter() + .map(|bc| bc.as_node()) + .chain(integrity_constraints_roots_ref.clone().into_iter().map(|ic| ic.as_node())) + .chain(bus_roots.into_iter().map(|b| b.as_node())); + combined_roots.collect() + } + + fn visit_node(&mut self, _graph: &mut Graph, node: Link) -> Result<(), CompileError> { + // In this pass, we both need to dispatch the visitor depending on the node type, + // and also mutate the node if needed. We implement custom visit_*_bis methods + // that returns a Some(updated_node) if we need to update the node's value. + let updated_op: Option> = match node.borrow().deref() { + Node::Enf(e) => e.to_link().map_or(Ok(None), visit_enf_bis)?, + Node::Fold(f) => f.to_link().map_or(Ok(None), visit_fold_bis)?, + Node::Vector(v) => v.to_link().map_or(Ok(None), visit_vector_bis)?, + Node::Value(v) => v.to_link().map_or(Ok(None), visit_value_bis)?, + Node::If(i) => i.to_link().map_or(Ok(None), |el| self.visit_if_bis(el))?, + Node::Boundary(_) + | Node::Add(_) + | Node::Sub(_) + | Node::Mul(_) + | Node::Exp(_) + | Node::BusOp(_) + | Node::Matrix(_) + | Node::Accessor(_) + | Node::None(_) => None, + _ => { + unreachable!( + "Unexpected node during Unrolling: Function, Evaluators, Calls, For nodes and Parameters should have been inlined before this pass. Found: {:?}", + node + ); + }, + }; + + // We update the node if needed + if let Some(updated_op) = updated_op { + node.as_op().unwrap().set(&updated_op); + } + + Ok(()) + } +} diff --git a/mir/src/passes/value_numbering.rs b/mir/src/passes/value_numbering.rs deleted file mode 100644 index 7fcb4838d..000000000 --- a/mir/src/passes/value_numbering.rs +++ /dev/null @@ -1,63 +0,0 @@ -use air_pass::Pass; -use miden_diagnostics::DiagnosticsHandler; - -use super::visitor::Visitor; -use crate::{ - ir::{Link, Mir, Node}, - CompileError, -}; - -/// TODO MIR: -/// If needed, implement value numbering pass on MIR -/// See https://en.wikipedia.org/wiki/Value_numbering -/// -pub struct ValueNumbering<'a> { - #[allow(unused)] - diagnostics: &'a DiagnosticsHandler, - work_stack: Vec>, -} - -impl Pass for ValueNumbering<'_> { - type Input<'a> = Mir; - type Output<'a> = Mir; - type Error = CompileError; - - fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { - Visitor::run(self, ir.constraint_graph_mut())?; - Ok(ir) - } -} - -impl<'a> ValueNumbering<'a> { - #[allow(unused)] - pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { - diagnostics, - work_stack: vec![], - } - } -} - -impl Visitor for ValueNumbering<'_> { - fn work_stack(&mut self) -> &mut Vec> { - &mut self.work_stack - } - fn root_nodes_to_visit( - &self, - graph: &crate::ir::Graph, - ) -> Vec> { - let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); - let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); - let combined_roots = boundary_constraints_roots_ref - .clone() - .into_iter() - .map(|bc| bc.as_node()) - .chain( - integrity_constraints_roots_ref - .clone() - .into_iter() - .map(|ic| ic.as_node()), - ); - combined_roots.collect() - } -} diff --git a/mir/src/passes/visitor.rs b/mir/src/passes/visitor.rs index 171be86ad..77fabc1e2 100644 --- a/mir/src/passes/visitor.rs +++ b/mir/src/passes/visitor.rs @@ -1,17 +1,18 @@ +use std::ops::Deref; + use crate::{ CompileError, ir::{Graph, Link, Node, Op, Parent, Root}, }; -use std::ops::Deref; /// A trait for visiting nodes in a MIR graph, from the leafs to the root. /// /// The process is as follows: /// - Starting from the root_nodes_to_visit, we scan a node -/// - We recursively scan all the children of the node, depth-first, -/// and store them on a stack in the same order. -/// - Once we have scanned all the children, -/// we visit the nodes in the stack starting from the last one +/// - We recursively scan all the children of the node, depth-first, and store them on a stack in +/// the same order. +/// - Once we have scanned all the children, we visit the nodes in the stack starting from the last +/// one /// - We then dispatch to the relevant `visit_*` method based on the variant of the node pub trait Visitor { fn work_stack(&mut self) -> &mut Vec>; @@ -65,7 +66,7 @@ pub trait Visitor { Node::None(_) => Ok(()), } } - /// Visit a Function node + /// Visit a `Function` node fn visit_function( &mut self, _graph: &mut Graph, @@ -73,7 +74,7 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } - /// Visit an Evaluator node + /// Visit an `Evaluator` node fn visit_evaluator( &mut self, _graph: &mut Graph, @@ -81,11 +82,11 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } - /// Visit an Enf node + /// Visit an `Enf` node fn visit_enf(&mut self, _graph: &mut Graph, _enf: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Boundary node + /// Visit a `Boundary` node fn visit_boundary( &mut self, _graph: &mut Graph, @@ -93,47 +94,47 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } - /// Visit an Add node + /// Visit an `Add` node fn visit_add(&mut self, _graph: &mut Graph, _add: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Sub node + /// Visit a `Sub` node fn visit_sub(&mut self, _graph: &mut Graph, _sub: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Mul node + /// Visit a `Mul` node fn visit_mul(&mut self, _graph: &mut Graph, _mul: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit an Exp node + /// Visit an `Exp` node fn visit_exp(&mut self, _graph: &mut Graph, _exp: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit an If node + /// Visit an `If` node fn visit_if(&mut self, _graph: &mut Graph, _if_node: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a For node + /// Visit a `For` node fn visit_for(&mut self, _graph: &mut Graph, _for_node: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Call node + /// Visit a `Call` node fn visit_call(&mut self, _graph: &mut Graph, _call: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Fold node + /// Visit a `Fold` node fn visit_fold(&mut self, _graph: &mut Graph, _fold: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Vector node + /// Visit a `Vector` node fn visit_vector(&mut self, _graph: &mut Graph, _vector: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Matrix node + /// Visit a `Matrix` node fn visit_matrix(&mut self, _graph: &mut Graph, _matrix: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit an Accessor node + /// Visit an `Accessor` node fn visit_accessor( &mut self, _graph: &mut Graph, @@ -141,11 +142,11 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } - /// Visit a BusOp node + /// Visit a `BusOp` node fn visit_bus_op(&mut self, _graph: &mut Graph, _bus_op: Link) -> Result<(), CompileError> { Ok(()) } - /// Visit a Parameter node + /// Visit a `Parameter` node fn visit_parameter( &mut self, _graph: &mut Graph, @@ -153,7 +154,7 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } - /// Visit a Value node + /// Visit a `Value` node fn visit_value(&mut self, _graph: &mut Graph, _value: Link) -> Result<(), CompileError> { Ok(()) } diff --git a/mir/src/tests/access.rs b/mir/src/tests/access.rs index 5e4571fc3..9221d5500 100644 --- a/mir/src/tests/access.rs +++ b/mir/src/tests/access.rs @@ -21,10 +21,7 @@ fn invalid_vector_access_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -48,10 +45,7 @@ fn invalid_matrix_row_access_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -75,10 +69,7 @@ fn invalid_matrix_column_access_in_boundary_constraint() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -102,10 +93,7 @@ fn invalid_vector_access_in_integrity_constraint() { enf clk' = clk + A + B[3] - C[1][2]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -129,10 +117,7 @@ fn invalid_matrix_row_access_in_integrity_constraint() { enf clk' = clk + A + B[1] - C[3][2]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -156,8 +141,5 @@ fn invalid_matrix_column_access_in_integrity_constraint() { enf clk' = clk + A + B[1] - C[1][3]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } diff --git a/mir/src/tests/buses.rs b/mir/src/tests/buses.rs index 73e672bf3..cf5ae146e 100644 --- a/mir/src/tests/buses.rs +++ b/mir/src/tests/buses.rs @@ -1,15 +1,7 @@ -use crate::{ - ir::{ - Add, Builder, Bus, Fold, FoldOperator, Link, Mir, MirValue, Op, PublicInputTableAccess, - Vector, assert_bus_eq, - }, - tests::translate, -}; -use air_parser::{Symbol, ast}; -use miden_diagnostics::{SourceSpan, Spanned}; +use air_parser::ast; use super::{compile, expect_diagnostic}; -use core::slice; +use crate::ir::{Link, MirValue, Op, PublicInputTableAccess}; #[test] fn buses_in_boundary_constraints() { @@ -78,74 +70,6 @@ fn buses_in_integrity_constraints() { assert!(compile(source).is_ok()); } -#[test] -fn buses_args_expr_in_integrity_expr() { - let source = " - def test - - trace_columns { - main: [a], - } - - public_inputs { - inputs: [2], - } - - buses { - multiset p, - } - - boundary_constraints { - enf p.first = null; - } - - integrity_constraints { - let vec = [x for x in 0..3]; - let b = 41; - let x = sum(vec) + b; - p.insert(x) when 1; - p.remove(x) when 0; - }"; - assert!(compile(source).is_ok()); - let mut result_mir = translate(source).unwrap(); - let bus = Bus::create( - ast::Identifier::new(SourceSpan::default(), Symbol::new(0)), - ast::BusType::Multiset, - SourceSpan::default(), - ); - let vec_op = Vector::builder() - .size(3) - .elements(From::from(0)) - .elements(From::from(1)) - .elements(From::from(2)) - .span(SourceSpan::default()) - .build(); - let b: Link = From::from(41); - let vec_sum = Fold::builder() - .iterator(vec_op) - .operator(FoldOperator::Add) - .initial_value(From::from(0)) - .span(SourceSpan::default()) - .build(); - let x: Link = Add::builder() - .lhs(vec_sum) - .rhs(b.clone()) - .span(SourceSpan::default()) - .build(); - let sel: Link = From::from(1); - let _p_add = bus.insert(slice::from_ref(&x), sel.clone(), SourceSpan::default()); - let not_sel: Link = From::from(0); - let _p_rem = bus.remove(slice::from_ref(&x), not_sel.clone(), SourceSpan::default()); - let bus_ident = result_mir.constraint_graph().buses.keys().next().unwrap(); - let bus_name = ast::Identifier::new(bus_ident.span(), bus_ident.name()); - bus.borrow_mut().set_name_unchecked(bus_name); - let mut expected_mir = Mir::new(result_mir.name); - let _ = expected_mir - .constraint_graph_mut() - .insert_bus(*bus_ident, bus.clone()); - assert_bus_eq(&mut expected_mir, &mut result_mir); -} - #[test] fn buses_table_in_boundary_constraints() { let source = " @@ -176,11 +100,8 @@ fn buses_table_in_boundary_constraints() { assert!(result.is_ok()); let get_name = |op: &Link| -> (ast::Identifier, usize) { - let MirValue::PublicInputTable(PublicInputTableAccess { - table_name, - num_cols, - .. - }) = op.as_value().unwrap().value.value + let MirValue::PublicInputTable(PublicInputTableAccess { table_name, num_cols, .. }) = + op.as_value().unwrap().value.value else { panic!("Expected a public input, got {op:#?}"); }; diff --git a/mir/src/tests/computed_indices.rs b/mir/src/tests/computed_indices.rs new file mode 100644 index 000000000..880963e5e --- /dev/null +++ b/mir/src/tests/computed_indices.rs @@ -0,0 +1,109 @@ +use super::{compile, expect_diagnostic}; + +#[test] +fn basic_computed_indices() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + + enf a = x[1 + 1]; + }"; + + assert!(compile(source).is_ok()); +} + +#[test] +fn basic_computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[1 + 1] for i in 0..5]; + + enf a = y[1 + 1]; + }"; + + assert!(compile(source).is_ok()); +} + +#[test] +fn computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[i + 1] for i in 0..4]; + + enf a = y[1 + 1]; + }"; + + assert!(compile(source).is_ok()); +} + +// Tests that should return errors +#[test] +fn err_computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[a + 1] for i in 0..4]; + + enf a = y[1 + 1]; + }"; + + expect_diagnostic(source, "error: the index is not constant during constant propagation"); +} diff --git a/mir/src/tests/integrity_constraints/comprehension/list_comprehension.rs b/mir/src/tests/integrity_constraints/comprehension/list_comprehension.rs index 8bf1347e9..171a5ca40 100644 --- a/mir/src/tests/integrity_constraints/comprehension/list_comprehension.rs +++ b/mir/src/tests/integrity_constraints/comprehension/list_comprehension.rs @@ -125,10 +125,7 @@ fn err_index_out_of_range_lc_ident() { enf clk = x[2]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -151,10 +148,7 @@ fn err_index_out_of_range_lc_slice() { enf clk = x[3]; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] diff --git a/mir/src/tests/ir/inlining2.rs b/mir/src/tests/ir/inlining2.rs index 7ddfdde3e..38b737abe 100644 --- a/mir/src/tests/ir/inlining2.rs +++ b/mir/src/tests/ir/inlining2.rs @@ -1,6 +1,8 @@ use crate::tests::compile; #[cfg(test)] mod tests { + use ntest::timeout; + use super::*; //use crate::graph::pretty; @@ -54,4 +56,89 @@ mod tests { "; let _mir = compile(code).unwrap(); } + + #[test] + #[timeout(5000)] + fn inline_deeply_nested_calls() { + let code = " + def DeeplyNestedInliningLoopBug + + trace_columns { + main: [a, b, c], + } + + public_inputs { + x: [1], + } + + # Utility functions (like utils module) + fn binary_and(a: felt, b: felt) -> felt { + return a * b; + } + + fn binary_or(a: felt, b: felt) -> felt { + return a + b - a * b; + } + + fn binary_not(a: felt) -> felt { + return 1 - a; + } + + # Helper functions that use utilities (like flag functions) + fn flag_current_and_next(s_next: felt) -> felt { + return binary_not(s_next); + } + + fn flag_last(s_next: felt) -> felt { + return s_next; + } + + # Function that returns array and calls multiple helper functions (like section_flags) + fn section_flags(s_next: felt, s_start: felt, s_start_next: felt) -> felt[3] { + let f_next_flag = flag_current_and_next(s_next); + let f_last_flag = flag_last(s_next); + + let f_start = s_start; + let f_next = binary_not(s_start_next); + let f_end = binary_or(binary_and(f_next_flag, s_start_next), f_last_flag); + + return [f_start, f_next, f_end]; + } + + # Function that returns array (like block_flags) + fn block_flags(s_block: felt) -> felt[2] { + let f_read = binary_not(s_block); + let f_eval = s_block; + + return [f_read, f_eval]; + } + + # Top-level evaluator that uses these functions (like section_block_flags_constraints) + fn complex_constraint(s: felt, sstart: felt, sstart_next: felt, sblock: felt) -> felt { + let flags = section_flags(s, sstart, sstart_next); + let f_start = flags[0]; + let f_next = flags[1]; + let f_end = flags[2]; + + let blocks = block_flags(sblock); + let blocks_next = block_flags(sblock); # Simulate calling with next state + let f_read = blocks[0]; + let f_eval = blocks[1]; + let f_read_next = blocks_next[0]; + let f_eval_next = blocks_next[1]; + + return f_start + f_next + f_end + f_read + f_eval + f_read_next + f_eval_next; + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + # Call complex constraint with deeply nested function calls + enf a = complex_constraint(b, c, b, c); + } + "; + let _mir = compile(code).unwrap(); + } } diff --git a/mir/src/tests/list_comprehension.rs b/mir/src/tests/list_comprehension.rs index b71ef4cb8..6a5b4f82a 100644 --- a/mir/src/tests/list_comprehension.rs +++ b/mir/src/tests/list_comprehension.rs @@ -1,6 +1,5 @@ -use crate::ir::assert_integrity_eq; - use super::compile; +use crate::ir::assert_integrity_eq; #[test] fn list_comprehension_nested_nobind() { @@ -29,7 +28,7 @@ fn list_comprehension_nested_nobind() { let state = [6, 5, 4]; let expected = [13, 58, 103, 148]; let result = [inner_loop(state, row) for row in TABLE]; - enf expected = result; + enf expected = result for (expected, result) in (expected, result); } fn inner_loop(st: felt[3], ro: felt[3]) -> felt { @@ -62,7 +61,7 @@ fn list_comprehension_nested_nobind() { let state = [6, 5, 4]; let expected = [13, 58, 103, 148]; let result = [sum([s * m for (s, m) in (state, row)]) for row in TABLE]; - enf expected = result; + enf expected = result for (expected, result) in (expected, result); }"; let Ok(mut nested) = compile(source_nested) else { diff --git a/mir/src/tests/mod.rs b/mir/src/tests/mod.rs index 88afd1092..49f7b6c5f 100644 --- a/mir/src/tests/mod.rs +++ b/mir/src/tests/mod.rs @@ -1,6 +1,7 @@ mod access; mod boundary_constraints; mod buses; +mod computed_indices; mod constant; mod evaluators; mod functions; @@ -14,19 +15,20 @@ mod source_sections; mod trace; mod variables; +use std::sync::Arc; + +use air_pass::Pass; +use miden_diagnostics::{CodeMap, DiagnosticsConfig, DiagnosticsHandler, Verbosity}; + /// Note: Tests on this module are currently redundant with the tests in the `air-ir` crate. /// -/// Indeed, these tests ensure that we can compile and translate AirScript code into AIR, with both pipelines (with and without MIR), -/// so if these tests pass, we can produce a Mir. +/// Indeed, these tests ensure that we can compile and translate AirScript code into AIR, with +/// both pipelines (with and without MIR), so if these tests pass, we can produce a Mir. /// -/// However, instead of removing the following tests, we should ensure the resulting Mir graph is consistent with what is expected, as well as test each pass. +/// However, instead of removing the following tests, we should ensure the resulting Mir graph +/// is consistent with what is expected, as well as test each pass. pub use crate::CompileError; - -use std::sync::Arc; - use crate::ir::Mir; -use air_pass::Pass; -use miden_diagnostics::{CodeMap, DiagnosticsConfig, DiagnosticsHandler, Verbosity}; pub fn compile(source: &str) -> Result { let compiler = Compiler::default(); @@ -36,7 +38,7 @@ pub fn compile(source: &str) -> Result { compiler.diagnostics.emit(err); compiler.emitter.print_captured_to_stderr(); Err(()) - } + }, } } @@ -48,7 +50,7 @@ pub fn translate(source: &str) -> Result { compiler.diagnostics.emit(err); compiler.emitter.print_captured_to_stderr(); Err(()) - } + }, } } @@ -61,7 +63,7 @@ pub fn parse(source: &str) -> Result { compiler.diagnostics.emit(err); compiler.emitter.print_captured_to_stderr(); Err(()) - } + }, } } @@ -71,7 +73,7 @@ pub fn expect_diagnostic(source: &str, expected: &str) { let err = match compiler.compile(source) { Ok(ref ast) => { panic!("expected compilation to fail, got {ast:#?}"); - } + }, Err(err) => err, }; compiler.diagnostics.emit(err); @@ -79,10 +81,7 @@ pub fn expect_diagnostic(source: &str, expected: &str) { if !found { compiler.emitter.print_captured_to_stderr(); } - assert!( - found, - "expected diagnostic output to contain the string: '{expected}'", - ); + assert!(found, "expected diagnostic output to contain the string: '{expected}'",); } struct Compiler { @@ -104,28 +103,18 @@ impl Compiler { pub fn new(config: DiagnosticsConfig) -> Self { let codemap = Arc::new(CodeMap::new()); let emitter = Arc::new(SplitEmitter::new()); - let diagnostics = Arc::new(DiagnosticsHandler::new( - config, - codemap.clone(), - emitter.clone(), - )); + let diagnostics = + Arc::new(DiagnosticsHandler::new(config, codemap.clone(), emitter.clone())); - Self { - codemap, - emitter, - diagnostics, - } + Self { codemap, emitter, diagnostics } } pub fn compile(&self, source: &str) -> Result { air_parser::parse(&self.diagnostics, self.codemap.clone(), source) .map_err(CompileError::Parse) .and_then(|ast| { - let mut pipeline = - air_parser::transforms::ConstantPropagation::new(&self.diagnostics) - .chain(crate::passes::AstToMir::new(&self.diagnostics)) - .chain(crate::passes::Inlining::new(&self.diagnostics)) - .chain(crate::passes::Unrolling::new(&self.diagnostics)); + let mut pipeline = air_parser::AstPasses::new(&self.diagnostics) + .chain(crate::MirPasses::new(&self.diagnostics)); pipeline.run(ast) }) } @@ -166,9 +155,10 @@ impl SplitEmitter { } pub fn print_captured_to_stderr(&self) { - use miden_diagnostics::Emitter; use std::io::Write; + use miden_diagnostics::Emitter; + let mut copy = self.default.buffer(); let captured = self.capture.captured(); copy.write_all(captured.as_bytes()).unwrap(); diff --git a/mir/src/tests/trace.rs b/mir/src/tests/trace.rs index 71d519cbd..7c8a33a25 100644 --- a/mir/src/tests/trace.rs +++ b/mir/src/tests/trace.rs @@ -106,10 +106,7 @@ fn err_bc_trace_cols_access_out_of_bounds() { enf a[0]' = a[0] - 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -134,10 +131,7 @@ fn err_ic_trace_cols_access_out_of_bounds() { enf a[4]' = a[4] - 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -159,3 +153,23 @@ fn err_ic_trace_cols_group_used_as_scalar() { expect_diagnostic(source, "type mismatch"); } + +#[test] +fn err_binop_on_non_scalar() { + let source = " + def test + trace_columns { + main: [clk, a[4], b[4]], + } + public_inputs { + stack_inputs: [16], + } + boundary_constraints { + enf a[1].first = 0; + } + integrity_constraints { + enf a = b; + }"; + + expect_diagnostic(source, "binary operations are only allowed on scalar values"); +} diff --git a/mir/src/tests/variables.rs b/mir/src/tests/variables.rs index 457b07571..5866ad5d8 100644 --- a/mir/src/tests/variables.rs +++ b/mir/src/tests/variables.rs @@ -303,10 +303,7 @@ fn invalid_vector_variable_access_out_of_bounds() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -328,10 +325,7 @@ fn invalid_matrix_column_variable_access_out_of_bounds() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] @@ -353,10 +347,7 @@ fn invalid_matrix_row_variable_access_out_of_bounds() { enf clk' = clk + 1; }"; - expect_diagnostic( - source, - "attempted to access an index which is out of bounds", - ); + expect_diagnostic(source, "attempted to access an index which is out of bounds"); } #[test] diff --git a/parser/Cargo.toml b/parser/Cargo.toml index cf424090f..c50738c40 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-parser" -version = "0.4.0" +version = "0.5.0" description = "Parser for the AirScript language" authors.workspace = true readme = "README.md" @@ -15,7 +15,7 @@ edition.workspace = true lalrpop = { version = "0.20", default-features = false } [dependencies] -air-pass = { package = "air-pass", path = "../pass", version = "0.4" } +air-pass = { package = "air-pass", path = "../pass", version = "0.5" } either = "1.12" lalrpop-util = "0.20" lazy_static = "1.4" diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index 2dc7723bd..671a0849b 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -14,8 +14,8 @@ //! * `boundary_constraints` //! * `integrity_constraints` //! -//! All other declarations are module-scoped, and must be explicitly imported by a module which wishes -//! to reference them. Not all items are importable however, only the following: +//! All other declarations are module-scoped, and must be explicitly imported by a module which +//! wishes to reference them. Not all items are importable however, only the following: //! //! * constants //! * evaluators @@ -83,14 +83,10 @@ pub struct Bus { impl Bus { /// Creates a new bus declaration pub const fn new(span: SourceSpan, name: Identifier, bus_type: BusType) -> Self { - Self { - span, - name, - bus_type, - } + Self { span, name, bus_type } } } -#[derive(Default, Copy, Hash, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Copy, Hash, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum BusType { /// A multiset bus #[default] @@ -99,6 +95,15 @@ pub enum BusType { Logup, } +impl fmt::Display for BusType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Multiset => write!(f, "multiset"), + Self::Logup => write!(f, "logup"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum BusOperator { /// Insert a tuple to the bus @@ -173,7 +178,7 @@ impl ConstantExpr { let num_rows = rows.len(); let num_cols = rows.first().unwrap().len(); Type::Matrix(num_rows, num_cols) - } + }, } } @@ -188,7 +193,7 @@ impl fmt::Display for ConstantExpr { Self::Scalar(value) => write!(f, "{value}"), Self::Vector(values) => { write!(f, "{}", DisplayList(values.as_slice())) - } + }, Self::Matrix(values) => write!( f, "{}", @@ -200,6 +205,12 @@ impl fmt::Display for ConstantExpr { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ImportLimb { + Star, + Ident(Identifier), +} + /// An import declaration /// /// There can be multiple of these in a given module @@ -216,7 +227,7 @@ pub enum Import { impl Import { pub fn module(&self) -> ModuleId { match self { - Self::All { module } | Self::Partial { module, .. } => *module, + Self::All { module } | Self::Partial { module, .. } => module.clone(), } } } @@ -225,16 +236,9 @@ impl PartialEq for Import { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::All { module: l }, Self::All { module: r }) => l == r, - ( - Self::Partial { - module: l, - items: ls, - }, - Self::Partial { - module: r, - items: rs, - }, - ) if l == r => ls.difference(rs).next().is_none(), + (Self::Partial { module: l, items: ls }, Self::Partial { module: r, items: rs }) => { + l == r && ls == rs + }, _ => false, } } @@ -247,12 +251,14 @@ impl PartialEq for Import { pub enum Export<'a> { Constant(&'a crate::ast::Constant), Evaluator(&'a EvaluatorFunction), + Function(&'a Function), } impl Export<'_> { pub fn name(&self) -> Identifier { match self { Self::Constant(item) => item.name, Self::Evaluator(item) => item.name, + Self::Function(item) => item.name, } } @@ -264,6 +270,7 @@ impl Export<'_> { match self { Self::Constant(item) => Some(item.ty()), Self::Evaluator(_) => None, + Self::Function(item) => Some(item.return_type), } } } @@ -352,22 +359,12 @@ impl Eq for PublicInput {} impl PartialEq for PublicInput { fn eq(&self, other: &Self) -> bool { match (self, other) { - ( - Self::Vector { - name: l, size: ls, .. - }, - Self::Vector { - name: r, size: rs, .. - }, - ) => l == r && ls == rs, - ( - Self::Table { - name: l, size: lc, .. - }, - Self::Table { - name: r, size: rc, .. - }, - ) => l == r && lc == rc, + (Self::Vector { name: l, size: ls, .. }, Self::Vector { name: r, size: rs, .. }) => { + l == r && ls == rs + }, + (Self::Table { name: l, size: lc, .. }, Self::Table { name: r, size: rc, .. }) => { + l == r && lc == rc + }, _ => false, } } @@ -392,12 +389,7 @@ impl EvaluatorFunction { params: Vec, body: Vec, ) -> Self { - Self { - span, - name, - params, - body, - } + Self { span, name, params, body } } } impl Eq for EvaluatorFunction {} @@ -430,13 +422,7 @@ impl Function { return_type: Type, body: Vec, ) -> Self { - Self { - span, - name, - params, - return_type, - body, - } + Self { span, name, params, return_type, body } } pub fn param_types(&self) -> Vec { @@ -453,3 +439,52 @@ impl PartialEq for Function { && self.body == other.body } } + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use miden_diagnostics::SourceSpan; + + use super::*; + use crate::symbols::Symbol; + + fn ident(name: &str) -> Identifier { + Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(name)) + } + + fn module_id(parts: &[&str]) -> ModuleId { + let ids = parts.iter().map(|p| ident(p)).collect::>(); + ModuleId::new(ids, SourceSpan::UNKNOWN) + } + + #[test] + fn import_partial_subset_is_not_equal_either_direction() { + let mut set_a: HashSet = HashSet::default(); + set_a.insert(ident("a")); + let import_a = Import::Partial { module: module_id(&["m"]), items: set_a }; + + let mut set_ab: HashSet = HashSet::default(); + set_ab.insert(ident("a")); + set_ab.insert(ident("b")); + let import_ab = Import::Partial { module: module_id(&["m"]), items: set_ab }; + + assert!(import_a != import_ab); + assert!(import_ab != import_a); + } + + #[test] + fn import_partial_identical_sets_are_equal() { + let mut set1: HashSet = HashSet::default(); + set1.insert(ident("a")); + set1.insert(ident("b")); + let mut set2: HashSet = HashSet::default(); + set2.insert(ident("b")); + set2.insert(ident("a")); + + let import1 = Import::Partial { module: module_id(&["m"]), items: set1 }; + let import2 = Import::Partial { module: module_id(&["m"]), items: set2 }; + + assert_eq!(import1, import2); + } +} diff --git a/parser/src/ast/display.rs b/parser/src/ast/display.rs index 0edf7b924..6c37f7e26 100644 --- a/parser/src/ast/display.rs +++ b/parser/src/ast/display.rs @@ -40,11 +40,7 @@ impl fmt::Display for DisplayTuple<'_, T> { pub struct DisplayTypedTuple<'a, V, T>(pub &'a [(V, T)]); impl fmt::Display for DisplayTypedTuple<'_, V, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "({})", - DisplayCsv::new(self.0.iter().map(|(v, t)| format!("{v}: {t}"))) - ) + write!(f, "({})", DisplayCsv::new(self.0.iter().map(|(v, t)| format!("{v}: {t}")))) } } @@ -101,16 +97,16 @@ impl fmt::Display for DisplayStatement<'_> { in_expr_position: false, }; write!(f, "{display}") - } + }, Statement::Enforce(expr) => { write!(f, "enf {expr}") - } - Statement::EnforceIf(expr, selector) => { - write!(f, "enf {expr} when {selector}") - } + }, + Statement::EnforceIf(match_expr) => { + write!(f, "enf {match_expr}") + }, Statement::EnforceAll(expr) => { write!(f, "enf {expr}") - } + }, Statement::Expr(expr) => write!(f, "return {expr}"), Statement::BusEnforce(expr) => write!(f, "enf {expr}"), } @@ -152,7 +148,7 @@ impl fmt::Display for DisplayLet<'_> { } else { f.write_str("}\n")?; } - } + }, value => { write!(f, "let {} = {}", self.let_expr.name, value)?; if self.in_expr_position { @@ -160,7 +156,7 @@ impl fmt::Display for DisplayLet<'_> { } else { f.write_char('\n')?; } - } + }, } for stmt in self.let_expr.body.iter() { writeln!(f, "{}", stmt.display(self.indent + 1))?; diff --git a/parser/src/ast/errors.rs b/parser/src/ast/errors.rs index 555692270..1afff53f5 100644 --- a/parser/src/ast/errors.rs +++ b/parser/src/ast/errors.rs @@ -32,18 +32,14 @@ impl ToDiagnostic for InvalidExprError { match self { Self::NonConstantExponent(span) => Diagnostic::error() .with_message("invalid expression") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message), - ]) + .with_labels(vec![Label::primary(span.source_id(), span).with_message(message)]) .with_notes(vec![ "Only constant powers are supported with the exponentiation operator currently" .to_string(), ]), Self::NonConstantRangeExpr(span) => Diagnostic::error() .with_message("invalid expression") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message), - ]) + .with_labels(vec![Label::primary(span.source_id(), span).with_message(message)]) .with_notes(vec![ "Range expression must be a constant to do this operation".to_string(), ]), @@ -53,9 +49,7 @@ impl ToDiagnostic for InvalidExprError { | Self::InvalidLetExpr(span) | Self::NotAnExpr(span) => Diagnostic::error() .with_message("invalid expression") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message), - ]), + .with_labels(vec![Label::primary(span.source_id(), span).with_message(message)]), } } } @@ -78,9 +72,7 @@ impl ToDiagnostic for InvalidTypeError { match self { Self::NonVectorIterable(span) => Diagnostic::error() .with_message("invalid type") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message), - ]) + .with_labels(vec![Label::primary(span.source_id(), span).with_message(message)]) .with_notes(vec!["Only vectors can be used as iterables".to_string()]), } } diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index d8566d4b3..97f26fa94 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -13,9 +13,8 @@ use std::{convert::AsRef, fmt}; use miden_diagnostics::{SourceSpan, Span, Spanned}; -use crate::symbols::Symbol; - use super::*; +use crate::symbols::Symbol; /// A range literal, equivalent to the interval `[start, end)`. pub type Range = std::ops::Range; @@ -67,9 +66,7 @@ impl PartialEq<&Identifier> for Identifier { } impl fmt::Debug for Identifier { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Identifier") - .field(&format!("{}", &self.0.item)) - .finish() + f.debug_tuple("Identifier").field(&format!("{}", &self.0.item)).finish() } } impl fmt::Display for Identifier { @@ -91,8 +88,8 @@ impl From for Identifier { /// Represents an identifier qualified with its namespace. /// /// Identifiers in AirScript are separated into two namespaces: one for functions, -/// and one for buses and bindings. This is because functions cannot be bound, added to or remove from, -/// while buses and bindings cannot be called. +/// and one for buses and bindings. This is because functions cannot be bound, added to or remove +/// from, while buses and bindings cannot be called. /// So we can always disambiguate identifiers based on its usage. /// /// It is still probably best practice to avoid having name conflicts between functions, @@ -135,7 +132,7 @@ impl fmt::Display for NamespacedIdentifier { /// Represents an identifier qualified with both its parent module and namespace. /// /// This represents a globally-unique identity for a declaration -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)] pub struct QualifiedIdentifier { pub module: ModuleId, #[span] @@ -160,11 +157,11 @@ impl QualifiedIdentifier { pub fn is_builtin(&self) -> bool { use crate::symbols; - if self.module.name() == "$builtin" { + if self.module.len() == 1 && self.module[0].name() == "$builtin" { match self.item { NamespacedIdentifier::Function(id) => { matches!(id.name(), symbols::Sum | symbols::Prod) - } + }, _ => false, } } else { @@ -185,13 +182,14 @@ impl fmt::Display for QualifiedIdentifier { } /// Represents an identifier which requires name resolution at some stage during lowering. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Spanned)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Spanned)] pub enum ResolvableIdentifier { /// This identifier is resolved to a local binding (i.e. function parameter or let-bound var) Local(#[span] Identifier), /// This identifier is resolved to a global binding Global(#[span] Identifier), - /// This identifier is resolved to a non-local item (i.e. module-level declaration or imported item) + /// This identifier is resolved to a non-local item (i.e. module-level declaration or imported + /// item) Resolved(#[span] QualifiedIdentifier), /// This identifier is not yet resolved or is undefined in the current scope Unresolved(#[span] NamespacedIdentifier), @@ -228,7 +226,7 @@ impl ResolvableIdentifier { /// resolved/unresolved states pub fn module(&self) -> Option { match self { - Self::Resolved(qid) => Some(*qid.as_ref()), + Self::Resolved(qid) => Some(qid.module.clone()), _ => None, } } @@ -236,14 +234,14 @@ impl ResolvableIdentifier { /// Obtains a [NamespacedIdentifier] from this identifier #[inline] pub fn namespaced(&self) -> NamespacedIdentifier { - (*self).into() + self.clone().into() } /// Gets the [QualifiedIdentifier] if this identifier is of type `Resolved` #[inline] pub fn resolved(&self) -> Option { match self { - Self::Resolved(qid) => Some(*qid), + Self::Resolved(qid) => Some(qid.clone()), _ => None, } } @@ -336,7 +334,7 @@ impl Expr { let rows = matrix.len(); let cols = matrix[0].len(); Some(Type::Matrix(rows, cols)) - } + }, Self::SymbolAccess(access) => access.ty, Self::Binary(_) => Some(Type::Felt), Self::Call(call) => call.ty, @@ -358,7 +356,7 @@ impl fmt::Debug for Expr { Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(), Self::ListComprehension(expr) => { f.debug_tuple("ListComprehension").field(expr).finish() - } + }, Self::Let(let_expr) => write!(f, "{let_expr:#?}"), Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(), Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(), @@ -381,7 +379,7 @@ impl fmt::Display for Expr { write!(f, "{}", DisplayList(col.as_slice()))?; } f.write_str("]") - } + }, Self::SymbolAccess(expr) => write!(f, "{expr}"), Self::Binary(expr) => write!(f, "{expr}"), Self::Call(expr) => write!(f, "{expr}"), @@ -393,7 +391,7 @@ impl fmt::Display for Expr { in_expr_position: true, }; write!(f, "{display}") - } + }, Self::BusOperation(expr) => write!(f, "{expr}"), Self::Null(_expr) => write!(f, "null"), Self::Unconstrained(_expr) => write!(f, "unconstrained"), @@ -447,16 +445,15 @@ impl TryFrom for Expr { #[inline] fn try_from(expr: ScalarExpr) -> Result { match expr { - ScalarExpr::Const(spanned) => Ok(Self::Const(Span::new( - spanned.span(), - ConstantExpr::Scalar(spanned.item), - ))), + ScalarExpr::Const(spanned) => { + Ok(Self::Const(Span::new(spanned.span(), ConstantExpr::Scalar(spanned.item)))) + }, ScalarExpr::SymbolAccess(access) => Ok(Self::SymbolAccess(access)), ScalarExpr::Binary(expr) => Ok(Self::Binary(expr)), ScalarExpr::Call(expr) => Ok(Self::Call(expr)), ScalarExpr::BoundedSymbolAccess(_) => { Err(InvalidExprError::BoundedSymbolAccess(expr.span())) - } + }, ScalarExpr::Let(expr) => Ok(Self::Let(expr)), ScalarExpr::BusOperation(expr) => Ok(Self::BusOperation(expr)), ScalarExpr::Null(spanned) => Ok(Self::Null(spanned)), @@ -487,7 +484,8 @@ pub enum ScalarExpr { /// /// NOTE: Symbol accesses in a `ScalarExpr` context must produce scalar values. SymbolAccess(SymbolAccess), - /// A reference to a trace column on a particular boundary of the trace, which must produce a scalar + /// A reference to a trace column on a particular boundary of the trace, which must produce a + /// scalar /// /// NOTE: This is only a valid expression in boundary constraints BoundedSymbolAccess(BoundedSymbolAccess), @@ -498,8 +496,8 @@ pub enum ScalarExpr { /// NOTE: This is only a valid scalar expression when one of the following hold: /// /// 1. The call is the top-level expression of a constraint, and is to an evaluator function - /// 2. The call is not the top-level expression of a constraint, and is to a pure function - /// that produces a scalar value type. + /// 2. The call is not the top-level expression of a constraint, and is to a pure function that + /// produces a scalar value type. /// /// If neither of the above are true, the call is invalid in a `ScalarExpr` context Call(Call), @@ -522,7 +520,8 @@ impl ScalarExpr { matches!(self, Self::Const(_)) } - /// Returns true if this scalar expression could expand to a block, e.g. due to a function call being inlined. + /// Returns true if this scalar expression could expand to a block, e.g. due to a function call + /// being inlined. pub fn has_block_like_expansion(&self) -> bool { match self { Self::Binary(expr) => expr.has_block_like_expansion(), @@ -551,7 +550,7 @@ impl ScalarExpr { Self::Let(expr) => Ok(expr.ty()), Self::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => { Ok(Some(Type::Felt)) - } + }, } } } @@ -566,7 +565,7 @@ impl TryFrom for ScalarExpr { ConstantExpr::Scalar(v) => Ok(Self::Const(Span::new(span, v))), _ => Err(InvalidExprError::InvalidScalarExpr(span)), } - } + }, Expr::SymbolAccess(sym) => Ok(Self::SymbolAccess(sym)), Expr::Binary(bin) => Ok(Self::Binary(bin)), Expr::Call(call) => Ok(Self::Call(call)), @@ -576,7 +575,7 @@ impl TryFrom for ScalarExpr { } else { Ok(Self::Let(let_expr)) } - } + }, invalid => Err(InvalidExprError::InvalidScalarExpr(invalid.span())), } } @@ -604,7 +603,7 @@ impl fmt::Debug for ScalarExpr { Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(), Self::BoundedSymbolAccess(expr) => { f.debug_tuple("BoundedSymbolAccess").field(expr).finish() - } + }, Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(), Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(), Self::Let(expr) => write!(f, "{expr:#?}"), @@ -629,7 +628,7 @@ impl fmt::Display for ScalarExpr { in_expr_position: true, }; write!(f, "{display}") - } + }, Self::BusOperation(expr) => write!(f, "{expr}"), Self::Null(_value) => write!(f, "null"), Self::Unconstrained(_value) => write!(f, "unconstrained"), @@ -708,7 +707,7 @@ impl RangeExpr { match (&self.start, &self.end) { (RangeBound::Const(start), RangeBound::Const(end)) => { Some(Type::Vector(end.item.abs_diff(start.item))) - } + }, _ => None, } } @@ -788,7 +787,8 @@ impl BinaryExpr { } } - /// Returns true if this binary expression could expand to a block, e.g. due to a function call being inlined. + /// Returns true if this binary expression could expand to a block, e.g. due to a function call + /// being inlined. #[inline] pub fn has_block_like_expansion(&self) -> bool { self.lhs.has_block_like_expansion() || self.rhs.has_block_like_expansion() @@ -859,7 +859,7 @@ impl fmt::Display for Boundary { } /// Represents the way an identifier is accessed/referenced in the source. -#[derive(Hash, Debug, Clone, Eq, PartialEq, Default)] +#[derive(Debug, Clone, Eq, PartialEq, Default)] pub enum AccessType { /// Access refers to the entire bound value #[default] @@ -869,19 +869,17 @@ pub enum AccessType { /// Access binds the value at a specific index of an aggregate value (i.e. vector or matrix) /// /// The result type may be either a scalar or a vector, depending on the type of the aggregate - Index(usize), + Index(Box), /// Access binds the value at a specific row and column of a matrix value - Matrix(usize, usize), + Matrix(Box, Box), } impl fmt::Display for AccessType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Default => write!(f, "direct reference by name"), - Self::Slice(range) => write!( - f, - "slice of elements at indices {}..{}", - range.start, range.end - ), + Self::Slice(range) => { + write!(f, "slice of elements at indices {}..{}", range.start, range.end) + }, Self::Index(idx) => write!(f, "reference to element at index {idx}"), Self::Matrix(row, col) => write!(f, "reference to value in matrix at [{row}][{col}]"), } @@ -961,9 +959,9 @@ impl SymbolAccess { AccessType::Default => self.access_default(access_type), AccessType::Slice(base_range) => { self.access_slice(base_range.to_slice_range(), access_type) - } - AccessType::Index(base_idx) => self.access_index(*base_idx, access_type), - AccessType::Matrix(_, _) => match access_type { + }, + AccessType::Index(base_idx) => self.access_index(base_idx.clone(), access_type), + AccessType::Matrix(..) => match access_type { AccessType::Default => Ok(self.clone()), _ => Err(InvalidAccessError::IndexIntoScalar), }, @@ -976,13 +974,11 @@ impl SymbolAccess { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), Type::Vector(_) => Ok(Self { access_type: AccessType::Index(idx), ty: Some(Type::Felt), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), Type::Matrix(_, cols) => Ok(Self { access_type: AccessType::Index(idx), ty: Some(Type::Vector(cols)), @@ -996,7 +992,7 @@ impl SymbolAccess { Type::Felt => Err(InvalidAccessError::IndexIntoScalar), Type::Vector(len) if slice_range.end > len => { Err(InvalidAccessError::IndexOutOfBounds) - } + }, Type::Vector(_) => Ok(Self { access_type: AccessType::Slice(range), ty: Some(Type::Vector(rlen)), @@ -1004,20 +1000,17 @@ impl SymbolAccess { }), Type::Matrix(rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) - } + }, Type::Matrix(_, cols) => Ok(Self { access_type: AccessType::Slice(range), ty: Some(Type::Matrix(rlen, cols)), ..self.clone() }), } - } + }, AccessType::Matrix(row, col) => match ty { Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { - Err(InvalidAccessError::IndexOutOfBounds) - } - Type::Matrix(_, _) => Ok(Self { + Type::Matrix(..) => Ok(Self { access_type: AccessType::Matrix(row, col), ty: Some(Type::Felt), ..self.clone() @@ -1036,15 +1029,29 @@ impl SymbolAccess { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { Type::Felt => unreachable!(), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), Type::Vector(_) => Ok(Self { - access_type: AccessType::Index(base_range.start + idx), + access_type: AccessType::Index(Box::new(ScalarExpr::Binary(BinaryExpr { + span: self.span(), + op: BinaryOp::Add, + lhs: Box::new(ScalarExpr::Const(Span::new( + self.span(), + base_range.start as u64, + ))), + rhs: idx.clone(), + }))), ty: Some(Type::Felt), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), Type::Matrix(_, cols) => Ok(Self { - access_type: AccessType::Index(base_range.start + idx), + access_type: AccessType::Index(Box::new(ScalarExpr::Binary(BinaryExpr { + span: self.span(), + op: BinaryOp::Add, + lhs: Box::new(ScalarExpr::Const(Span::new( + self.span(), + base_range.start as u64, + ))), + rhs: idx.clone(), + }))), ty: Some(Type::Vector(cols)), ..self.clone() }), @@ -1064,7 +1071,7 @@ impl SymbolAccess { Type::Felt => unreachable!(), Type::Vector(_) if slice_range.end > blen => { Err(InvalidAccessError::IndexOutOfBounds) - } + }, Type::Vector(_) => Ok(Self { access_type: AccessType::Slice(shifted), ty: Some(Type::Vector(rlen)), @@ -1072,20 +1079,17 @@ impl SymbolAccess { }), Type::Matrix(rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) - } + }, Type::Matrix(_, cols) => Ok(Self { access_type: AccessType::Slice(shifted), ty: Some(Type::Matrix(rlen, cols)), ..self.clone() }), } - } + }, AccessType::Matrix(row, col) => match ty { Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { - Err(InvalidAccessError::IndexOutOfBounds) - } - Type::Matrix(_, _) => Ok(Self { + Type::Matrix(..) => Ok(Self { access_type: AccessType::Matrix(row, col), ty: Some(Type::Felt), ..self.clone() @@ -1096,7 +1100,7 @@ impl SymbolAccess { fn access_index( &self, - base_idx: usize, + base_idx: Box, access_type: AccessType, ) -> Result { let ty = self.ty.unwrap(); @@ -1104,13 +1108,11 @@ impl SymbolAccess { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), Type::Vector(_) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), ty: Some(Type::Felt), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), Type::Matrix(_, cols) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), ty: Some(Type::Vector(cols)), @@ -1118,7 +1120,7 @@ impl SymbolAccess { }), }, AccessType::Slice(_) => Err(InvalidAccessError::SliceOfMatrix), - AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar), + AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), } } } @@ -1172,11 +1174,7 @@ pub struct BoundedSymbolAccess { } impl BoundedSymbolAccess { pub const fn new(span: SourceSpan, column: SymbolAccess, boundary: Boundary) -> Self { - Self { - span, - boundary, - column, - } + Self { span, boundary, column } } } impl Eq for BoundedSymbolAccess {} @@ -1270,11 +1268,7 @@ impl fmt::Debug for ListComprehension { impl fmt::Display for ListComprehension { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.bindings.len() == 1 { - write!( - f, - "{} for {} in {}", - &self.body, &self.bindings[0], &self.iterables[0] - )?; + write!(f, "{} for {} in {}", &self.body, &self.bindings[0], &self.iterables[0])?; } else { write!( f, @@ -1330,13 +1324,7 @@ impl fmt::Debug for BusOperation { } impl fmt::Display for BusOperation { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}{}{}", - self.bus, - self.op, - DisplayTuple(self.args.as_slice()) - ) + write!(f, "{}{}{}", self.bus, self.op, DisplayTuple(self.args.as_slice())) } } @@ -1369,9 +1357,9 @@ pub struct Call { /// The reason this field is an `Option` is two-fold: /// /// * Calls to evaluators produce no value, and thus have no type - /// * When parsed, the callee has not yet been resolved, so we don't know the - /// type of the function being called. During semantic analysis, the callee is - /// resolved and this field is set to the result type of that function. + /// * When parsed, the callee has not yet been resolved, so we don't know the type of the + /// function being called. During semantic analysis, the callee is resolved and this field is + /// set to the result type of that function. pub ty: Option, } impl Call { @@ -1409,7 +1397,10 @@ impl Call { } fn new_builtin(span: SourceSpan, name: &str, args: Vec, ty: Type) -> Self { - let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin")); + let builtin_module = ModuleId::new( + vec![Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin"))], + SourceSpan::UNKNOWN, + ); let name = Identifier::new(span, Symbol::intern(name)); let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name)); Self { diff --git a/parser/src/ast/mod.rs b/parser/src/ast/mod.rs index b5e17976e..1686455e7 100644 --- a/parser/src/ast/mod.rs +++ b/parser/src/ast/mod.rs @@ -8,15 +8,6 @@ mod trace; mod types; pub mod visit; -pub use self::declarations::*; -pub(crate) use self::display::*; -pub use self::errors::*; -pub use self::expression::*; -pub use self::module::*; -pub use self::statement::*; -pub use self::trace::*; -pub use self::types::*; - use std::{ collections::{BTreeMap, HashMap, HashSet, VecDeque}, fmt, mem, @@ -24,10 +15,13 @@ use std::{ sync::Arc, }; -use miden_diagnostics::{ - CodeMap, DiagnosticsHandler, FileName, Severity, SourceSpan, Span, Spanned, -}; +use miden_diagnostics::{CodeMap, DiagnosticsHandler, FileName, Severity, SourceSpan, Span}; +use petgraph::visit::EdgeRef; +pub(crate) use self::display::*; +pub use self::{ + declarations::*, errors::*, expression::*, module::*, statement::*, trace::*, types::*, +}; use crate::{ Symbol, parser::ParseError, @@ -132,10 +126,20 @@ impl Program { root: ModuleId, mut library: Library, ) -> Result { - use crate::sema::DependencyType; use petgraph::visit::DfsPostOrder; - let mut program = Program::new(root); + use crate::sema::DependencyType; + + if root.len() != 1 { + diagnostics + .diagnostic(Severity::Error) + .with_message("root module must be a single identifier") + .emit(); + return Err(SemanticAnalysisError::MissingRoot); + } + let root_name = root[0]; + + let mut program = Program::new(root_name); // Validate that the root module is contained in the library if !library.contains(&root) { @@ -147,28 +151,39 @@ impl Program { let root_module = library.get_mut(&root).unwrap(); mem::swap(&mut program.public_inputs, &mut root_module.public_inputs); mem::swap(&mut program.trace_columns, &mut root_module.trace_columns); + program.buses = BTreeMap::from_iter(root_module.buses.iter().map(|(k, v)| { + ( + QualifiedIdentifier::new(root.clone(), NamespacedIdentifier::Binding(*k)), + v.clone(), + ) + })); } // Build the module graph starting from the root module let mut modgraph = sema::ModuleGraph::new(); let mut visited = HashSet::::default(); let mut worklist = VecDeque::new(); - worklist.push_back(root); + let mut nodes = BTreeMap::::new(); + worklist.push_back(root.clone()); while let Some(module_name) = worklist.pop_front() { // If we haven't visited the imported module yet, add it's imports to the graph - if visited.insert(module_name) { - modgraph.add_node(module_name); - + if visited.insert(module_name.clone()) { + let module_node_index = + get_node_index_or_add(&mut modgraph, &mut nodes, &module_name); if let Some(module) = library.get(&module_name) { for import in module.imports.values() { - let import_module = modgraph.add_node(import.module()); + let import_module_node_index = + get_node_index_or_add(&mut modgraph, &mut nodes, &import.module()); // If an attempt is made to import the root module, raise an error - if import_module == root { - return Err(SemanticAnalysisError::RootImport(import.module().span())); + if import.module() == root { + return Err(SemanticAnalysisError::RootImport(root.span())); } - assert_eq!(modgraph.add_edge(module_name, import_module, ()), None); - worklist.push_back(import_module); + assert!( + !modgraph.contains_edge(module_node_index, import_module_node_index) + ); + modgraph.add_edge(module_node_index, import_module_node_index, ()); + worklist.push_back(import.module()); } } else { return Err(SemanticAnalysisError::MissingModule(module_name)); @@ -182,28 +197,47 @@ impl Program { // In each dependency module, we resolve all identifiers in that module to // their fully-qualified form, and add edges in the dependency graph which // represent what items are referenced from the functions/constraints in that module. - let mut deps = sema::DependencyGraph::new(); - let mut visitor = DfsPostOrder::new(&modgraph, root); - while let Some(module_name) = visitor.next(&modgraph) { + let mut deps_graph = sema::DependencyGraph::new(); + let mut deps_nodes = BTreeMap::::new(); + let root_node_index = get_node_index_or_add(&mut modgraph, &mut nodes, &root); + let mut visitor = DfsPostOrder::new(&modgraph, root_node_index); + while let Some(module_name_node_index) = visitor.next(&modgraph) { + let module_name = modgraph + .node_weight(module_name_node_index) + .expect("Did not find module in graph"); + // Remove the module from the library temporarily, so that we // can look up other modules in the library while we modify it // // NOTE: This will always succeed, or we would have raised an error // during semantic analysis - let mut module = library.modules.remove(&module_name).unwrap(); + let mut module = library.modules.remove(module_name).unwrap(); // Resolve imports let resolver = sema::ImportResolver::new(diagnostics, &library); let imported = resolver.run(&mut module)?; - // Perform semantic analysis on the module, updating the - // dependency graph with information gathered from this module - let analysis = - sema::SemanticAnalysis::new(diagnostics, &program, &library, &mut deps, imported); + // Perform semantic analysis on the module, updating the dependency graph with + // information gathered from this module. The dependency graph is built up + // incrementally as we analyze each module, each node is a fully-qualified identifier + // representing an item in the program, and edges represent dependencies between those + // items (e.g. a constant is used in a function). + // + // NOTE: nodes are stored in `deps_nodes` and accessed/added to the graph as needed + // through `NodeIndex` type to reference them (as QualifiedIdentifier does not implement + // Copy). + let analysis = sema::SemanticAnalysis::new( + diagnostics, + &program, + &library, + &mut deps_graph, + &mut deps_nodes, + imported, + ); analysis.run(&mut module)?; // Put the module back - library.modules.insert(module.name, module); + library.modules.insert(module.path.clone(), module); } // Now that we have a dependency graph for each function/constraint in the root module, @@ -212,7 +246,7 @@ impl Program { // from the boundary_constraints and integrity_constraints sections, or any of the functions // in the root module. let root_node = QualifiedIdentifier::new( - program.name, + ModuleId::new(vec![program.name], SourceSpan::UNKNOWN), NamespacedIdentifier::Binding(Identifier::new( SourceSpan::UNKNOWN, Symbol::intern("$$root"), @@ -229,18 +263,9 @@ impl Program { if let Some(ic) = root_module.integrity_constraints.as_ref() { program.integrity_constraints = ic.to_vec(); } - // Make sure we move the buses into the program - if !root_module.buses.is_empty() { - program.buses = BTreeMap::from_iter(root_module.buses.iter().map(|(k, v)| { - ( - QualifiedIdentifier::new(root, NamespacedIdentifier::Binding(*k)), - v.clone(), - ) - })); - } for evaluator in root_module.evaluators.values() { root_nodes.push_back(QualifiedIdentifier::new( - root, + root.clone(), NamespacedIdentifier::Function(evaluator.name), )); } @@ -248,42 +273,51 @@ impl Program { let mut visited = HashSet::::default(); while let Some(node) = root_nodes.pop_front() { - for (_, referenced, dep_type) in - deps.edges_directed(node, petgraph::Direction::Outgoing) - { + let node_index = deps_graph + .node_indices() + .find(|i| deps_graph.node_weight(*i).unwrap() == &node) + .expect("Did not find node in graph"); + for edges in deps_graph.edges_directed(node_index, petgraph::Direction::Outgoing) { + let dep_type = edges.weight(); + let referenced_node_index = edges.target(); + let referenced = deps_graph + .node_weight(referenced_node_index) + .expect("Did not find node in graph") + .clone(); + // Avoid spinning infinitely in dependency cycles - if !visited.insert(referenced) { + if !visited.insert(referenced.clone()) { continue; } // Add dependency to program let referenced_module = library.get(&referenced.module).unwrap(); - let id = referenced.item.id(); + let id = referenced.clone().item.id(); match dep_type { DependencyType::Constant => { program .constants - .entry(referenced) + .entry(referenced.clone()) .or_insert_with(|| referenced_module.constants[&id].clone()); - } + }, DependencyType::Evaluator => { program .evaluators - .entry(referenced) + .entry(referenced.clone()) .or_insert_with(|| referenced_module.evaluators[&id].clone()); - } + }, DependencyType::Function => { program .functions - .entry(referenced) + .entry(referenced.clone()) .or_insert_with(|| referenced_module.functions[&id].clone()); - } + }, DependencyType::PeriodicColumn => { program .periodic_columns - .entry(referenced) + .entry(referenced.clone()) .or_insert_with(|| referenced_module.periodic_columns[&id].clone()); - } + }, } // Make sure we visit all of the dependencies of this dependency @@ -329,13 +363,8 @@ impl fmt::Display for Program { if !self.periodic_columns.is_empty() { writeln!(f, "periodic_columns {{")?; for (qid, column) in self.periodic_columns.iter() { - if qid.module == self.name { - writeln!( - f, - " {}: {}", - &qid.item, - DisplayList(column.values.as_slice()) - )?; + if qid.module.0.item == vec![self.name] { + writeln!(f, " {}: {}", &qid.item, DisplayList(column.values.as_slice()))?; } else { writeln!(f, " {}: {}", qid, DisplayList(column.values.as_slice()))?; } @@ -346,7 +375,7 @@ impl fmt::Display for Program { if !self.constants.is_empty() { for (qid, constant) in self.constants.iter() { - if qid.module == self.name { + if qid.module.0.item == vec![self.name] { writeln!(f, "const {} = {}", &qid.item, &constant.value)?; } else { writeln!(f, "const {} = {}", qid, &constant.value)?; @@ -371,13 +400,8 @@ impl fmt::Display for Program { for (qid, evaluator) in self.evaluators.iter() { f.write_str("ev ")?; - if qid.module == self.name { - writeln!( - f, - "{}{}", - &qid.item, - DisplayTuple(evaluator.params.as_slice()) - )?; + if qid.module.0.item == vec![self.name] { + writeln!(f, "{}{}", &qid.item, DisplayTuple(evaluator.params.as_slice()))?; } else { writeln!(f, "{}{}", qid, DisplayTuple(evaluator.params.as_slice()))?; } @@ -391,20 +415,10 @@ impl fmt::Display for Program { for (qid, function) in self.functions.iter() { f.write_str("fn ")?; - if qid.module == self.name { - writeln!( - f, - "{}{}", - &qid.item, - DisplayTypedTuple(function.params.as_slice()) - )?; + if qid.module.0.item == vec![self.name] { + writeln!(f, "{}{}", &qid.item, DisplayTypedTuple(function.params.as_slice()))?; } else { - writeln!( - f, - "{}{}", - qid, - DisplayTypedTuple(function.params.as_slice()) - )?; + writeln!(f, "{}{}", qid, DisplayTypedTuple(function.params.as_slice()))?; } for statement in function.body.iter() { @@ -433,7 +447,7 @@ impl Library { ) -> Result { use std::collections::hash_map::Entry; - let mut lib = Library::default(); + let mut lib: Library = Library::default(); if modules.is_empty() { return Ok(lib); @@ -442,23 +456,22 @@ impl Library { // Register all parsed modules first let mut found_duplicate = None; for module in modules.drain(..) { - match lib.modules.entry(module.name) { + match lib.modules.entry(module.path.clone()) { Entry::Occupied(entry) => { - let prev_span = entry.key().span(); - found_duplicate = Some(prev_span); + found_duplicate = Some(entry.key().span()); diagnostics .diagnostic(Severity::Error) .with_message("conflicting module definitions") .with_primary_label( - module.name.span(), + module.path.span(), "this module name is already in use", ) - .with_secondary_label(prev_span, "originally defined here") + .with_secondary_label(entry.key().span(), "originally defined here") .emit(); - } + }, Entry::Vacant(entry) => { entry.insert(module); - } + }, } } @@ -476,16 +489,11 @@ impl Library { if module.imports.is_empty() { None } else { - let imports = module - .imports - .values() - .map(|i| i.module()) - .collect::>(); - Some((*name, imports)) + let imports = module.imports.values().map(|i| i.module()).collect::>(); + Some((name.clone(), imports)) } }) .collect::>(); - // Cache the current working directory for use in constructing file paths in case // we need to parse referenced modules from disk, and do not have a file path associated // with the importing module with which to derive the import path. @@ -495,74 +503,108 @@ impl Library { // to modules in the library. If the module is already in the library, we proceed, // if it isn't, then we must parse the desired module from disk, and add it to the // library, visiting any of its imports as well. - while let Some((module, mut imports)) = worklist.pop_front() { + while let Some((_module, mut imports)) = worklist.pop_front() { // We attempt to resolve imports on disk relative to the file path of the // importing module, if it was parsed from disk. If no path is available, // we default to the current working directory. - let source_dir = match codemap.name(module.span().source_id()) { + + let (real_path, source_dir) = match codemap + .name(imports.first().unwrap().span().source_id()) + { // If we have no source span, default to the current working directory - Err(_) => cwd.clone(), + Err(_) => (false, cwd.clone()), // If the file is virtual, then we've either already parsed imports for this module, // or we have to fall back to the current working directory, but we have no relative // path from which to base our search. - Ok(FileName::Virtual(_)) => cwd.clone(), - Ok(FileName::Real(path)) => path - .parent() - .unwrap_or_else(|| Path::new(".")) - .to_path_buf(), + Ok(FileName::Virtual(_)) => (false, cwd.clone()), + Ok(FileName::Real(path)) => { + (true, path.parent().unwrap_or_else(|| Path::new(".")).to_path_buf()) + }, }; - // For each module imported, try to load the module from the library, if it is unavailable - // we must do extra work to load it into the library, as described above. + // For each module imported, try to load the module from the library, if it is + // unavailable we must do extra work to load it into the library, as + // described above. for import in imports.drain(..) { - if let Entry::Vacant(entry) = lib.modules.entry(import) { - let filename = source_dir.join(format!("{}.air", import.as_str())); - // Check if the module exists in the codemap first, so that we can add files directly - // to the codemap during testing for convenience + if !lib.modules.contains_key(&import.clone()) { + let mut filename = source_dir.clone(); + let mut use_default_mod = false; + let items = &import.0.item; + if let Some((last, parents)) = items.split_last() { + for part in parents { + filename = filename.join(part.as_str()); + } + let path_exist = { + let mut check_path = filename.clone(); + check_path = check_path.join(format!("{}.air", last.as_str())); + check_path.exists() + }; + if path_exist || !real_path { + filename = filename.join(format!("{}.air", last.as_str())); + } else { + filename = filename.join(last.as_str()).join("mod.air"); + use_default_mod = true; + } + } + + // Check if the module exists in the codemap first, so that we can add files + // directly to the codemap during testing for convenience let result = match codemap.get_by_name(&FileName::Real(filename.clone())) { Some(file) => crate::parse_module(diagnostics, codemap.clone(), file), None => { crate::parse_module_from_file(diagnostics, codemap.clone(), &filename) - } + }, }; + match result { Ok(imported_module) => { + let mut imported_module = imported_module; + // We must check if the file we parsed actually contains a module with // the same name as our import, if not, that's an error - if imported_module.name != import { - diagnostics.diagnostic(Severity::Error) - .with_message("invalid module declaration") - .with_primary_label(imported_module.name.span(), "module names must be the same as the name of the file they are defined in") - .emit(); - return Err(SemanticAnalysisError::ImportFailed(import.span())); - } else { - // We parsed the module successfully, so add it to the library - if !imported_module.imports.is_empty() { - let imports = imported_module - .imports - .values() - .map(|i| i.module()) - .collect::>(); - worklist.push_back((imported_module.name, imports)); + + if !use_default_mod { + let last_import_part = import.0.item.last().unwrap(); + let module_name_parts = imported_module.path.0.item.last().unwrap(); + if module_name_parts != last_import_part { + diagnostics.diagnostic(Severity::Error) + .with_message("invalid module declaration") + .with_primary_label(imported_module.path.span(), "module names must be the same as the name of the file they are defined in") + .emit(); + return Err(SemanticAnalysisError::ImportFailed(import.span())); } - entry.insert(imported_module); } - } + + imported_module.path = import.clone(); + + // We parsed the module successfully, so add it to the library + if !imported_module.imports.is_empty() { + let imports = imported_module + .imports + .values() + .map(|i| i.module()) + .collect::>(); + worklist.push_back((imported_module.path.clone(), imports)); + } + lib.modules.insert(imported_module.path.clone(), imported_module); + }, Err(ParseError::Failed) => { - // Nothing interesting to emit as a diagnostic here, so just return an error + // Nothing interesting to emit as a diagnostic here, so just return an + // error return Err(SemanticAnalysisError::ImportFailed(import.span())); - } + }, Err(err) => { // Emit the error as a diagnostic and return an ImportError instead diagnostics.emit(err); return Err(SemanticAnalysisError::ImportFailed(import.span())); - } + }, } } } } - // All imports have been resolved, but additional processing is required to merge modules together in a program + // All imports have been resolved, but additional processing is required to merge modules + // together in a program Ok(lib) } @@ -585,4 +627,22 @@ impl Library { pub fn get_mut(&mut self, module: &ModuleId) -> Option<&mut Module> { self.modules.get_mut(module) } + + pub fn get_submodules_of(&self, module: &ModuleId) -> Vec { + self.modules.keys().filter(|m| m.is_submodule_of(module)).cloned().collect() + } +} + +/// Adds the given module to the module graph if it does not already exist, +/// returning the corresponding node index. +fn get_node_index_or_add( + modgraph: &mut sema::ModuleGraph, + nodes: &mut BTreeMap, + module_name: &ModuleId, +) -> petgraph::graph::NodeIndex { + nodes.get(module_name).cloned().unwrap_or_else(|| { + let index = modgraph.add_node(module_name.clone()); + nodes.insert(module_name.clone(), index); + index + }) } diff --git a/parser/src/ast/module.rs b/parser/src/ast/module.rs index 62177fcbd..0b7a0eacc 100644 --- a/parser/src/ast/module.rs +++ b/parser/src/ast/module.rs @@ -1,18 +1,68 @@ -use std::collections::{BTreeMap, HashSet}; +use std::{ + collections::{BTreeMap, HashSet}, + ops::Index, +}; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; use crate::{ast::*, sema::SemanticAnalysisError}; -/// This is a type alias used to clarify that an identifier refers to a module -pub type ModuleId = Identifier; +/// This is a type alias used to clarify that a module is referenced by a sequence of identifiers +/// representing its path in the module hierarchy (e.g., `foo::bar::baz`). +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)] +pub struct ModuleId(pub Span>); + +impl ModuleId { + pub fn new(identifiers: Vec, span: SourceSpan) -> Self { + Self(Span::new(span, identifiers)) + } + + /// Returns the span of this module identifier + pub fn span(&self) -> SourceSpan { + self.0.span() + } + + pub fn len(&self) -> usize { + self.0.item.len() + } + + pub fn is_empty(&self) -> bool { + self.0.item.is_empty() + } + + /// Returns true if this module identifier is a submodule of another module identifier. + /// For example, `foo::bar` is a submodule of `foo`. Note that a module is considered a + /// submodule of itself. + pub fn is_submodule_of(&self, parent_module: &ModuleId) -> bool { + if self.len() < parent_module.len() { + return false; + } + self.0.item.iter().zip(parent_module.0.item.iter()).all(|(a, b)| a == b) + } +} + +impl Index for ModuleId { + type Output = Identifier; + + fn index(&self, index: usize) -> &Self::Output { + &self.0.item[index] + } +} + +impl std::fmt::Display for ModuleId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let names: Vec<&str> = self.0.item.iter().map(|id| id.as_str()).collect(); + write!(f, "{}", names.join("::")) + } +} #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ModuleType { /// Only one root module may be defined in an AirScript program, using `def`. /// /// The root module has no restrictions on what sections it can contain, and in a - /// sense "provides" restricted sections to other modules in the program, e.g. the trace columns. + /// sense "provides" restricted sections to other modules in the program, e.g. the trace + /// columns. Root, /// Any number of library modules are permitted in an AirScript program, using `module`. /// @@ -47,7 +97,7 @@ pub enum ModuleType { pub struct Module { #[span] pub span: SourceSpan, - pub name: ModuleId, + pub path: ModuleId, pub ty: ModuleType, pub imports: BTreeMap, pub constants: BTreeMap, @@ -70,10 +120,10 @@ impl Module { /// the caller to guarantee that they construct a valid module that upholds those /// guarantees, otherwise it is expected that compilation will panic at some point down /// the line. - pub fn new(ty: ModuleType, span: SourceSpan, name: ModuleId) -> Self { + pub fn new(ty: ModuleType, span: SourceSpan, path: ModuleId) -> Self { Self { span, - name, + path, ty, imports: Default::default(), constants: Default::default(), @@ -98,10 +148,10 @@ impl Module { diagnostics: &DiagnosticsHandler, ty: ModuleType, span: SourceSpan, - name: Identifier, + path: ModuleId, mut declarations: Vec, ) -> Result { - let mut module = Self::new(ty, span, name); + let mut module = Self::new(ty, span, path); // Keep track of named items in this module while building it from // the set of declarations we received. We want to produce modules @@ -114,21 +164,21 @@ impl Module { match declaration { Declaration::Import(import) => { module.declare_import(diagnostics, &mut names, import)?; - } + }, Declaration::Constant(constant) => { module.declare_constant(diagnostics, &mut names, constant)?; - } + }, Declaration::EvaluatorFunction(evaluator) => { module.declare_evaluator(diagnostics, &mut names, evaluator)?; - } + }, Declaration::Function(function) => { module.declare_function(diagnostics, &mut names, function)?; - } + }, Declaration::PeriodicColumns(mut columns) => { for column in columns.drain(..) { module.declare_periodic_column(diagnostics, &mut names, column)?; } - } + }, Declaration::PublicInputs(mut inputs) => { if module.is_library() { invalid_section_in_library(diagnostics, "public_inputs", span); @@ -137,21 +187,25 @@ impl Module { for input in inputs.item.drain(..) { module.declare_public_input(diagnostics, &mut names, input)?; } - } + }, Declaration::Trace(segments) => { module.declare_trace_segments(diagnostics, &mut names, segments)?; - } + }, Declaration::BoundaryConstraints(statements) => { module.declare_boundary_constraints(diagnostics, statements)?; - } + }, Declaration::IntegrityConstraints(statements) => { module.declare_integrity_constraints(diagnostics, statements)?; - } + }, Declaration::Buses(mut buses) => { + if module.is_library() { + invalid_section_in_library(diagnostics, "buses", span); + return Err(SemanticAnalysisError::RootSectionInLibrary(span)); + } for bus in buses.drain(..) { module.declare_bus(diagnostics, &mut names, bus)?; } - } + }, } } @@ -193,12 +247,13 @@ impl Module { use std::collections::btree_map::Entry; let span = import.span(); - match import.item { - Import::All { module: name } => { - if name == self.name { - return Err(SemanticAnalysisError::ImportSelf(name.span())); + match import.item.clone() { + Import::All { module: path } => { + if path == self.path { + return Err(SemanticAnalysisError::ImportSelf(path.span())); } - match self.imports.entry(name) { + + match self.imports.entry(path.clone()) { Entry::Occupied(mut entry) => { let first = entry.key().span(); match entry.get_mut() { @@ -209,7 +264,7 @@ impl Module { .with_primary_label(span, "duplicate import occurs here") .with_secondary_label(first, "original import was here") .emit(); - } + }, Import::Partial { items, .. } => { for item in items.iter() { diagnostics @@ -217,45 +272,40 @@ impl Module { .with_message("redundant item import") .with_primary_label(item.span(), "this import is redundant") .with_secondary_label( - name.span(), + path.span(), "because this import imports all items already", ) .emit(); } entry.insert(import.item); - } + }, } - } + }, Entry::Vacant(entry) => { entry.insert(import.item); - } + }, } Ok(()) - } - Import::Partial { - module: name, - mut items, - } => { - if name == self.name { - return Err(SemanticAnalysisError::ImportSelf(name.span())); + }, + Import::Partial { module: path, mut items } => { + if path == self.path { + return Err(SemanticAnalysisError::ImportSelf(path.span())); } - match self.imports.entry(name) { + match self.imports.entry(path.clone()) { Entry::Occupied(mut entry) => match entry.get_mut() { Import::All { module: prev } => { diagnostics .diagnostic(Severity::Warning) .with_message("redundant module import") - .with_primary_label(name.span(), "this import is redundant") + .with_primary_label(path.span(), "this import is redundant") .with_secondary_label( prev.span(), "because this import includes all items already", ) .emit(); - } - Import::Partial { - items: prev_items, .. - } => { + }, + Import::Partial { items: prev_items, .. } => { for item in items.drain() { if let Some(prev) = prev_items.get(&item) { diagnostics @@ -285,7 +335,7 @@ impl Module { return Err(SemanticAnalysisError::NameConflict(item.span())); } } - } + }, }, Entry::Vacant(entry) => { for item in items.iter().copied() { @@ -304,15 +354,12 @@ impl Module { return Err(SemanticAnalysisError::NameConflict(item.span())); } } - entry.insert(Import::Partial { - module: name, - items, - }); - } + entry.insert(Import::Partial { module: path, items }); + }, } Ok(()) - } + }, } } @@ -326,10 +373,7 @@ impl Module { diagnostics .diagnostic(Severity::Error) .with_message("constant identifiers must be uppercase ASCII characters, e.g. FOO") - .with_primary_label( - constant.name.span(), - "this is an invalid constant identifier", - ) + .with_primary_label(constant.name.span(), "this is an invalid constant identifier") .emit(); return Err(SemanticAnalysisError::Invalid); } @@ -341,10 +385,8 @@ impl Module { // Validate constant expression if let ConstantExpr::Matrix(matrix) = &constant.value { - let expected_len = matrix - .first() - .expect("expected matrix to have at least one row") - .len(); + let expected_len = + matrix.first().expect("expected matrix to have at least one row").len(); for vector in matrix.iter().skip(1) { if expected_len != vector.len() { diagnostics @@ -394,6 +436,8 @@ impl Module { return Err(SemanticAnalysisError::NameConflict(function.name.span())); } + println!("Declared function: {:?}", function.name); + self.functions.insert(function.name, function); Ok(()) @@ -436,14 +480,14 @@ impl Module { assert_eq!(self.periodic_columns.insert(column.name, column), None); Ok(()) - } + }, _ => { diagnostics.diagnostic(Severity::Error) .with_message("invalid periodic column declaration") .with_primary_label(column.span(), "periodic columns must have a non-zero cycle length which is a power of two") .emit(); Err(SemanticAnalysisError::Invalid) - } + }, } } @@ -458,12 +502,7 @@ impl Module { } if let Some(prev) = names.replace(NamespacedIdentifier::Binding(input.name())) { - conflicting_declaration( - diagnostics, - "public input", - prev.span(), - input.name().span(), - ); + conflicting_declaration(diagnostics, "public input", prev.span(), input.name().span()); Err(SemanticAnalysisError::NameConflict(input.name().span())) } else { assert_eq!(self.public_inputs.insert(input.name(), input), None); @@ -588,6 +627,7 @@ impl Module { .values() .map(Export::Constant) .chain(self.evaluators.values().map(Export::Evaluator)) + .chain(self.functions.values().map(Export::Function)) } /// Get the export with the given identifier, if it can be found @@ -595,14 +635,17 @@ impl Module { if id.is_uppercase() { self.constants.get(id).map(Export::Constant) } else { - self.evaluators.get(id).map(Export::Evaluator) + self.evaluators + .get(id) + .map(Export::Evaluator) + .or_else(|| self.functions.get(id).map(Export::Function)) } } } impl Eq for Module {} impl PartialEq for Module { fn eq(&self, other: &Self) -> bool { - self.name == other.name + self.path == other.path && self.ty == other.ty && self.imports == other.imports && self.constants == other.constants diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index c7f94089c..795dbd8d2 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -49,12 +49,13 @@ pub enum Statement { /// /// This variant is only present in the AST after inlining is performed, even though the parser /// could produce it directly from the parse tree. This is because this variant is equivalent to - /// a comprehension constraint with a single element, so we transform all syntax corresponding to - /// `EnforceIf` into `EnforceAll` form so we can reuse all of the analyses/optimizations/transformations - /// for both. However, when lowering to the IR, we perform inlining/unrolling of comprehensions, and - /// at that time we need `EnforceIf` in order to represent unrolled constraints which have a selector - /// that is only resolvable at runtime. - EnforceIf(#[span] ScalarExpr, ScalarExpr), + /// a comprehension constraint with a single element, so we transform all syntax corresponding + /// to `EnforceIf` into `EnforceAll` form so we can reuse all of the + /// analyses/optimizations/transformations for both. However, when lowering to the IR, we + /// perform inlining/unrolling of comprehensions, and at that time we need `EnforceIf` in + /// order to represent unrolled constraints which have a selector that is only resolvable at + /// runtime. + EnforceIf(Match), /// Declares a constraint to be enforced over a vector of values produced by a comprehension. /// /// Just like `Enforce`, except the constraint is contained in the body of a list comprehension, @@ -63,6 +64,65 @@ pub enum Statement { /// Declares a bus related constraint BusEnforce(ListComprehension), } + +#[derive(Clone, Spanned, Debug, Eq)] +pub struct Match { + #[span] + pub span: SourceSpan, + pub match_arms: Vec, +} + +impl Match { + pub fn new(span: SourceSpan, match_arms: Vec) -> Self { + Self { span, match_arms } + } +} + +impl PartialEq for Match { + fn eq(&self, other: &Self) -> bool { + self.match_arms == other.match_arms + } +} + +impl fmt::Display for Match { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "match: ")?; + for arm in self.match_arms.iter() { + write!(f, "{arm}")?; + } + Ok(()) + } +} + +#[derive(Clone, Spanned, Debug, Eq)] +pub struct MatchArm { + #[span] + pub span: SourceSpan, + /// The condition to be matched + pub condition: ScalarExpr, + /// The expression to be enforced if the condition is matched + pub expr: ScalarExpr, +} + +impl MatchArm { + pub fn new(span: SourceSpan, expr: ScalarExpr, condition: ScalarExpr) -> Self { + Self { span, expr, condition } + } +} + +impl PartialEq for MatchArm { + fn eq(&self, other: &Self) -> bool { + self.condition == other.condition && self.expr == other.expr + } +} + +impl fmt::Display for MatchArm { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} when {}", self.expr, self.condition)?; + Ok(()) + } +} + impl Statement { /// Checks this statement to see if it contains any constraints /// @@ -71,20 +131,16 @@ impl Statement { /// one or more constraints in its body. pub fn has_constraints(&self) -> bool { match self { - Self::Enforce(_) - | Self::EnforceIf(_, _) - | Self::EnforceAll(_) - | Self::BusEnforce(_) => true, + Self::Enforce(_) | Self::EnforceIf(..) | Self::EnforceAll(_) | Self::BusEnforce(_) => { + true + }, Self::Let(Let { body, .. }) => body.iter().any(|s| s.has_constraints()), Self::Expr(_) => false, } } pub fn display(&self, indent: usize) -> DisplayStatement<'_> { - DisplayStatement { - statement: self, - indent, - } + DisplayStatement { statement: self, indent } } } impl From for Statement { @@ -142,32 +198,28 @@ pub struct Let { } impl Let { pub fn new(span: SourceSpan, name: Identifier, value: Expr, body: Vec) -> Self { - Self { - span, - name, - value, - body, - } + Self { span, name, value, body } } /// Return the type of the overall `let` expression. /// /// A `let` with an empty body, or with a body that terminates with a non-expression statement - /// has no type (or rather, one could consider the type it returns to be of "void" or "unit" type). + /// has no type (or rather, one could consider the type it returns to be of "void" or "unit" + /// type). /// /// For `let` statements with a non-empty body that terminates with an expression, the `let` can - /// be used in expression position, producing the value of the terminating expression in its body, - /// and having the same type as that value. + /// be used in expression position, producing the value of the terminating expression in its + /// body, and having the same type as that value. pub fn ty(&self) -> Option { let mut last = self.body.last(); while let Some(stmt) = last.take() { match stmt { Statement::Let(let_expr) => { last = let_expr.body.last(); - } + }, Statement::Expr(expr) => return expr.ty(), Statement::Enforce(_) - | Statement::EnforceIf(_, _) + | Statement::EnforceIf(..) | Statement::EnforceAll(_) | Statement::BusEnforce(_) => break, } diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index aac904bad..5b26a05d0 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -5,7 +5,26 @@ use miden_diagnostics::{SourceSpan, Spanned}; use super::*; /// The id of a trace segment is its index in the trace_columns declaration -pub type TraceSegmentId = usize; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum TraceSegmentId { + Main = 0, + Aux = 1, +} + +impl fmt::Display for TraceSegmentId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TraceSegmentId::Main => write!(f, "main"), + TraceSegmentId::Aux => write!(f, "aux"), + } + } +} + +impl From for usize { + fn from(value: TraceSegmentId) -> Self { + value as usize + } +} /// The index of a column in a particular trace segment pub type TraceColumnIndex = usize; @@ -29,7 +48,8 @@ pub struct TraceSegment { pub boundary_constrained: Vec>, } impl TraceSegment { - /// Constructs a new [TraceSegment] given a span, segment id, name, and a vector of (Identifier, size) pairs. + /// Constructs a new [TraceSegment] given a span, segment id, name, and a vector of (Identifier, + /// size) pairs. pub fn new( span: SourceSpan, id: TraceSegmentId, @@ -44,14 +64,7 @@ impl TraceSegment { 1 => Type::Felt, n => Type::Vector(n), }; - bindings.push(TraceBinding::new( - binding.span(), - name, - id, - offset, - size, - ty, - )); + bindings.push(TraceBinding::new(binding.span(), name, id, offset, size, ty)); offset += size; } @@ -108,10 +121,7 @@ impl fmt::Debug for TraceSegment { .field("name", &self.name) .field("size", &self.size) .field("bindings", &self.bindings) - .field( - "boundary_constrained", - &FormatConstrainedFlags(&self.boundary_constrained), - ) + .field("boundary_constrained", &FormatConstrainedFlags(&self.boundary_constrained)) .finish() } } @@ -195,16 +205,14 @@ impl std::ops::BitAnd for ColumnBoundaryFlags { struct FormatConstrainedFlags<'a>(&'a [Span]); impl fmt::Debug for FormatConstrainedFlags<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_list() - .entries(self.0.iter().map(|c| c.item)) - .finish() + f.debug_list().entries(self.0.iter().map(|c| c.item)).finish() } } /// [TraceBinding] is used to represent one or more columns in the execution trace that are bound to /// a name. For single columns, the size is 1. For groups, the size is the number of columns in the /// group. The offset is the column index in the trace where the first column of the binding starts. -#[derive(Copy, Clone, Spanned)] +#[derive(Clone, Spanned)] pub struct TraceBinding { #[span] pub span: SourceSpan, @@ -218,6 +226,8 @@ pub struct TraceBinding { pub size: usize, /// The effective type of this binding pub ty: Type, + /// The access type associated to this TraceBinding + pub access: AccessType, } impl TraceBinding { /// Creates a new trace binding. @@ -236,53 +246,77 @@ impl TraceBinding { offset, size, ty, + access: AccessType::Default, } } /// Returns a [Type] that describes what type of value this binding represents #[inline] pub fn ty(&self) -> Type { - self.ty + match self.access.clone() { + AccessType::Default => self.ty, + AccessType::Slice(range_expr) => Type::Vector(range_expr.to_slice_range().len()), + AccessType::Index(_) => Type::Felt, + AccessType::Matrix(..) => { + unreachable!("matrix access not supported on trace bindings") + }, + } } #[inline] pub fn is_scalar(&self) -> bool { - self.ty.is_scalar() + self.ty().is_scalar() + } + + /// Returns the size of the trace binding, taking into account how it is accessed + pub fn tb_size(&self) -> usize { + match self.ty() { + Type::Vector(len) => len, + Type::Felt => 1, + _ => self.size, + } } /// Derive a new [TraceBinding] derived from the current one given an [AccessType] pub fn access(&self, access_type: AccessType) -> Result { - match access_type { - AccessType::Default => Ok(*self), - AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), - AccessType::Slice(range) => { - let slice_range = range.to_slice_range(); - if slice_range.end > self.size { - Err(InvalidAccessError::IndexOutOfBounds) - } else { - let offset = self.offset + slice_range.start; - let size = slice_range.len(); - Ok(Self { - offset, - size, - ty: Type::Vector(size), - ..*self - }) - } - } - AccessType::Index(_) if self.is_scalar() => Err(InvalidAccessError::IndexIntoScalar), - AccessType::Index(idx) if idx >= self.size => Err(InvalidAccessError::IndexOutOfBounds), - AccessType::Index(idx) => { - let offset = self.offset + idx; - Ok(Self { - offset, - size: 1, - ty: Type::Felt, - ..*self - }) - } - AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar), + let combined_access = match (self.access.clone(), access_type.clone()) { + (AccessType::Default, _) => access_type, + (_, AccessType::Default) => self.access.clone(), + (AccessType::Slice(range_expr), AccessType::Slice(range_expr1)) => { + let range_expr = range_expr.to_slice_range(); + let range_expr1 = range_expr1.to_slice_range(); + let combined_range = + (range_expr.start + range_expr1.start)..(range_expr.end + range_expr1.end); + AccessType::Slice(combined_range.into()) + }, + (AccessType::Slice(range_expr), AccessType::Index(index_expr)) => { + let range_expr_usize = range_expr.to_slice_range(); + let new_expr = ScalarExpr::Binary(BinaryExpr::new( + self.span(), + BinaryOp::Add, + ScalarExpr::Const(Span::new(range_expr.span(), range_expr_usize.start as u64)), + *index_expr, + )); + AccessType::Index(Box::new(new_expr)) + }, + (AccessType::Index(_), AccessType::Index(_)) => { + return Err(InvalidAccessError::IndexIntoScalar); + }, + (AccessType::Matrix(..), _) | (_, AccessType::Matrix(..)) => { + return Err(InvalidAccessError::IndexIntoScalar); + }, + (expression::AccessType::Index(_), expression::AccessType::Slice(_)) => { + return Err(InvalidAccessError::SliceOfScalar); + }, + }; + + if let AccessType::Index(idx) = combined_access.clone() + && let ScalarExpr::Const(value) = *idx + && value.item as usize >= self.size + { + return Err(InvalidAccessError::IndexOutOfBounds); } + Ok(Self { access: combined_access, ..*self }) } } impl Eq for TraceBinding {} @@ -309,18 +343,9 @@ impl fmt::Debug for TraceBinding { impl fmt::Display for TraceBinding { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.size == 1 { - write!( - f, - "{}", - self.name.as_ref().map(|n| n.as_str()).unwrap_or("?") - ) + write!(f, "{}", self.name.as_ref().map(|n| n.as_str()).unwrap_or("?")) } else { - write!( - f, - "{}[{}]", - self.name.as_ref().map(|n| n.as_str()).unwrap_or("?"), - self.size - ) + write!(f, "{}[{}]", self.name.as_ref().map(|n| n.as_str()).unwrap_or("?"), self.size) } } } diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index da3d0d976..cc2f26931 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -16,7 +16,7 @@ impl Type { pub fn is_aggregate(&self) -> bool { match self { Self::Felt => false, - Self::Vector(_) | Self::Matrix(_, _) => true, + Self::Vector(_) | Self::Matrix(..) => true, } } @@ -51,10 +51,9 @@ impl Type { } else { Ok(Self::Vector(slice_range.len())) } - } - AccessType::Index(idx) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + }, AccessType::Index(_) => Ok(Self::Felt), - AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar), + AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), _ => unreachable!(), }, Self::Matrix(rows, cols) => match access_type { @@ -65,13 +64,9 @@ impl Type { } else { Ok(Self::Matrix(slice_range.len(), cols)) } - } - AccessType::Index(idx) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), + }, AccessType::Index(_) => Ok(Self::Vector(cols)), - AccessType::Matrix(row, col) if row >= rows || col >= cols => { - Err(InvalidAccessError::IndexOutOfBounds) - } - AccessType::Matrix(_, _) => Ok(Self::Felt), + AccessType::Matrix(..) => Ok(Self::Felt), _ => unreachable!(), }, } diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index 1f799cb75..9537c9159 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -26,7 +26,7 @@ use crate::ast; /// /// use miden_diagnostics::{Span, Spanned}; /// -/// use air_parser::ast::{self, visit}; +/// use air_parser::ast::{self, visit, ScalarExpr}; /// /// /// A simple visitor which replaces accesses to constant values with the values themselves, /// /// evaluates constant expressions (i.e. expressions whose operands are constant), and propagates @@ -59,19 +59,13 @@ use crate::ast; /// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value))); /// } /// Some((span, ast::ConstantExpr::Vector(value))) => { -/// match sym.access_type { -/// ast::AccessType::Index(idx) => { -/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[idx]))); -/// } -/// _ => panic!("invalid constant reference, expected scalar access"), +/// if let ast::AccessType::Index(idx) = sym.access_type.clone() && let ScalarExpr::Const(idx) = *idx { +/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[idx.item as usize]))); /// } /// } /// Some((span, ast::ConstantExpr::Matrix(value))) => { -/// match sym.access_type { -/// ast::AccessType::Matrix(row, col) => { -/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[row][col]))); -/// } -/// _ => panic!("invalid constant reference, expected scalar access"), +/// if let ast::AccessType::Matrix(row, col) = sym.access_type.clone() && let ScalarExpr::Const(row) = *row && let ScalarExpr::Const(col) = *col { +/// core::mem::replace(expr, ast::ScalarExpr::Const(Span::new(span, value[row.item as usize][col.item as usize]))); /// } /// } /// } @@ -105,7 +99,6 @@ use crate::ast; /// } /// } /// ``` -/// pub trait VisitMut { fn visit_mut_module(&mut self, module: &mut ast::Module) -> ControlFlow { visit_mut_module(self, module) @@ -170,13 +163,12 @@ pub trait VisitMut { fn visit_mut_enforce(&mut self, expr: &mut ast::ScalarExpr) -> ControlFlow { visit_mut_scalar_expr(self, expr) } - fn visit_mut_enforce_if( - &mut self, - expr: &mut ast::ScalarExpr, - selector: &mut ast::ScalarExpr, - ) -> ControlFlow { - self.visit_mut_enforce(expr)?; - self.visit_mut_scalar_expr(selector) + fn visit_mut_enforce_if(&mut self, match_expr: &mut ast::Match) -> ControlFlow { + for arm in match_expr.match_arms.iter_mut() { + self.visit_mut_scalar_expr(&mut arm.condition)?; + self.visit_mut_scalar_expr(&mut arm.expr)?; + } + ControlFlow::Continue(()) } fn visit_mut_enforce_all(&mut self, expr: &mut ast::ListComprehension) -> ControlFlow { self.visit_mut_list_comprehension(expr) @@ -322,12 +314,8 @@ where fn visit_mut_enforce(&mut self, expr: &mut ast::ScalarExpr) -> ControlFlow { (**self).visit_mut_enforce(expr) } - fn visit_mut_enforce_if( - &mut self, - expr: &mut ast::ScalarExpr, - selector: &mut ast::ScalarExpr, - ) -> ControlFlow { - (**self).visit_mut_enforce_if(expr, selector) + fn visit_mut_enforce_if(&mut self, match_expr: &mut ast::Match) -> ControlFlow { + (**self).visit_mut_enforce_if(match_expr) } fn visit_mut_enforce_all(&mut self, expr: &mut ast::ListComprehension) -> ControlFlow { (**self).visit_mut_enforce_all(expr) @@ -422,15 +410,15 @@ where for segment in module.trace_columns.iter_mut() { visitor.visit_mut_trace_segment(segment)?; } - if let Some(bc) = module.boundary_constraints.as_mut() { - if !bc.is_empty() { - visitor.visit_mut_boundary_constraints(bc)?; - } + if let Some(bc) = module.boundary_constraints.as_mut() + && !bc.is_empty() + { + visitor.visit_mut_boundary_constraints(bc)?; } - if let Some(ic) = module.integrity_constraints.as_mut() { - if !ic.is_empty() { - visitor.visit_mut_integrity_constraints(ic)?; - } + if let Some(ic) = module.integrity_constraints.as_mut() + && !ic.is_empty() + { + visitor.visit_mut_integrity_constraints(ic)?; } ControlFlow::Continue(()) @@ -571,7 +559,7 @@ where match expr { ast::Statement::Let(expr) => visitor.visit_mut_let(expr), ast::Statement::Enforce(expr) => visitor.visit_mut_enforce(expr), - ast::Statement::EnforceIf(expr, selector) => visitor.visit_mut_enforce_if(expr, selector), + ast::Statement::EnforceIf(match_expr) => visitor.visit_mut_enforce_if(match_expr), ast::Statement::EnforceAll(expr) => visitor.visit_mut_enforce_all(expr), ast::Statement::Expr(expr) => visitor.visit_mut_expr(expr), ast::Statement::BusEnforce(expr) => visitor.visit_mut_bus_enforce(expr), @@ -600,13 +588,13 @@ where visitor.visit_mut_range_bound(&mut range.start)?; visitor.visit_mut_range_bound(&mut range.end)?; ControlFlow::Continue(()) - } + }, ast::Expr::Vector(exprs) => { for expr in exprs.iter_mut() { visitor.visit_mut_expr(expr)?; } ControlFlow::Continue(()) - } + }, ast::Expr::Matrix(matrix) => { for exprs in matrix.iter_mut() { for expr in exprs.iter_mut() { @@ -614,7 +602,7 @@ where } } ControlFlow::Continue(()) - } + }, ast::Expr::SymbolAccess(expr) => visitor.visit_mut_symbol_access(expr), ast::Expr::Binary(expr) => visitor.visit_mut_binary_expr(expr), ast::Expr::Call(expr) => visitor.visit_mut_call(expr), @@ -709,13 +697,16 @@ where V: ?Sized + VisitMut, { match expr { - ast::AccessType::Default | ast::AccessType::Index(_) | ast::AccessType::Matrix(_, _) => { - ControlFlow::Continue(()) - } + ast::AccessType::Default => ControlFlow::Continue(()), + ast::AccessType::Index(index) => visitor.visit_mut_scalar_expr(index), + ast::AccessType::Matrix(row, col) => { + visitor.visit_mut_scalar_expr(row)?; + visitor.visit_mut_scalar_expr(col) + }, ast::AccessType::Slice(range) => { visitor.visit_mut_range_bound(&mut range.start)?; visitor.visit_mut_range_bound(&mut range.end) - } + }, } } @@ -746,6 +737,7 @@ pub fn visit_mut_symbol_access( where V: ?Sized + VisitMut, { + visitor.visit_mut_access_type(&mut expr.access_type)?; visitor.visit_mut_resolvable_identifier(&mut expr.name) } diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index 5bc959e2a..f354cf6f4 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -15,10 +15,7 @@ pub type Lexed = Result<(SourceIndex, Token, SourceIndex), ParseError>; #[derive(Clone, Debug, thiserror::Error)] pub enum LexicalError { #[error("invalid integer value: {}", DisplayIntErrorKind(reason))] - InvalidInt { - span: SourceSpan, - reason: IntErrorKind, - }, + InvalidInt { span: SourceSpan, reason: IntErrorKind }, #[error("encountered unexpected character '{found}'")] UnexpectedCharacter { start: SourceIndex, found: char }, } @@ -27,7 +24,7 @@ impl PartialEq for LexicalError { match (self, other) { (Self::InvalidInt { reason: lhs, .. }, Self::InvalidInt { reason: rhs, .. }) => { lhs == rhs - } + }, ( Self::UnexpectedCharacter { found: lhs, .. }, Self::UnexpectedCharacter { found: rhs, .. }, @@ -41,18 +38,17 @@ impl ToDiagnostic for LexicalError { use miden_diagnostics::Label; match self { - Self::InvalidInt { span, ref reason } => Diagnostic::error() - .with_message("invalid integer literal") - .with_labels(vec![ + Self::InvalidInt { span, ref reason } => { + Diagnostic::error().with_message("invalid integer literal").with_labels(vec![ Label::primary(span.source_id(), span) .with_message(format!("{}", DisplayIntErrorKind(reason))), - ]), - Self::UnexpectedCharacter { start, .. } => Diagnostic::error() - .with_message("unexpected character") - .with_labels(vec![Label::primary( - start.source_id(), - SourceSpan::new(start, start), - )]), + ]) + }, + Self::UnexpectedCharacter { start, .. } => { + Diagnostic::error().with_message("unexpected character").with_labels(vec![ + Label::primary(start.source_id(), SourceSpan::new(start, start)), + ]) + }, } } } @@ -231,27 +227,27 @@ impl PartialEq for Token { if let Self::Num(i2) = other { return *i == *i2; } - } + }, Self::Error(_) => { if let Self::Error(_) = other { return true; } - } + }, Self::Ident(i) => { if let Self::Ident(i2) = other { return i == i2; } - } + }, Self::DeclIdentRef(i) => { if let Self::DeclIdentRef(i2) = other { return i == i2; } - } + }, Self::FunctionIdent(i) => { if let Self::FunctionIdent(i2) = other { return i == i2; } - } + }, _ => return mem::discriminant(self) == mem::discriminant(other), } false @@ -346,8 +342,9 @@ macro_rules! pop2 { }}; } -/// The lexer that is used to perform lexical analysis on the AirScript grammar. The lexer implements -/// the `Iterator` trait, so in order to retrieve the tokens, you simply have to iterate over it. +/// The lexer that is used to perform lexical analysis on the AirScript grammar. The lexer +/// implements the `Iterator` trait, so in order to retrieve the tokens, you simply have to iterate +/// over it. /// /// # Errors /// @@ -587,7 +584,7 @@ where start: self.span().start(), found: c, }); - } + }, } self.skip_ident(); diff --git a/parser/src/lexer/tests/identifiers.rs b/parser/src/lexer/tests/identifiers.rs index a4baefddb..0ada9a373 100644 --- a/parser/src/lexer/tests/identifiers.rs +++ b/parser/src/lexer/tests/identifiers.rs @@ -107,10 +107,7 @@ fn valid_tokenization_indexed_trace_access() { fn error_identifier_with_invalid_characters() { let source = "enf clk@' = clk + 1"; // "@" is not in the allowed characters. - let expected = LexicalError::UnexpectedCharacter { - start: SourceIndex::UNKNOWN, - found: '@', - }; + let expected = LexicalError::UnexpectedCharacter { start: SourceIndex::UNKNOWN, found: '@' }; expect_error_at_location(source, expected, 0, 7); } @@ -125,7 +122,7 @@ fn return_first_invalid_character_error() { LexicalError::UnexpectedCharacter { start, found: '@' } => { let expected = SourceIndex::new(start.source_id(), ByteIndex(7)); assert_eq!(start, expected); - } + }, err => panic!("unexpected lexical error in source: {err:#?}"), } } diff --git a/parser/src/lexer/tests/mod.rs b/parser/src/lexer/tests/mod.rs index 8b1532fc5..875979192 100644 --- a/parser/src/lexer/tests/mod.rs +++ b/parser/src/lexer/tests/mod.rs @@ -1,6 +1,8 @@ -use crate::Symbol; -use crate::lexer::{Lexer, LexicalError, Token}; -use crate::parser::ParseError; +use crate::{ + Symbol, + lexer::{Lexer, LexicalError, Token}, + parser::ParseError, +}; mod arithmetic_ops; mod boundary_constraints; @@ -54,7 +56,7 @@ fn expect_error_at_location(source: &str, expected: LexicalError, line: u32, col LexicalError::UnexpectedCharacter { start, .. } => { let span = miden_diagnostics::SourceSpan::new(*start, *start); codemap.location(&span).unwrap() - } + }, }; assert_eq!(err, expected); assert_eq!(loc.line, LineIndex(line)); diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 6988cd452..9c8a1e2cd 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -8,15 +8,40 @@ mod sema; pub mod symbols; pub mod transforms; -pub use self::parser::{ParseError, Parser}; -pub use self::sema::{LexicalScope, SemanticAnalysisError}; -pub use self::symbols::Symbol; - -use std::path::Path; -use std::sync::Arc; +use std::{path::Path, sync::Arc}; +use air_pass::Pass; use miden_diagnostics::{CodeMap, DiagnosticsHandler}; +pub use self::{ + parser::{ParseError, Parser}, + sema::{LexicalScope, SemanticAnalysisError}, + symbols::Symbol, +}; +use crate::ast::Program; + +/// Abstracts the various passes done on the AST representation of the program. +pub struct AstPasses<'a> { + diagnostics: &'a DiagnosticsHandler, +} + +impl<'a> AstPasses<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } +} + +impl Pass for AstPasses<'_> { + type Input<'a> = Program; + type Output<'a> = Program; + type Error = SemanticAnalysisError; + + fn run<'a>(&mut self, input: Self::Input<'a>) -> Result, Self::Error> { + let mut passes = transforms::ConstantPropagation::new(self.diagnostics); + passes.run(input) + } +} + /// Parses the provided source and returns the AST. pub fn parse( diagnostics: &DiagnosticsHandler, @@ -29,7 +54,7 @@ pub fn parse( Err(ParseError::Lexer(err)) => { diagnostics.emit(err); Err(ParseError::Failed) - } + }, Err(err) => Err(err), } } @@ -46,7 +71,7 @@ pub fn parse_file>( Err(ParseError::Lexer(err)) => { diagnostics.emit(err); Err(ParseError::Failed) - } + }, Err(err) => Err(err), } } @@ -85,7 +110,7 @@ pub(crate) fn parse_module_from_file>( Err(ParseError::Lexer(err)) => { diagnostics.emit(err); Err(ParseError::Failed) - } + }, err @ Err(_) => err, } } @@ -104,7 +129,7 @@ pub(crate) fn parse_module( Err(ParseError::Lexer(err)) => { diagnostics.emit(err); Err(ParseError::Failed) - } + }, err @ Err(_) => err, } } diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index 391379fbb..b06578f68 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -38,7 +38,7 @@ pub Source: Source = { pub Program: Program = { =>? { - let root_name = root.name; + let root_name = root.path.clone(); let mut modules = modules; modules.push(root); let library = match Library::new(diagnostics, codemap.clone(), modules) { @@ -57,14 +57,14 @@ pub AnyModule: Module = { Root: Module = { "def" =>? { - Module::from_declarations(diagnostics, ModuleType::Root, span!(l, r), name, decls) + Module::from_declarations(diagnostics, ModuleType::Root, span!(l, r), ModuleId::new(vec![name], name.span()), decls) .map_err(|err| ParseError::Analysis(err).into()) } } Module: Module = { "mod" =>? { - Module::from_declarations(diagnostics, ModuleType::Library, span!(l, r), name, decls) + Module::from_declarations(diagnostics, ModuleType::Library, span!(l, r), ModuleId::new(vec![name], name.span()), decls) .map_err(|err| ParseError::Analysis(err).into()) } } @@ -82,15 +82,53 @@ Declaration: Declaration = { => Declaration::IntegrityConstraints(<>), } +// Import parses a Vec separated by '::', where ImportLimb is Identifier or '*' (only as last limb) Import: Span = { - "use" "::" "*" ";" => Span::new(span!(l, r), Import::All { module: Identifier::new(span!(l, r), module.name()) }), - "use" "::" ";" => { - let mut items: HashSet = HashSet::default(); - items.insert(item); - Span::new(span!(l, r), Import::Partial { module, items }) + "use" ";" => { + let (last_limb, parents) = limbs.split_last().expect("Imports need at least one path limb"); + match last_limb { + ImportLimb::Star => { + let mut limbs_ident = vec![]; + for parent in parents { + match parent { + ImportLimb::Star => unreachable!(), + ImportLimb::Ident(ident) => limbs_ident.push(*ident) + } + } + Span::new(span!(l, r), Import::All { module: ModuleId::new(limbs_ident, span!(l, r)) }) + } + ImportLimb::Ident(item) => { + let mut limbs_ident = vec![]; + for parent in parents { + match parent { + ImportLimb::Star => unreachable!(), + ImportLimb::Ident(ident) => limbs_ident.push(*ident) + } + } + let mut items: HashSet = HashSet::default(); + items.insert(*item); + Span::new(span!(l, r), Import::Partial { module: ModuleId::new(limbs_ident, span!(l, r)), items }) + } + } } } +// ImportPath: Vec separated by '::' (single right-recursive rule) +ImportPath: Vec = { + => vec![head], + "::" => { + let mut v = vec![head]; + v.extend(tail); + v + } +} + +// ImportLimb: either an identifier or '*', but '*' only allowed as last limb +ImportLimb: ImportLimb = { + "*" => ImportLimb::Star, + => ImportLimb::Ident(id), +} + // TRACE COLUMNS // ================================================================================================ @@ -114,7 +152,7 @@ MainSegmentId: Identifier = { MainTraceBindings: TraceSegment = { ":" > "," => - TraceSegment::new(span!(l, r), 0, name, bindings), + TraceSegment::new(span!(l, r), TraceSegmentId::Main, name, bindings), } TraceBinding: Span<(Identifier, usize)> = { @@ -195,14 +233,22 @@ EvaluatorFunction: EvaluatorFunction = { EvaluatorBindings: Vec = { > =>? { - let mut segments = Vec::with_capacity(trace.len()); + if trace.len() != 1 { + diagnostics.diagnostic(Severity::Error) + .with_message("invalid evaluator function definition") + .with_primary_label(span!(l, r), "evaluators must have exactly one trace segment") + .emit(); + return Err(ParseError::Failed.into()); + } - for (segment, (span, bindings)) in trace.into_iter().enumerate() { + let mut segments = Vec::with_capacity(trace.len()); + + if let Some((span, bindings)) = trace.into_iter().next() { // We generate a name for these segments to distinguish them from direct references // to the actual main columns. This is useful during the inlining phase let segment_name = Identifier::new(span, Symbol::intern(format!("%{}", *next_var))); *next_var += 1; - segments.push(TraceSegment::new(span, segment, segment_name, bindings)); + segments.push(TraceSegment::new(span, TraceSegmentId::Main, segment_name, bindings)); } // the last segment of trace columns cannot be empty. @@ -305,7 +351,10 @@ ConstraintStatements: Vec = { } ConstraintStatement: Vec = { - "enf" "match" "{" "}" ";" => <>, + "enf" "match" "{" "}" ";" => { + let match_expr = Match::new(span!(l, r), arms); + vec![Statement::EnforceIf(match_expr)] + }, "enf" ";" => vec![<>], ";" => vec![<>], } @@ -314,13 +363,9 @@ ReturnStatement: Expr = { "return" ";" => expr, } -MatchArm: Statement = { +MatchArm: MatchArm = { "case" ":" "," => { - let generated_name = format!("%{}", *next_var); - *next_var += 1; - let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); - let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; - Statement::EnforceAll(ListComprehension::new(span!(l, r), constraint, context, Some(selector))) + MatchArm::new(span!(l, r), constraint, selector) } } @@ -360,12 +405,8 @@ ConstraintExpr: Statement = { } else { // If we didn't parse this as a comprehension, but a selector is present, the constraint is in form 3, // so transform it into form 1. Otherwise, if no selector is present, this is form 4, i.e. simple. - if selector.is_some() { - let generated_name = format!("%{}", *next_var); - *next_var += 1; - let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); - let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; - Statement::EnforceAll(ListComprehension::new(span!(l, r), expr, context, selector)) + if let Some(selector) = selector { + Statement::EnforceIf(Match::new(span!(l,r), vec![MatchArm::new(span!(l,r), expr, selector)])) } else { Statement::Enforce(expr) } @@ -611,8 +652,8 @@ Size: u64 = { "[" "]" => <> } -Index: usize = { - "[" "]" => idx as usize +Index: Box = { + "[" "]" => Box::new(idx) } TableSize: u64 = { diff --git a/parser/src/parser/mod.rs b/parser/src/parser/mod.rs index f588b7aaa..5d4455060 100644 --- a/parser/src/parser/mod.rs +++ b/parser/src/parser/mod.rs @@ -43,10 +43,7 @@ pub enum ParseError { #[error("invalid token")] InvalidToken(SourceIndex), #[error("unexpected end of file")] - UnexpectedEof { - at: SourceIndex, - expected: Vec, - }, + UnexpectedEof { at: SourceIndex, expected: Vec }, #[error("unrecognized token '{token}'")] UnrecognizedToken { span: SourceSpan, @@ -69,18 +66,10 @@ impl PartialEq for ParseError { (Self::InvalidToken(_), Self::InvalidToken(_)) => true, (Self::UnexpectedEof { expected: l, .. }, Self::UnexpectedEof { expected: r, .. }) => { l == r - } + }, ( - Self::UnrecognizedToken { - token: lt, - expected: l, - .. - }, - Self::UnrecognizedToken { - token: rt, - expected: r, - .. - }, + Self::UnrecognizedToken { token: lt, expected: l, .. }, + Self::UnrecognizedToken { token: rt, expected: r, .. }, ) => lt == rt && l == r, (Self::ExtraToken { token: l, .. }, Self::ExtraToken { token: r, .. }) => l == r, (Self::Failed, Self::Failed) => true, @@ -94,23 +83,18 @@ impl From> for ParseErr match err { LError::InvalidToken { location } => Self::InvalidToken(location), - LError::UnrecognizedEof { - location: at, - expected, - } => Self::UnexpectedEof { at, expected }, - LError::UnrecognizedToken { - token: (l, token, r), - expected, - } => Self::UnrecognizedToken { - span: SourceSpan::new(l, r), - token, - expected, + LError::UnrecognizedEof { location: at, expected } => { + Self::UnexpectedEof { at, expected } + }, + LError::UnrecognizedToken { token: (l, token, r), expected } => { + Self::UnrecognizedToken { + span: SourceSpan::new(l, r), + token, + expected, + } }, - LError::ExtraToken { - token: (l, token, r), - } => Self::ExtraToken { - span: SourceSpan::new(l, r), - token, + LError::ExtraToken { token: (l, token, r) } => { + Self::ExtraToken { span: SourceSpan::new(l, r), token } }, LError::User { error } => error, } @@ -137,16 +121,11 @@ impl ToDiagnostic for ParseError { } } - Diagnostic::error() - .with_message("unexpected eof") - .with_labels(vec![ - Label::primary(at.source_id(), SourceSpan::new(at, at)) - .with_message(message), - ]) - } - Self::UnrecognizedToken { - span, ref expected, .. - } => { + Diagnostic::error().with_message("unexpected eof").with_labels(vec![ + Label::primary(at.source_id(), SourceSpan::new(at, at)).with_message(message), + ]) + }, + Self::UnrecognizedToken { span, ref expected, .. } => { let mut message = "expected one of: ".to_string(); for (i, t) in expected.iter().enumerate() { if i == 0 { @@ -158,10 +137,8 @@ impl ToDiagnostic for ParseError { Diagnostic::error() .with_message("unexpected token") - .with_labels(vec![ - Label::primary(span.source_id(), span).with_message(message), - ]) - } + .with_labels(vec![Label::primary(span.source_id(), span).with_message(message)]) + }, Self::ExtraToken { span, .. } => Diagnostic::error() .with_message("extraneous token") .with_labels(vec![Label::primary(span.source_id(), span)]), @@ -206,7 +183,7 @@ impl miden_parsing::Parse for ast::Source { return Err(ParseError::Failed); } Ok(ast) - } + }, Err(lalrpop_util::ParseError::User { error }) => Err(error), Err(err) => Err(err.into()), } @@ -249,7 +226,7 @@ impl miden_parsing::Parse for ast::Program { return Err(ParseError::Failed); } Ok(ast) - } + }, Err(lalrpop_util::ParseError::User { error }) => Err(error), Err(err) => Err(err.into()), } @@ -292,7 +269,7 @@ impl miden_parsing::Parse for ast::Module { return Err(ParseError::Failed); } Ok(ast) - } + }, Err(lalrpop_util::ParseError::User { error }) => Err(error), Err(err) => Err(err.into()), } diff --git a/parser/src/parser/tests/arithmetic_ops.rs b/parser/src/parser/tests/arithmetic_ops.rs index 239d3fa27..c7e0ae970 100644 --- a/parser/src/parser/tests/arithmetic_ops.rs +++ b/parser/src/parser/tests/arithmetic_ops.rs @@ -1,8 +1,7 @@ use miden_diagnostics::SourceSpan; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // EXPRESSIONS // ================================================================================================ @@ -17,13 +16,13 @@ fn single_addition() { enf clk' + clk = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!(add!(access!(clk, 1), access!(clk)), int!(0)))], ), ); @@ -40,17 +39,14 @@ fn multi_addition() { enf clk' + clk + 2 = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce!(eq!( - add!(add!(access!(clk, 1), access!(clk)), int!(2)), - int!(0) - ))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce!(eq!(add!(add!(access!(clk, 1), access!(clk)), int!(2)), int!(0)))], ), ); ParseTest::new().expect_module_ast(source, expected); @@ -66,13 +62,13 @@ fn single_subtraction() { enf clk' - clk = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!(sub!(access!(clk, 1), access!(clk)), int!(0)))], ), ); @@ -89,17 +85,14 @@ fn multi_subtraction() { enf clk' - clk - 1 = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce!(eq!( - sub!(sub!(access!(clk, 1), access!(clk)), int!(1)), - int!(0) - ))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce!(eq!(sub!(sub!(access!(clk, 1), access!(clk)), int!(1)), int!(0)))], ), ); ParseTest::new().expect_module_ast(source, expected); @@ -115,13 +108,13 @@ fn single_multiplication() { enf clk' * clk = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!(mul!(access!(clk, 1), access!(clk)), int!(0)))], ), ); @@ -138,17 +131,14 @@ fn multi_multiplication() { enf clk' * clk * 2 = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce!(eq!( - mul!(mul!(access!(clk, 1), access!(clk)), int!(2)), - int!(0) - ))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce!(eq!(mul!(mul!(access!(clk, 1), access!(clk)), int!(2)), int!(0)))], ), ); ParseTest::new().expect_module_ast(source, expected); @@ -164,13 +154,13 @@ fn unit_with_parens() { enf (2) + 1 = 3; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!(add!(int!(2), int!(1)), int!(3)))], ), ); @@ -187,17 +177,14 @@ fn ops_with_parens() { enf (clk' + clk) * 2 = 4; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce!(eq!( - mul!(add!(access!(clk, 1), access!(clk)), int!(2)), - int!(4) - ))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce!(eq!(mul!(add!(access!(clk, 1), access!(clk)), int!(2)), int!(4)))], ), ); ParseTest::new().expect_module_ast(source, expected); @@ -213,13 +200,13 @@ fn const_exponentiation() { enf clk'^2 = 1; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!(exp!(access!(clk, 1), int!(2)), int!(1)))], ), ); @@ -236,17 +223,14 @@ fn non_const_exponentiation() { enf clk'^(clk + 2) = 1; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce!(eq!( - exp!(access!(clk, 1), add!(access!(clk), int!(2))), - int!(1) - ))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce!(eq!(exp!(access!(clk, 1), add!(access!(clk), int!(2))), int!(1)))], ), ); ParseTest::new().expect_module_ast(source, expected); @@ -286,13 +270,13 @@ fn multi_arithmetic_ops_same_precedence() { enf clk' - clk - 2 + 1 = 0; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!( add!(sub!(sub!(access!(clk, 1), access!(clk)), int!(2)), int!(1)), int!(0) @@ -318,18 +302,15 @@ fn multi_arithmetic_ops_different_precedence() { // 3. Addition/Subtraction // These operations are evaluated in the order of decreasing precedence. - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!( - sub!( - sub!(exp!(access!(clk, 1), int!(2)), mul!(access!(clk), int!(2))), - int!(1) - ), + sub!(sub!(exp!(access!(clk, 1), int!(2)), mul!(access!(clk), int!(2))), int!(1)), int!(0) ))], ), @@ -353,18 +334,15 @@ fn multi_arithmetic_ops_different_precedence_w_parens() { // 3. Multiplication // 4. Addition/Subtraction // These operations are evaluated in the order of decreasing precedence. - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!( - sub!( - access!(clk, 1), - mul!(exp!(access!(clk), int!(2)), sub!(int!(2), int!(1))) - ), + sub!(access!(clk, 1), mul!(exp!(access!(clk), int!(2)), sub!(int!(2), int!(1)))), int!(0) ))], ), diff --git a/parser/src/parser/tests/boundary_constraints.rs b/parser/src/parser/tests/boundary_constraints.rs index 6eb543a4f..6b66853c3 100644 --- a/parser/src/parser/tests/boundary_constraints.rs +++ b/parser/src/parser/tests/boundary_constraints.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // BOUNDARY STATEMENTS // ================================================================================================ @@ -53,26 +52,21 @@ integrity_constraints { /// /// This is used as a common base for most tests in this module fn test_module() -> Module { - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); - expected.buses.insert( - ident!(p), - Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset), - ); - expected.buses.insert( - ident!(q), - Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup), - ); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(clk), int!(0)))], - )); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected + .buses + .insert(ident!(p), Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset)); + expected + .buses + .insert(ident!(q), Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup)); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(clk), int!(0)))])); expected } @@ -90,10 +84,7 @@ fn boundary_constraint_at_first() { let mut expected = test_module(); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); ParseTest::new().expect_module_ast(&source, expected); } @@ -112,10 +103,7 @@ fn boundary_constraint_at_last() { let mut expected = test_module(); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::Last), - int!(15) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::Last), int!(15)))], )); ParseTest::new().expect_module_ast(&source, expected); } @@ -194,10 +182,7 @@ fn boundary_constraint_with_pub_input() { let mut expected = test_module(); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - access!(inputs[0]) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), access!(inputs[0])))], )); ParseTest::new().expect_module_ast(&source, expected); } @@ -242,9 +227,7 @@ fn boundary_constraint_with_const() { let mut expected = test_module(); expected.constants.insert(ident!(A), constant!(A = 1)); expected.constants.insert(ident!(B), constant!(B = [0, 1])); - expected - .constants - .insert(ident!(C), constant!(C = [[0, 1], [1, 0]])); + expected.constants.insert(ident!(C), constant!(C = [[0, 1], [1, 0]])); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!( diff --git a/parser/src/parser/tests/buses.rs b/parser/src/parser/tests/buses.rs index f9d65e5e4..2e4f1c95a 100644 --- a/parser/src/parser/tests/buses.rs +++ b/parser/src/parser/tests/buses.rs @@ -1,86 +1,170 @@ -use miden_diagnostics::SourceSpan; +use miden_diagnostics::{SourceSpan, Span}; +use super::ParseTest; use crate::ast::*; -use super::ParseTest; +const BASE_MODULE: &str = r#" +def test + +trace_columns { + main: [clk], +} + +public_inputs { + inputs: [2], +}"#; + +fn add_base_expectations(expected: &mut Module) { + expected + .trace_columns + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); +} + +fn add_base_boundary_expectation(expected: &mut Module) { + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], + )); +} + +fn add_base_integrity_expectation(expected: &mut Module) { + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(clk), int!(0)))])); +} #[test] fn buses() { - let source = " - mod test + let source = format!( + " + {BASE_MODULE} - buses { + buses {{ multiset p, logup q, - }"; + }} + + boundary_constraints {{ + enf clk.first = 0; + }} - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - expected.buses.insert( - ident!(p), - Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset), - ); - expected.buses.insert( - ident!(q), - Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup), + integrity_constraints {{ + enf clk = 0; + }} + " ); - ParseTest::new().expect_module_ast(source, expected); + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + add_base_expectations(&mut expected); + add_base_boundary_expectation(&mut expected); + add_base_integrity_expectation(&mut expected); + expected + .buses + .insert(ident!(p), Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset)); + expected + .buses + .insert(ident!(q), Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup)); + ParseTest::new().expect_module_ast(&source, expected); } #[test] fn boundary_constraints_buses() { - let _source = " - mod test + let source = format!( + " + {BASE_MODULE} - buses { + buses {{ multiset p, logup q, - } + }} - boundary_constraints { + boundary_constraints {{ enf p.first = null; enf q.last = null; - }"; + }} - /*let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - expected.buses.insert( - ident!(p), - Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset), - ); - expected.buses.insert( - ident!(q), - Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup), + integrity_constraints {{ + enf clk = 0; + }} + " ); - ParseTest::new().expect_module_ast(source, expected);*/ + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + add_base_expectations(&mut expected); + add_base_integrity_expectation(&mut expected); + expected + .buses + .insert(ident!(p), Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset)); + expected + .buses + .insert(ident!(q), Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![ + enforce!(eq!(bounded_access!(p, Boundary::First), null!())), + enforce!(eq!(bounded_access!(q, Boundary::Last), null!())), + ], + )); + ParseTest::new().expect_module_ast(&source, expected); } #[test] fn integrity_constraints_buses() { - let _source = " - mod test + let source = format!( + " + {BASE_MODULE} - buses { + buses {{ multiset p, logup q, - } + }} + + boundary_constraints {{ + enf clk.first = 0; + }} - integrity_constraints { + integrity_constraints {{ p.insert(1) when 1; p.remove(1) when 1; q.insert(1, 2) when 1; q.insert(1, 2) when 1; q.remove(1, 2) with 2; - }"; - - /*let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - expected.buses.insert( - ident!(p), - Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset), - ); - expected.buses.insert( - ident!(q), - Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup), + }} + " ); - ParseTest::new().expect_module_ast(source, expected);*/ + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + add_base_expectations(&mut expected); + add_base_boundary_expectation(&mut expected); + expected + .buses + .insert(ident!(p), Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset)); + expected + .buses + .insert(ident!(q), Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup)); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![ + bus_enforce!( + lc!((("%0", range!(0..1))) => bus_insert!(p, vec![expr!(int!(1))]), when int!(1)) + ), + bus_enforce!( + lc!((("%1", range!(0..1))) => bus_remove!(p, vec![expr!(int!(1))]), when int!(1)) + ), + bus_enforce!( + lc!((("%2", range!(0..1))) => bus_insert!(q, vec![expr!(int!(1)), expr!(int!(2))]), when int!(1)) + ), + bus_enforce!( + lc!((("%3", range!(0..1))) => bus_insert!(q, vec![expr!(int!(1)), expr!(int!(2))]), when int!(1)) + ), + bus_enforce!( + lc!((("%4", range!(0..1))) => bus_remove!(q, vec![expr!(int!(1)), expr!(int!(2))]), with int!(2)) + ), + ], + )); + ParseTest::new().expect_module_ast(&source, expected); } #[test] diff --git a/parser/src/parser/tests/calls.rs b/parser/src/parser/tests/calls.rs index cb329b6a2..90cb8aa82 100644 --- a/parser/src/parser/tests/calls.rs +++ b/parser/src/parser/tests/calls.rs @@ -1,8 +1,7 @@ use miden_diagnostics::SourceSpan; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; #[test] fn call_fold_identifier() { @@ -15,7 +14,7 @@ fn call_fold_identifier() { enf a = x + y; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); let body = vec![let_!(x = expr!(call!(sum(expr!(access!(c))))) => let_!(y = expr!(call!(prod(expr!(access!(c))))) => enforce!(eq!(access!(a), add!(access!(x), access!(y))))))]; @@ -24,7 +23,7 @@ fn call_fold_identifier() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(a, 1), (c, 2)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(a, 1), (c, 2)])], body, ), ); @@ -43,18 +42,16 @@ fn call_fold_vector_literal() { enf a = x + y; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - let body = vec![ - let_!(x = expr!(call!(sum(vector!(access!(a), access!(b), access!(c[0]))))) => + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); + let body = vec![let_!(x = expr!(call!(sum(vector!(access!(a), access!(b), access!(c[0]))))) => let_!(y = expr!(call!(prod(vector!(access!(a), access!(b), access!(c[0]))))) => - enforce!(eq!(access!(a), add!(access!(x), access!(y)))))), - ]; + enforce!(eq!(access!(a), add!(access!(x), access!(y))))))]; expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(a, 1), (b, 1), (c, 4)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(a, 1), (b, 1), (c, 4)])], body, ), ); @@ -73,7 +70,7 @@ fn call_fold_list_comprehension() { enf a = x + y; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); let body = vec![ let_!(x = expr!(call!(sum(lc!(((col, expr!(access!(c)))) => exp!(access!(col), int!(7))).into()))) => let_!(y = expr!(call!(prod(lc!(((col, expr!(access!(c)))) => exp!(access!(col), int!(7))).into()))) => @@ -84,7 +81,7 @@ fn call_fold_list_comprehension() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(a, 1), (b, 1), (c, 4)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(a, 1), (b, 1), (c, 4)])], body, ), ); diff --git a/parser/src/parser/tests/computed_indices.rs b/parser/src/parser/tests/computed_indices.rs new file mode 100644 index 000000000..f05b0865c --- /dev/null +++ b/parser/src/parser/tests/computed_indices.rs @@ -0,0 +1,143 @@ +use miden_diagnostics::{SourceSpan, Span}; + +use super::ParseTest; +use crate::ast::*; + +#[test] +fn basic_computed_indices() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + + enf a = x[1 + 1]; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!(x = vector!(int!(0), int!(1), int!(2), int!(3), int!(4)) => + enforce!(eq!(access!(a), access!(x[Box::new(add!(int!(1), int!(1)))]))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn basic_computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[1 + 1] for i in 0..5]; + + enf a = y[1 + 1]; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!(x = vector!(int!(0), int!(1), int!(2), int!(3), int!(4)) => + let_!(y = lc!(((i, range!(0usize, 5usize))) => mul!(access!(i), access!(x[Box::new(add!(int!(1), int!(1)))]))).into() => + enforce!(eq!(access!(a), access!(y[Box::new(add!(int!(1), int!(1)))])))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn computed_indices_in_lc() { + let source = " + def test + + trace_columns { + main: [a, b, c[4]], + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let x = [0, 1, 2, 3, 4]; + let y = [i * x[i + 1] for i in 0..4]; + + enf a = y[1 + 1]; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], + )); + expected.integrity_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![let_!(x = vector!(int!(0), int!(1), int!(2), int!(3), int!(4)) => + let_!(y = lc!(((i, range!(0usize, 4usize))) => mul!(access!(i), access!(x[Box::new(add!(access!(i), int!(1)))]))).into() => + enforce!(eq!(access!(a), access!(y[Box::new(add!(int!(1), int!(1)))])))))], + )); + + ParseTest::new().expect_module_ast(source, expected); +} diff --git a/parser/src/parser/tests/constant_propagation.rs b/parser/src/parser/tests/constant_propagation.rs index c0c0d0308..faaeb9d92 100644 --- a/parser/src/parser/tests/constant_propagation.rs +++ b/parser/src/parser/tests/constant_propagation.rs @@ -1,11 +1,9 @@ use air_pass::Pass; use miden_diagnostics::SourceSpan; - use pretty_assertions::assert_eq; -use crate::{ast::*, transforms::ConstantPropagation}; - use super::ParseTest; +use crate::{ast::*, transforms::ConstantPropagation}; #[test] fn test_constant_propagation() { @@ -57,7 +55,7 @@ fn test_constant_propagation() { Err(err) => { test.diagnostics.emit(err); panic!("expected parsing to succeed, see diagnostics for details"); - } + }, Ok(ast) => ast, }; @@ -66,38 +64,27 @@ fn test_constant_propagation() { let mut expected = Program::new(ident!(root)); expected.trace_columns.push(trace_segment!( - 0, + TraceSegmentId::Main, "$main", [(clk, 1), (a, 1), (b, 2), (c, 1)] )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); expected - .constants - .insert(ident!(root, A), constant!(A = [2, 4, 6, 8])); - expected - .constants - .insert(ident!(root, B), constant!(B = [[1, 1], [2, 2]])); - expected - .constants - .insert(ident!(lib, EXP), constant!(EXP = 2)); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0)); + expected.constants.insert(ident!(root, A), constant!(A = [2, 4, 6, 8])); + expected.constants.insert(ident!(root, B), constant!(B = [[1, 1], [2, 2]])); + expected.constants.insert(ident!(lib, EXP), constant!(EXP = 2)); // When constant propagation is done, the boundary constraints should look like: // enf a.first = 1 - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(a, Boundary::First, Type::Felt), - int!(1) - ))); + expected + .boundary_constraints + .push(enforce!(eq!(bounded_access!(a, Boundary::First, Type::Felt), int!(1)))); // When constant propagation is done, the integrity constraints should look like: // enf test_constraint(b) // enf a + 4 = c + 5 expected .integrity_constraints - .push(enforce!(call!(lib::test_constraint(expr!(access!( - b, - Type::Vector(2) - )))))); + .push(enforce!(call!(lib::test_constraint(expr!(access!(b, Type::Vector(2))))))); expected.integrity_constraints.push(enforce!(eq!( add!(access!(a, Type::Felt), int!(4)), add!(access!(c, Type::Felt), int!(5)) @@ -113,7 +100,7 @@ fn test_constant_propagation() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test_constraint), - vec![trace_segment!(0, "%0", [(b0, 1), (b1, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(b0, 1), (b1, 1)])], body, ), ); diff --git a/parser/src/parser/tests/constants.rs b/parser/src/parser/tests/constants.rs index 607510910..e9da31bb5 100644 --- a/parser/src/parser/tests/constants.rs +++ b/parser/src/parser/tests/constants.rs @@ -1,8 +1,7 @@ use miden_diagnostics::SourceSpan; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // CONSTANTS // ================================================================================================ @@ -15,7 +14,7 @@ fn constants_scalars() { const A = 1; const B = 2;"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.constants.insert( ident!(A), Constant::new(SourceSpan::UNKNOWN, ident!(A), ConstantExpr::Scalar(1)), @@ -35,22 +34,14 @@ fn constants_vectors() { const A = [1, 2, 3, 4]; const B = [5, 6, 7, 8];"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.constants.insert( ident!(A), - Constant::new( - SourceSpan::UNKNOWN, - ident!(A), - ConstantExpr::Vector(vec![1, 2, 3, 4]), - ), + Constant::new(SourceSpan::UNKNOWN, ident!(A), ConstantExpr::Vector(vec![1, 2, 3, 4])), ); expected.constants.insert( ident!(B), - Constant::new( - SourceSpan::UNKNOWN, - ident!(B), - ConstantExpr::Vector(vec![5, 6, 7, 8]), - ), + Constant::new(SourceSpan::UNKNOWN, ident!(B), ConstantExpr::Vector(vec![5, 6, 7, 8])), ); ParseTest::new().expect_module_ast(source, expected); } @@ -63,7 +54,7 @@ fn constants_matrices() { const A = [[1, 2], [3, 4]]; const B = [[5, 6], [7, 8]];"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.constants.insert( ident!(A), Constant::new( diff --git a/parser/src/parser/tests/evaluators.rs b/parser/src/parser/tests/evaluators.rs index b92371f99..ebe332c3f 100644 --- a/parser/src/parser/tests/evaluators.rs +++ b/parser/src/parser/tests/evaluators.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // EVALUATOR FUNCTIONS // ================================================================================================ @@ -16,13 +15,13 @@ fn ev_fn_main_cols() { enf clk' = clk + 1; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.evaluators.insert( ident!(advance_clock), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(advance_clock), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce!(eq!(access!(clk, 1), add!(access!(clk), int!(1))))], ), ); @@ -50,14 +49,13 @@ fn ev_fn_call_simple() { enf advance_clock([clk]); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], @@ -91,14 +89,15 @@ fn ev_fn_call() { enf advance_clock([a, b[1..3], c[2..4]]); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 2), (b, 4), (c, 6)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 2), (b, 4), (c, 6)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], @@ -124,14 +123,14 @@ fn ev_fn_call_inside_ev_fn() { enf advance_clock([clk]); }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); let body = vec![enforce!(call!(advance_clock(vector!(access!(clk)))))]; expected.evaluators.insert( ident!(ev_func), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(ev_func), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], body, ), ); @@ -160,14 +159,15 @@ fn ev_fn_call_with_more_than_two_args() { enf advance_clock([a], [b], [c]); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 1)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], @@ -188,11 +188,23 @@ fn ev_fn_call_with_more_than_two_args() { // ================================================================================================ #[test] -fn ev_fn_def_with_empty_final_arg() { +fn ev_fn_def_with_multiple_args() { + let source = " + mod test + + ev ev_func([clk], [a, b]) { + enf clk' = clk + 1 + }"; + ParseTest::new() + .expect_module_diagnostic(source, "evaluators must have exactly one trace segment"); +} + +#[test] +fn ev_fn_def_with_empty_arg() { let source = " mod test - ev ev_func([clk], []) { + ev ev_func([]) { enf clk' = clk + 1 }"; ParseTest::new().expect_module_diagnostic(source, "the last trace segment cannot be empty"); diff --git a/parser/src/parser/tests/functions.rs b/parser/src/parser/tests/functions.rs index d8f286b8a..f47e42efa 100644 --- a/parser/src/parser/tests/functions.rs +++ b/parser/src/parser/tests/functions.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // PURE FUNCTIONS // ================================================================================================ @@ -16,7 +15,7 @@ fn fn_def_with_scalars() { return a + b; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.functions.insert( ident!(fn_with_scalars), Function::new( @@ -39,7 +38,7 @@ fn fn_def_with_vectors() { return [x + y for (x, y) in (a, b)]; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.functions.insert( ident!(fn_with_vectors), Function::new( @@ -47,10 +46,8 @@ fn fn_def_with_vectors() { function_ident!(fn_with_vectors), vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], Type::Vector(12), - vec![return_!(expr!( - lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => - add!(access!(x), access!(y))) - ))], + vec![return_!(expr!(lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => + add!(access!(x), access!(y)))))], ), ); ParseTest::new().expect_module_ast(source, expected); @@ -81,7 +78,7 @@ fn fn_use_scalars_and_vectors() { enf a' = fn_with_scalars_and_vectors(a, b); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(root)); expected.functions.insert( ident!(fn_with_scalars_and_vectors), @@ -98,7 +95,7 @@ fn fn_use_scalars_and_vectors() { expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 12)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 1), (b, 12)])); expected.public_inputs.insert( ident!(stack_inputs), @@ -113,10 +110,7 @@ fn fn_use_scalars_and_vectors() { SourceSpan::UNKNOWN, vec![enforce!(eq!( access!(a, 1), - call!(fn_with_scalars_and_vectors( - expr!(access!(a)), - expr!(access!(b)) - )) + call!(fn_with_scalars_and_vectors(expr!(access!(a)), expr!(access!(b)))) ))], )); ParseTest::new().expect_module_ast(source, expected); @@ -151,7 +145,7 @@ fn fn_call_in_fn() { enf a' = fold_scalar_and_vec(a, b); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(root)); expected.functions.insert( ident!(fold_vec), @@ -160,9 +154,7 @@ fn fn_call_in_fn() { function_ident!(fold_vec), vec![(ident!(a), Type::Vector(12))], Type::Felt, - vec![return_!(expr!(call!(sum(expr!( - lc!(((x, expr!(access!(a)))) => access!(x)) - )))))], + vec![return_!(expr!(call!(sum(expr!(lc!(((x, expr!(access!(a)))) => access!(x)))))))], ), ); @@ -173,16 +165,13 @@ fn fn_call_in_fn() { function_ident!(fold_scalar_and_vec), vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], Type::Felt, - vec![return_!(expr!(add!( - access!(a), - call!(fold_vec(expr!(access!(b)))) - )))], + vec![return_!(expr!(add!(access!(a), call!(fold_vec(expr!(access!(b)))))))], ), ); expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 12)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 1), (b, 12)])); expected.public_inputs.insert( ident!(stack_inputs), @@ -238,7 +227,7 @@ fn fn_call_in_ev() { enf evaluator(a, b); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(root)); expected.functions.insert( ident!(fold_vec), @@ -247,9 +236,7 @@ fn fn_call_in_ev() { function_ident!(fold_vec), vec![(ident!(a), Type::Vector(12))], Type::Felt, - vec![return_!(expr!(call!(sum(expr!( - lc!(((x, expr!(access!(a)))) => access!(x)) - )))))], + vec![return_!(expr!(call!(sum(expr!(lc!(((x, expr!(access!(a)))) => access!(x)))))))], ), ); @@ -260,10 +247,7 @@ fn fn_call_in_ev() { function_ident!(fold_scalar_and_vec), vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], Type::Felt, - vec![return_!(expr!(add!( - access!(a), - call!(fold_vec(expr!(access!(b)))) - )))], + vec![return_!(expr!(add!(access!(a), call!(fold_vec(expr!(access!(b)))))))], ), ); @@ -272,7 +256,7 @@ fn fn_call_in_ev() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(evaluator), - vec![trace_segment!(0, "%0", [(a, 1), (b, 12)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(a, 1), (b, 12)])], vec![enforce!(eq!( access!(a, 1), call!(fold_scalar_and_vec(expr!(access!(a)), expr!(access!(b)))) @@ -282,7 +266,7 @@ fn fn_call_in_ev() { expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 12)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 1), (b, 12)])); expected.public_inputs.insert( ident!(stack_inputs), @@ -296,10 +280,7 @@ fn fn_call_in_ev() { expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(call!(evaluator( - expr!(access!(a)), - expr!(access!(b)) - )))], + vec![enforce!(call!(evaluator(expr!(access!(a)), expr!(access!(b)))))], )); ParseTest::new().expect_module_ast(source, expected); @@ -331,7 +312,7 @@ fn fn_as_lc_iterables() { enf a' = sum([operation(x, y) for (x, y) in (a, b)]); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(root)); expected.functions.insert( ident!(operation), @@ -340,16 +321,14 @@ fn fn_as_lc_iterables() { function_ident!(operation), vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], Type::Felt, - vec![ - let_!(x = expr!(add!(exp!(access!(a), access!(b)), int!(1))) => - return_!(expr!(exp!(access!(b), access!(x))))), - ], + vec![let_!(x = expr!(add!(exp!(access!(a), access!(b)), int!(1))) => + return_!(expr!(exp!(access!(b), access!(x)))))], ), ); expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 12), (b, 12)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 12), (b, 12)])); expected.public_inputs.insert( ident!(stack_inputs), @@ -404,7 +383,7 @@ fn fn_call_in_binary_ops() { enf b[0]' = b[0] * operation(a, b); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(root)); expected.functions.insert( ident!(operation), @@ -424,7 +403,7 @@ fn fn_call_in_binary_ops() { expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 12), (b, 12)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 12), (b, 12)])); expected.public_inputs.insert( ident!(stack_inputs), @@ -433,10 +412,7 @@ fn fn_call_in_binary_ops() { expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(a[0], Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(a[0], Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( @@ -444,17 +420,11 @@ fn fn_call_in_binary_ops() { vec![ enforce!(eq!( access!(a[0], 1), - mul!( - access!(a[0], 0), - call!(operation(expr!(access!(a)), expr!(access!(b)))) - ) + mul!(access!(a[0], 0), call!(operation(expr!(access!(a)), expr!(access!(b))))) )), enforce!(eq!( access!(b[0], 1), - mul!( - access!(b[0], 0), - call!(operation(expr!(access!(a)), expr!(access!(b)))) - ) + mul!(access!(b[0], 0), call!(operation(expr!(access!(a)), expr!(access!(b))))) )), ], )); @@ -489,7 +459,7 @@ fn fn_call_in_vector_def() { enf b[0]' = d[1]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(root)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(root)); expected.functions.insert( ident!(operation), @@ -498,18 +468,16 @@ fn fn_call_in_vector_def() { function_ident!(operation), vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], Type::Vector(12), - vec![return_!(expr!( - lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( - access!(x), - access!(y) - )) - ))], + vec![return_!(expr!(lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( + access!(x), + access!(y) + ))))], ), ); expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 12), (b, 12)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 12), (b, 12)])); expected.public_inputs.insert( ident!(stack_inputs), @@ -518,10 +486,7 @@ fn fn_call_in_vector_def() { expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(a[0], Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(a[0], Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( diff --git a/parser/src/parser/tests/inlining.rs b/parser/src/parser/tests/inlining.rs deleted file mode 100644 index 370ce83c0..000000000 --- a/parser/src/parser/tests/inlining.rs +++ /dev/null @@ -1,1280 +0,0 @@ -use air_pass::Pass; -use miden_diagnostics::SourceSpan; - -use pretty_assertions::assert_eq; - -use crate::{ - ast::*, - transforms::{ConstantPropagation, Inlining}, -}; - -use super::ParseTest; - -/// This test inlines an evaluator function into the root -/// integrity constraints. The evaluator is called with a -/// single trace column binding representing two columns -/// in the main trace, which is split into two individual -/// bindings via the evaluator function signature. -/// -/// It is expected that the resulting evaluator function -/// body will have its references to those parameters rewritten -/// to refer to the input binding, but with appropriate accesses -/// inserted to match the semantics of the function signature -#[test] -fn test_inlining_with_evaluator_split_input_binding() { - let root = r#" - def root - - use lib::*; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - const A = [2, 4, 6, 8]; - const B = [[1, 1], [2, 2]]; - - integrity_constraints { - enf test_constraint(b); - let x = 2^EXP; - let y = A[0..2]; - enf a + y[1] = c + (x + 1); - } - - boundary_constraints { - let x = B[0]; - enf a.first = x[0]; - } - - "#; - let lib = r#" - mod lib - - const EXP = 2; - - ev test_constraint([b0, b1]) { - let x = EXP; - let y = 2^x; - enf b0 + x = b1 + y; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib.air"); - test.add_virtual_file(path, lib.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - expected - .constants - .insert(ident!(root, A), constant!(A = [2, 4, 6, 8])); - expected - .constants - .insert(ident!(root, B), constant!(B = [[1, 1], [2, 2]])); - expected - .constants - .insert(ident!(lib, EXP), constant!(EXP = 2)); - // When constant propagation and inlining is done, the boundary constraints should look like: - // enf a.first = 1 - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(a, Boundary::First, Type::Felt), - int!(1) - ))); - // When constant propagation and inlining is done, the integrity constraints should look like: - // enf b[0] + 2 = b[1] + 4 - // enf a + 4 = c + 5 - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(b[0], Type::Felt), int!(2)), - add!(access!(b[1], Type::Felt), int!(4)) - ))); - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(a, Type::Felt), int!(4)), - add!(access!(c, Type::Felt), int!(5)) - ))); - // The test_constraint function before inlining should look like: - // enf b0 + 2 = b1 + 4 - let body = vec![enforce!(eq!( - add!(access!(b0, Type::Felt), int!(2)), - add!(access!(b1, Type::Felt), int!(4)) - ))]; - expected.evaluators.insert( - function_ident!(lib, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(b0, 1), (b1, 1)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test inlines an evaluator function into the root -/// integrity constraints. The evaluator is called with two -/// disjoint bindings representing three columns from the main -/// trace, packed using a vector literal. The evaluator function -/// then destructures that vector into a set of two bindings which -/// recombines the input columns into different groupings, and then -/// expresses a constraint using accesses into those groups. -/// -/// It is expected that the resulting evaluator function -/// body will have its references to those parameters rewritten -/// to accesses relative to the input bindings, or to direct accesses -/// to the corresponding trace segment declaration. -#[test] -fn test_inlining_with_vector_literal_binding_regrouped() { - let root = r#" - def root - - use lib::*; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint([clk, b]); - } - - boundary_constraints { - enf clk.first = 0; - } - - "#; - let lib = r#" - mod lib - - ev test_constraint([pair[2], b1]) { - enf pair[0] + pair[1] = b1; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib.air"); - test.add_virtual_file(path, lib.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf clk + b[0] = b[1] - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(clk, Type::Felt), access!(b[0], Type::Felt)), - access!(b[1], Type::Felt) - ))); - // The test_constraint function before inlining should look like: - // enf pair[0] + pair[1] = b1 - let body = vec![enforce!(eq!( - add!(access!(pair[0], Type::Felt), access!(pair[1], Type::Felt)), - access!(b1, Type::Felt) - ))]; - expected.evaluators.insert( - function_ident!(lib, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(pair, 2), (b1, 1)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test checks that there are no assumptions about the ordering of -/// arguments to an evaluator, i.e. there is no assumption that two consecutive -/// columns necessarily appear in that order in the trace_columns declaration -#[test] -fn test_inlining_with_vector_literal_binding_unordered() { - let root = r#" - def root - - use lib::*; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint([b, clk]); - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - let lib = r#" - mod lib - - ev test_constraint([b0, pair[2]]) { - enf pair[1] + b0 = pair[0]; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib.air"); - test.add_virtual_file(path, lib.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf clk + b[0] = b[1] - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(clk, Type::Felt), access!(b[0], Type::Felt)), - access!(b[1], Type::Felt) - ))); - // The test_constraint function before inlining should look like: - // enf pair[1] + b0 = pair[0] - let body = vec![enforce!(eq!( - add!(access!(pair[1], Type::Felt), access!(b0, Type::Felt)), - access!(pair[0], Type::Felt) - ))]; - expected.evaluators.insert( - function_ident!(lib, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(b0, 1), (pair, 2)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test checks the behavior when there are not only disjoint args/params -/// in a call to an evaluator, but that the number of arguments and parameters -/// is different, with more input arguments than parameter bindings. -#[test] -fn test_inlining_with_vector_literal_binding_different_arity_many_to_few() { - let root = r#" - def root - - use lib::*; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint([clk, b, a]); - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - let lib = r#" - mod lib - - ev test_constraint([pair[3], foo]) { - enf pair[0] + pair[1] = foo + pair[2]; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib.air"); - test.add_virtual_file(path, lib.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf clk + b[0] = a + b[1] - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(clk, Type::Felt), access!(b[0], Type::Felt)), - add!(access!(a, Type::Felt), access!(b[1], Type::Felt)) - ))); - // The test_constraint function before inlining should look like: - // enf pair[0] + pair[1] = a + pair[2] - let body = vec![enforce!(eq!( - add!(access!(pair[0], Type::Felt), access!(pair[1], Type::Felt)), - add!(access!(foo, Type::Felt), access!(pair[2], Type::Felt)) - ))]; - expected.evaluators.insert( - function_ident!(lib, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(pair, 3), (foo, 1)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test checks the behavior when there are not only disjoint args/params -/// in a call to an evaluator, but that the number of arguments and parameters -/// is different, with more parameter bindings than input arguments. -#[test] -fn test_inlining_with_vector_literal_binding_different_arity_few_to_many() { - let root = r#" - def root - - use lib::*; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint([b, a]); - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - let lib = r#" - mod lib - - ev test_constraint([x, y, z]) { - enf x + y = z; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib.air"); - test.add_virtual_file(path, lib.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf b[0] + b[1] = a - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(b[0], Type::Felt), access!(b[1], Type::Felt)), - access!(a, Type::Felt) - ))); - // The test_constraint function before inlining should look like: - // enf x + y = z - let body = vec![enforce!(eq!( - add!(access!(x, Type::Felt), access!(y, Type::Felt)), - access!(z, Type::Felt) - ))]; - expected.evaluators.insert( - function_ident!(lib, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(x, 1), (y, 1), (z, 1)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test checks the behavior when inlining across multiple modules with -/// nested calls to evaluators, with a mix of parameter/argument binding configurations -#[test] -fn test_inlining_across_modules_with_nested_evaluators_variant1() { - let root = r#" - def root - - use lib1::test_constraint; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint([clk, b, a]); - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - let lib1 = r#" - mod lib1 - - use lib2::*; - - ev test_constraint([tuple[3], z]) { - enf helper_constraint([z, tuple[1..3]]); - }"#; - let lib2 = r#" - mod lib2 - - ev helper_constraint([x[2], y]) { - enf x[0] + x[1] = y; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib1.air"); - test.add_virtual_file(path, lib1.to_string()); - let path = std::env::current_dir().unwrap().join("lib2.air"); - test.add_virtual_file(path, lib2.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf a + b[0] = b[1] - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(a, Type::Felt), access!(b[0], Type::Felt)), - access!(b[1], Type::Felt) - ))); - // The test_constraint function before inlining should look like: - // enf helper_constraint([z, tuple[1..3]]) - let body = vec![enforce!(call!(lib2::helper_constraint(vector!( - access!(z, Type::Felt), - slice!(tuple, 1..3, Type::Vector(2)) - ))))]; - expected.evaluators.insert( - function_ident!(lib1, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(tuple, 3), (z, 1)])], - body, - ), - ); - // The helper_constraint function before inlining should look like: - // enf x[0] + x[1] = y - let body = vec![enforce!(eq!( - add!(access!(x[0], Type::Felt), access!(x[1], Type::Felt)), - access!(y, Type::Felt) - ))]; - expected.evaluators.insert( - function_ident!(lib2, helper_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(helper_constraint), - vec![trace_segment!(0, "%0", [(x, 2), (y, 1)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test is like *_variant1, but with a different mix of parameter/argument configurations -#[test] -fn test_inlining_across_modules_with_nested_evaluators_variant2() { - let root = r#" - def root - - use lib1::test_constraint; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint([clk, b[0..2], a]); - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - let lib1 = r#" - mod lib1 - - use lib2::*; - - ev test_constraint([tuple[3], z]) { - enf helper_constraint([z, tuple[1], tuple[2..3]]); - }"#; - let lib2 = r#" - mod lib2 - - ev helper_constraint([x[2], y]) { - enf x[0] + x[1] = y; - }"#; - - let test = ParseTest::new(); - let path = std::env::current_dir().unwrap().join("lib1.air"); - test.add_virtual_file(path, lib1.to_string()); - let path = std::env::current_dir().unwrap().join("lib2.air"); - test.add_virtual_file(path, lib2.to_string()); - - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf a + b[0] = b[1] - expected.integrity_constraints.push(enforce!(eq!( - add!(access!(a, Type::Felt), access!(b[0], Type::Felt)), - access!(b[1], Type::Felt) - ))); - // The test_constraint function before inlining should look like: - // enf helper_constraint([z, tuple[1..3]]) - let body = vec![enforce!(call!(lib2::helper_constraint(vector!( - access!(z, Type::Felt), - access!(tuple[1], Type::Felt), - slice!(tuple, 2..3, Type::Vector(1)) - ))))]; - expected.evaluators.insert( - function_ident!(lib1, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(tuple, 3), (z, 1)])], - body, - ), - ); - // The helper_constraint function before inlining should look like: - // enf x[0] + x[1] = y - let body = vec![enforce!(eq!( - add!(access!(x[0], Type::Felt), access!(x[1], Type::Felt)), - access!(y, Type::Felt) - ))]; - expected.evaluators.insert( - function_ident!(lib2, helper_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(helper_constraint), - vec![trace_segment!(0, "%0", [(x, 2), (y, 1)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test verifies that constraint comprehensions (without a selector) are unrolled properly during inlining -/// -/// In this variant, we do not involve other modules to keep the test focused on just the -/// comprehension unrolling behavior. Other tests will expand on this to test it when combined -/// with other inlining behavior. -#[test] -fn test_inlining_constraint_comprehensions_no_selector() { - let root = r#" - def root - - const YS = [2, 4, 6, 8]; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - # We're expecting this to expand to: - # - # enf b[0]' = 2 - # enf b[1]' = 4 - # - enf x' = y for (x, y) in (b, YS[0..2]); - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - - let test = ParseTest::new(); - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected - .constants - .insert(ident!(root, YS), constant!(YS = [2, 4, 6, 8])); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf b[0]' = 2 - // enf b[1]' = 4 - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[0], 1, Type::Felt), int!(2)))); - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[1], 1, Type::Felt), int!(4)))); - - assert_eq!(program, expected); -} - -/// This test verifies that constraint comprehensions (with a selector) are unrolled properly during inlining -/// -/// In this variant, we do not involve other modules to keep the test focused on just the -/// comprehension unrolling behavior. Other tests will expand on this to test it when combined -/// with other inlining behavior. -#[test] -fn test_inlining_constraint_comprehensions_with_selector() { - let root = r#" - def root - - const YS = [2, 4, 6, 8]; - - trace_columns { - main: [clk, a, b[2], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - # We're expecting this to expand to: - # - # enf b[0]' = 2 when c - # enf b[1]' = 4 when c - # - enf x' = y for (x, y) in (b, YS[0..2]) when c; - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - - let test = ParseTest::new(); - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected - .constants - .insert(ident!(root, YS), constant!(YS = [2, 4, 6, 8])); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 2), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf b[0]' = 2 when c - // enf b[1]' = 4 when c - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[0], 1, Type::Felt), int!(2)), when access!(c, Type::Felt))); - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[1], 1, Type::Felt), int!(4)), when access!(c, Type::Felt))); - - assert_eq!(program, expected); -} - -/// This test verifies that constraint comprehensions (with a selector) are unrolled properly during inlining. -/// Specifically, in this case we expect that because the selector is constant, only constraints for which the -/// selector is "truthy" (i.e. non-zero) remain, and that the selector has been elided. -/// -/// In this variant, we do not involve other modules to keep the test focused on just the -/// comprehension unrolling behavior. Other tests will expand on this to test it when combined -/// with other inlining behavior. -#[test] -fn test_inlining_constraint_comprehensions_with_constant_selector() { - let root = r#" - def root - - const YS = [0, 4, 0, 8]; - - trace_columns { - main: [clk, a, b[4], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - # We're expecting this to expand to: - # - # enf b[1]' = 4 - # enf b[3]' = 8 - # - enf x' = y for (x, y) in (b, YS) when y; - } - - boundary_constraints { - enf clk.first = 0; - } - "#; - - let test = ParseTest::new(); - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected - .constants - .insert(ident!(root, YS), constant!(YS = [0, 4, 0, 8])); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 4), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf b[1]' = 4 - // enf b[3]' = 8 - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[1], 1, Type::Felt), int!(4)))); - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[3], 1, Type::Felt), int!(8)))); - - assert_eq!(program, expected); -} - -/// This test verifies that constraint comprehensions present in evaluators are inlined into call sites correctly -/// -/// This test exercises multiple edge cases in constant propagation/inlining in conjunction to make sure that all -/// of the pieces integrate together even in odd scenarios -#[test] -fn test_inlining_constraint_comprehensions_in_evaluator() { - let root = r#" - def root - - const YS = [0, 4, 0, 8]; - - trace_columns { - main: [clk, a, b[4], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint(b[1..4]); - } - - boundary_constraints { - enf clk.first = 0; - } - - ev test_constraint([i, j[2]]) { - let ys = [x^2 for x in YS]; - let k = j[0]; - let l = j[1]; - let xs = [i, k, l]; - enf x' = y for (x, y) in (xs, ys[1..4]) when y; - }"#; - - let test = ParseTest::new(); - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected - .constants - .insert(ident!(root, YS), constant!(YS = [0, 4, 0, 8])); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 4), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // enf b[1]' = 16 - // enf b[3]' = 64 - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[1], 1, Type::Felt), int!(16)))); - expected - .integrity_constraints - .push(enforce!(eq!(access!(b[3], 1, Type::Felt), int!(64)))); - // The evaluator definition is never modified by inlining, but is by constant propagation: - // - // ev test_constraint([i, j[2]]) { - // let k = j[0] - // let l = j[1] - // let xs = [i, k, l] - // enf x' = y for (x, y) in (xs, [16, 0, 64]) when y - // } - let body = vec![let_!(k = expr!(access!(j[0], Type::Felt)) - => let_!(l = expr!(access!(j[1], Type::Felt)) - => let_!(xs = vector!(access!(i, Type::Felt), access!(k, Type::Felt), access!(l, Type::Felt)) - => enforce_all!(lc!(((x, expr!(access!(xs, Type::Vector(3)))), (y, vector!(16, 0, 64))) - => eq!(access!(x, 1, Type::Felt), access!(y, Type::Felt)), when access!(y, Type::Felt))))) - )]; - expected.evaluators.insert( - function_ident!(root, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(i, 1), (j, 2)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -/// This test verifies that constraints involving let-bound, folded comprehensions behave as expected -#[test] -fn test_inlining_constraints_with_folded_comprehensions_in_evaluator() { - let root = r#" - def root - - trace_columns { - main: [clk, a, b[4], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - enf test_constraint(b[1..4]); - } - - boundary_constraints { - enf clk.first = 0; - } - - ev test_constraint([x, ys[2]]) { - let y = sum([col^7 for col in ys]); - let z = prod([col^7 for col in ys]); - enf x = y + z; - }"#; - - let test = ParseTest::new(); - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 4), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // When constant propagation and inlining is done, integrity_constraints should look like: - // let y = - // let %lc0 = b[2]^7 - // let %lc1 = b[3]^7 - // %lc0 + %lc1 - // in - // let z = - // let %lc2 = b[2]^7 - // let %lc3 = b[3]^7 - // %lc2 * %lc3 - // in - // enf b[1] = y + z - expected - .integrity_constraints - .push(let_!(y = expr!( - let_!("%lc0" = expr!(exp!(access!(b[2], Type::Felt), int!(7))) - => let_!("%lc1" = expr!(exp!(access!(b[3], Type::Felt), int!(7))) - => statement!(add!(access!("%lc0", Type::Felt), access!("%lc1", Type::Felt))))) - ) => - let_!(z = expr!( - let_!("%lc2" = expr!(exp!(access!(b[2], Type::Felt), int!(7))) - => let_!("%lc3" = expr!(exp!(access!(b[3], Type::Felt), int!(7))) - => statement!(mul!(access!("%lc2", Type::Felt), access!("%lc3", Type::Felt))))) - ) => - enforce!(eq!(access!(b[1], Type::Felt), add!(access!(y, Type::Felt), access!(z, Type::Felt)))) - ) - )); - // The evaluator definition is never modified by constant propagation or inlining - let body = vec![ - let_!(y = expr!(call!(sum(expr!(lc!(((col, expr!(access!(ys, Type::Vector(2))))) => exp!(access!(col, Type::Felt), int!(7))))))) - => let_!(z = expr!(call!(prod(expr!(lc!(((col, expr!(access!(ys, Type::Vector(2))))) => exp!(access!(col, Type::Felt), int!(7))))))) - => enforce!(eq!(access!(x, Type::Felt), add!(access!(y, Type::Felt), access!(z, Type::Felt)))))), - ]; - expected.evaluators.insert( - function_ident!(root, test_constraint), - EvaluatorFunction::new( - SourceSpan::UNKNOWN, - ident!(test_constraint), - vec![trace_segment!(0, "%0", [(x, 1), (ys, 2)])], - body, - ), - ); - - assert_eq!(program, expected); -} - -#[test] -fn test_inlining_with_function_call_as_binary_operand() { - let root = r#" - def root - - trace_columns { - main: [clk, a, b[4], c], - } - - public_inputs { - inputs: [0], - } - - integrity_constraints { - let complex_fold = fold_sum(b) * fold_vec(b); - enf complex_fold = 1; - } - - boundary_constraints { - enf clk.first = 0; - } - - fn fold_sum(a: felt[4]) -> felt { - return a[0] + a[1] + a[2] + a[3]; - } - - fn fold_vec(a: felt[4]) -> felt { - let m = a[0] * a[1]; - let n = m * a[2]; - let o = n * a[3]; - return o; - } - "#; - - let test = ParseTest::new(); - let program = match test.parse_program(root) { - Err(err) => { - test.diagnostics.emit(err); - panic!("expected parsing to succeed, see diagnostics for details"); - } - Ok(ast) => ast, - }; - - let mut pipeline = - ConstantPropagation::new(&test.diagnostics).chain(Inlining::new(&test.diagnostics)); - let program = pipeline.run(program).unwrap(); - - let mut expected = Program::new(ident!(root)); - expected.trace_columns.push(trace_segment!( - 0, - "$main", - [(clk, 1), (a, 1), (b, 4), (c, 1)] - )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 0), - ); - expected.functions.insert( - function_ident!(root, fold_sum), - Function::new( - SourceSpan::UNKNOWN, - ident!(fold_sum), - vec![(ident!(a), Type::Vector(4))], - Type::Felt, - vec![return_!(expr!(add!( - add!( - add!(access!(a[0], Type::Felt), access!(a[1], Type::Felt)), - access!(a[2], Type::Felt) - ), - access!(a[3], Type::Felt) - )))], - ), - ); - expected.functions.insert( - function_ident!(root, fold_vec), - Function::new( - SourceSpan::UNKNOWN, - ident!(fold_vec), - vec![(ident!(a), Type::Vector(4))], - Type::Felt, - vec![ - let_!("m" = expr!(mul!(access!(a[0], Type::Felt), access!(a[1], Type::Felt))) - => let_!("n" = expr!(mul!(access!(m, Type::Felt), access!(a[2], Type::Felt))) - => let_!("o" = expr!(mul!(access!(n, Type::Felt), access!(a[3], Type::Felt))) - => return_!(expr!(access!(o, Type::Felt))) - ))), - ], - ), - ); - // The sole boundary constraint is already minimal - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); - // With constant propagation and inlining done - // - // let complex_fold = - // (b[0] + b[1] + b[2] + b[3]) * - // (let m = b[0] * b[1] - // let n = m * b[2] - // let o = n * b[3] in o) - // enf complex_fold = 1 - expected.integrity_constraints.push( - let_!(complex_fold = expr!(mul!( - add!(add!(add!(access!(b[0], Type::Felt), access!(b[1], Type::Felt)), access!(b[2], Type::Felt)), access!(b[3], Type::Felt)), - scalar!(let_!(m = expr!(mul!(access!(b[0], Type::Felt), access!(b[1], Type::Felt))) - => let_!(n = expr!(mul!(access!(m, Type::Felt), access!(b[2], Type::Felt))) - => let_!(o = expr!(mul!(access!(n, Type::Felt), access!(b[3], Type::Felt))) => return_!(expr!(access!(o, Type::Felt))))))) - )) => enforce!(eq!(access!(complex_fold, Type::Felt), int!(1)))) - ); - - assert_eq!(program, expected); -} diff --git a/parser/src/parser/tests/input/import_example.air b/parser/src/parser/tests/input/import_example.air index f63e58055..ca88f43f4 100644 --- a/parser/src/parser/tests/input/import_example.air +++ b/parser/src/parser/tests/input/import_example.air @@ -6,6 +6,9 @@ use foo::*; # Import just `bar_constraint` from bar use bar::bar_constraint; +# Import just `are_all_binary` from utils/binary +use utils::binary::are_all_binary; + trace_columns { main: [clk, fmp, ctx], } @@ -17,6 +20,8 @@ public_inputs { integrity_constraints { enf foo_constraint([clk]); enf bar_constraint([clk]); + enf are_all_binary([clk, fmp, ctx]); + } boundary_constraints { diff --git a/parser/src/parser/tests/input/utils/binary.air b/parser/src/parser/tests/input/utils/binary.air new file mode 100644 index 000000000..0e3470550 --- /dev/null +++ b/parser/src/parser/tests/input/utils/binary.air @@ -0,0 +1,10 @@ + +mod binary + +ev is_binary([x]) { + enf x^2 = x; +} + +ev are_all_binary([c[3]]) { + enf is_binary([c]) for c in c; +} diff --git a/parser/src/parser/tests/integrity_constraints.rs b/parser/src/parser/tests/integrity_constraints.rs index 43f5d4662..e84eb2eef 100644 --- a/parser/src/parser/tests/integrity_constraints.rs +++ b/parser/src/parser/tests/integrity_constraints.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // INTEGRITY STATEMENTS // ================================================================================================ @@ -28,20 +27,16 @@ fn integrity_constraints() { enf clk' = clk + 1; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -81,14 +76,13 @@ fn integrity_constraints_with_buses() { q.remove(1, 2) with 2; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![ @@ -96,14 +90,12 @@ fn integrity_constraints_with_buses() { enforce!(eq!(bounded_access!(q, Boundary::Last), null!())), ], )); - expected.buses.insert( - ident!(p), - Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset), - ); - expected.buses.insert( - ident!(q), - Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup), - ); + expected + .buses + .insert(ident!(p), Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Multiset)); + expected + .buses + .insert(ident!(q), Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Logup)); let mut bus_enforces = Vec::new(); @@ -188,20 +180,16 @@ fn multiple_integrity_constraints() { enf clk' - clk = 1; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -238,24 +226,19 @@ fn integrity_constraint_with_periodic_col() { enf k0 + b = 0; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(b, 1)])); - expected.periodic_columns.insert( - ident!(k0), - PeriodicColumn::new(SourceSpan::UNKNOWN, ident!(k0), vec![1, 0]), - ); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(b, 1)])); + expected + .periodic_columns + .insert(ident!(k0), PeriodicColumn::new(SourceSpan::UNKNOWN, ident!(k0), vec![1, 0])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -289,25 +272,19 @@ fn integrity_constraint_with_constants() { enf clk + A = B[1] + C[1][1]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); expected.constants.insert(ident!(A), constant!(A = 0)); expected.constants.insert(ident!(B), constant!(B = [0, 1])); + expected.constants.insert(ident!(C), constant!(C = [[0, 1], [1, 0]])); expected - .constants - .insert(ident!(C), constant!(C = [[0, 1], [1, 0]])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -343,20 +320,16 @@ fn integrity_constraint_with_variables() { enf clk + a = b[1] + c[1][1]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -389,24 +362,20 @@ fn integrity_constraint_with_indexed_trace_access() { enf $main[0]' = $main[1] + 1; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(a, 1), (b, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - access!("$main"[0], 1), - add!(access!("$main"[1]), int!(1)) - ))], + vec![enforce!(eq!(access!("$main"[0], 1), add!(access!("$main"[1]), int!(1))))], )); ParseTest::new().expect_module_ast(source, expected); } @@ -435,20 +404,18 @@ fn ic_comprehension_one_iterable_identifier() { enf x = a + b for x in c; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -480,20 +447,18 @@ fn ic_comprehension_one_iterable_range() { enf x = a + b for x in (1..4); }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -525,20 +490,18 @@ fn ic_comprehension_with_selectors() { enf x = a + b for x in c when s[0] & s[1]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(s, 2), (a, 1), (b, 1), (c, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(s, 2), (a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -574,29 +537,27 @@ fn ic_comprehension_with_evaluator_call() { enf is_binary([x]) for x in c; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4), (d, 4)])); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4), (d, 4)] + )); expected.evaluators.insert( ident!(is_binary), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(is_binary), - vec![trace_segment!(0, "%0", [(x, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(x, 1)])], vec![enforce!(eq!(exp!(access!(x), int!(2)), access!(x)))], ), ); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -632,9 +593,9 @@ fn ic_comprehension_with_evaluator_and_selectors() { enf is_binary([x]) for x in c when s[0] & s[1]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected.trace_columns.push(trace_segment!( - 0, + TraceSegmentId::Main, "$main", [(s, 2), (a, 1), (b, 1), (c, 4), (d, 4)] )); @@ -643,20 +604,16 @@ fn ic_comprehension_with_evaluator_and_selectors() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(is_binary), - vec![trace_segment!(0, "%0", [(x, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(x, 1)])], vec![enforce!(eq!(exp!(access!(x), int!(2)), access!(x)))], ), ); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, @@ -695,9 +652,9 @@ fn ic_match_constraint() { }; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected.trace_columns.push(trace_segment!( - 0, + TraceSegmentId::Main, "$main", [(s, 2), (a, 1), (b, 1), (c, 4), (d, 4)] )); @@ -706,31 +663,26 @@ fn ic_match_constraint() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(is_binary), - vec![trace_segment!(0, "%0", [(x, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(x, 1)])], vec![enforce!(eq!(exp!(access!(x), int!(2)), access!(x)))], ), ); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![ - enforce_all!( - lc!((("%1", range!(0..1))) => call!(is_binary(vector!(access!(c[0])))), when and!(access!(s[0]), access!(s[1]))) - ), - enforce_all!( - lc!((("%2", range!(0..1))) => eq!(access!(c[1]), access!(c[2])), when access!(s[0])) + vec![enforce_if!( + match_arm!( + call!(is_binary(vector!(access!(c[0])))), + and!(access!(s[0]), access!(s[1])) ), - ], + match_arm!(eq!(access!(c[1]), access!(c[2])), access!(s[0])) + )], )); ParseTest::new().expect_module_ast(source, expected); } diff --git a/parser/src/parser/tests/list_comprehension.rs b/parser/src/parser/tests/list_comprehension.rs index 55a5e46aa..f94c074de 100644 --- a/parser/src/parser/tests/list_comprehension.rs +++ b/parser/src/parser/tests/list_comprehension.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // LIST COMPREHENSION // ================================================================================================ @@ -31,18 +30,17 @@ fn bc_one_iterable_identifier_lc() { enf a.first = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(a), int!(0)))], + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(a), int!(0)))])); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![ @@ -78,21 +76,18 @@ fn bc_identifier_and_range_lc() { enf a.first = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); - expected - .constants - .insert(ident!(THREE), constant!(THREE = 3)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(a), int!(0)))], + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.constants.insert(ident!(THREE), constant!(THREE = 3)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(a), int!(0)))])); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![ @@ -126,24 +121,21 @@ fn bc_iterable_slice_lc() { enf a.first = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(a), int!(0)))], + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(a), int!(0)))])); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![ - let_!(x = lc!(((c, expr!(slice!(c, 0..3)))) => access!(c)).into() => - enforce!(eq!(bounded_access!(a, Boundary::First), add!(add!(add!(access!(x[0]), access!(x[1])), access!(x[2])), access!(x[3]))))), - ], + vec![let_!(x = lc!(((c, expr!(slice!(c, 0..3)))) => access!(c)).into() => + enforce!(eq!(bounded_access!(a, Boundary::First), add!(add!(add!(access!(x[0]), access!(x[1])), access!(x[2])), access!(x[3])))))], )); ParseTest::new().expect_module_ast(source, expected); @@ -171,18 +163,17 @@ fn bc_two_iterable_identifier_lc() { enf a.first = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4), (d, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(a), int!(0)))], + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4), (d, 4)] )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(a), int!(0)))])); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![ @@ -216,18 +207,17 @@ fn bc_multiple_iterables_lc() { enf a.first = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 3), (c, 4), (d, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(a), int!(0)))], + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 3), (c, 4), (d, 4)] )); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(a), int!(0)))])); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![ @@ -266,14 +256,15 @@ fn ic_one_iterable_identifier_lc() { enf a = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], @@ -312,14 +303,15 @@ fn ic_iterable_identifier_range_lc() { enf a = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], @@ -357,24 +349,23 @@ fn ic_iterable_slice_lc() { enf a = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![ - let_!(x = lc!(((c, expr!(slice!(c, 0..3)))) => access!(c)).into() => - enforce!(eq!(access!(a), add!(add!(add!(access!(x[0]), access!(x[1])), access!(x[2])), access!(x[3]))))), - ], + vec![let_!(x = lc!(((c, expr!(slice!(c, 0..3)))) => access!(c)).into() => + enforce!(eq!(access!(a), add!(add!(add!(access!(x[0]), access!(x[1])), access!(x[2])), access!(x[3])))))], )); ParseTest::new().expect_module_ast(source, expected); @@ -402,14 +393,15 @@ fn ic_two_iterable_identifier_lc() { enf a = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 1), (c, 4), (d, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 1), (c, 4), (d, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], @@ -447,14 +439,15 @@ fn ic_multiple_iterables_lc() { enf a = x[0] + x[1] + x[2] + x[3]; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(a, 1), (b, 3), (c, 4), (d, 4)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(a, 1), (b, 3), (c, 4), (d, 4)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(bounded_access!(a, Boundary::First), int!(0)))], diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 5f8805a9d..aa1a162e1 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -27,6 +27,54 @@ macro_rules! assert_module_error { }; } +macro_rules! module_ident { + ($name:ident) => { + ModuleId::new( + vec![Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(stringify!($name)), + )], + miden_diagnostics::SourceSpan::UNKNOWN, + ) + }; + + ($($names:ident),+) => { + ModuleId::new( + vec![ + $( + Identifier::new( + miden_diagnostics::SourceSpan::UNKNOWN, + crate::Symbol::intern(stringify!($names)), + ) + ),+ + ], + miden_diagnostics::SourceSpan::UNKNOWN, + ) + }; +} + +macro_rules! import_all { + ($module:ident) => { + Import::All { module: module_ident!($module) } + }; + ($($names:ident),+) => { + Import::All { module: module_ident!($($names),+) } + }; +} + +macro_rules! import { + ($module:ident, $item:ident) => {{ + let mut items: std::collections::HashSet = std::collections::HashSet::default(); + items.insert(ident!($item)); + Import::Partial { module: module_ident!($module), items } + }}; + (($($names:ident),+), $item:ident) => {{ + let mut items: std::collections::HashSet = std::collections::HashSet::default(); + items.insert(ident!($item)); + Import::Partial { module: module_ident!($($names),+), items } + }}; +} + macro_rules! ident { ($name:ident) => { Identifier::new( @@ -36,15 +84,12 @@ macro_rules! ident { }; ($name:literal) => { - Identifier::new( - miden_diagnostics::SourceSpan::UNKNOWN, - crate::Symbol::intern($name), - ) + Identifier::new(miden_diagnostics::SourceSpan::UNKNOWN, crate::Symbol::intern($name)) }; ($module:ident, $name:ident) => { QualifiedIdentifier::new( - ident!($module), + module_ident!($module), NamespacedIdentifier::Binding(ident!($name)), ) }; @@ -57,7 +102,13 @@ macro_rules! function_ident { ($module:ident, $name:ident) => { QualifiedIdentifier::new( - ident!($module), + module_ident!($module), + NamespacedIdentifier::Function(ident!($name)), + ) + }; + (($($modules:ident),+), $name:ident) => { + QualifiedIdentifier::new( + module_ident!($($modules),+), NamespacedIdentifier::Function(ident!($name)), ) }; @@ -174,6 +225,15 @@ macro_rules! access { }; ($name:ident [ $idx:literal ]) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + 0, + )) + }; + + ($name:ident [ $idx:expr ]) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -183,6 +243,15 @@ macro_rules! access { }; ($name:literal [ $idx:literal ]) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + 0, + )) + }; + + ($name:literal [ $idx:expr ]) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -192,6 +261,15 @@ macro_rules! access { }; ($name:ident [ $row:literal ] [ $col:literal ]) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Matrix(Box::new(int!($row)), Box::new(int!($col))), + 0, + )) + }; + + ($name:ident [ $row:expr ] [ $col:expr ]) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -201,6 +279,16 @@ macro_rules! access { }; ($name:ident [ $row:literal ] [ $col:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!($name)), + access_type: AccessType::Matrix(Box::new(int!($row)), Box::new(int!($col))), + offset: 0, + ty: Some($ty), + }) + }; + + ($name:ident [ $row:expr ] [ $col:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), @@ -211,6 +299,16 @@ macro_rules! access { }; ($module:ident, $name:ident [ $idx:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ident!($module, $name).into(), + access_type: AccessType::Index(Box::new(int!($idx))), + offset: 0, + ty: Some($ty), + }) + }; + + ($module:ident, $name:ident [ $idx:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ident!($module, $name).into(), @@ -221,6 +319,16 @@ macro_rules! access { }; ($module:ident, $name:ident [ $row:literal ] [ $col:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ident!($module, $name).into(), + access_type: AccessType::Matrix(Box::new(int!($row)), Box::new(int!($col))), + offset: 0, + ty: Some($ty), + }) + }; + + ($module:ident, $name:ident [ $row:expr ] [ $col:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ident!($module, $name).into(), @@ -231,6 +339,15 @@ macro_rules! access { }; ($name:ident [ $idx:literal ], $offset:literal) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + $offset, + )) + }; + + ($name:ident [ $idx:expr ], $offset:literal) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -240,6 +357,16 @@ macro_rules! access { }; ($name:ident [ $idx:literal ], $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!($name)), + access_type: AccessType::Index(Box::new(int!($idx))), + offset: 0, + ty: Some($ty), + }) + }; + + ($name:ident [ $idx:expr ], $ty:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), @@ -253,13 +380,32 @@ macro_rules! access { ScalarExpr::SymbolAccess(SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), - access_type: AccessType::Index($idx), + access_type: AccessType::Index(Box::new(int!($idx))), + offset: $offset, + ty: Some($ty), + }) + }; + + ($name:ident [ $idx:literal ], $offset:literal, $ty:expr) => { + ScalarExpr::SymbolAccess(SymbolAccess { + span: miden_diagnostics::SourceSpan::UNKNOWN, + name: ResolvableIdentifier::Local(ident!($name)), + access_type: AccessType::Index(Box::new(int!($idx))), offset: $offset, ty: Some($ty), }) }; ($name:literal [ $idx:literal ], $offset:literal) => { + ScalarExpr::SymbolAccess(SymbolAccess::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($name), + AccessType::Index(Box::new(int!($idx))), + $offset, + )) + }; + + ($name:literal [ $idx:expr ], $offset:literal) => { ScalarExpr::SymbolAccess(SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), @@ -275,18 +421,6 @@ macro_rules! expr { }; } -macro_rules! scalar { - ($expr:expr) => { - ScalarExpr::try_from($expr).unwrap() - }; -} - -macro_rules! statement { - ($expr:expr) => { - Statement::try_from($expr).unwrap() - }; -} - macro_rules! slice { ($name:ident, $range:expr) => { ScalarExpr::SymbolAccess(SymbolAccess { @@ -343,7 +477,7 @@ macro_rules! bounded_access { SymbolAccess::new( miden_diagnostics::SourceSpan::UNKNOWN, ident!($name), - AccessType::Index($idx), + AccessType::Index(Box::new(int!($idx))), 0, ), $bound, @@ -356,7 +490,7 @@ macro_rules! bounded_access { SymbolAccess { span: miden_diagnostics::SourceSpan::UNKNOWN, name: ResolvableIdentifier::Local(ident!($name)), - access_type: AccessType::Index($idx), + access_type: AccessType::Index(Box::new(int!($idx))), offset: 0, ty: Some($ty), }, @@ -383,10 +517,7 @@ macro_rules! int { macro_rules! null { () => { - ScalarExpr::Null(miden_diagnostics::Span::new( - miden_diagnostics::SourceSpan::UNKNOWN, - (), - )) + ScalarExpr::Null(miden_diagnostics::Span::new(miden_diagnostics::SourceSpan::UNKNOWN, ())) }; } @@ -402,11 +533,21 @@ macro_rules! call { args: vec![$($param),+], ty: None, }) - } + }; + + (($($modules:ident),+) :: $callee:ident ($($param:expr),+)) => { + ScalarExpr::Call(Call { + span: miden_diagnostics::SourceSpan::UNKNOWN, + callee: ResolvableIdentifier::Resolved(function_ident!(($($modules),+), $callee)), + args: vec![$($param),+], + ty: None, + }) + }; + } macro_rules! trace_segment { - ($idx:literal, $name:literal, [$(($binding_name:ident, $binding_size:literal)),*]) => { + ($idx:expr, $name:literal, [$(($binding_name:ident, $binding_size:literal)),*]) => { TraceSegment::new(miden_diagnostics::SourceSpan::UNKNOWN, $idx, ident!($name), vec![ $(miden_diagnostics::Span::new(miden_diagnostics::SourceSpan::UNKNOWN, (ident!($binding_name), $binding_size))),* ]) @@ -469,7 +610,22 @@ macro_rules! enforce { }; ($expr:expr, when $selector:expr) => { - Statement::EnforceIf($expr, $selector) + Statement::EnforceIf(Match::new( + miden_diagnostics::SourceSpan::UNKNOWN, + vec![MatchArm::new(miden_diagnostics::SourceSpan::UNKNOWN, $expr, $selector)], + )) + }; +} + +macro_rules! enforce_if { + ($($match_arms:expr),+) => { + Statement::EnforceIf(Match::new(miden_diagnostics::SourceSpan::UNKNOWN, vec![$($match_arms),+])) + }; +} + +macro_rules! match_arm { + ($expr:expr, $selector:expr) => { + MatchArm::new(miden_diagnostics::SourceSpan::UNKNOWN, $expr, $selector) }; } @@ -648,35 +804,16 @@ macro_rules! exp { }; } -macro_rules! import_all { - ($module:ident) => { - Import::All { - module: ident!($module), - } - }; -} - -macro_rules! import { - ($module:ident, $item:ident) => {{ - let mut items: std::collections::HashSet = std::collections::HashSet::default(); - items.insert(ident!($item)); - Import::Partial { - module: ident!($module), - items, - } - }}; -} - mod arithmetic_ops; mod boundary_constraints; mod buses; mod calls; +mod computed_indices; mod constant_propagation; mod constants; mod evaluators; mod functions; mod identifiers; -mod inlining; mod integrity_constraints; mod list_comprehension; mod modules; @@ -704,9 +841,11 @@ fn full_air_file() { // trace_columns { // main: [clk, fmp, ctx] // } - expected - .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1), (fmp, 1), (ctx, 1)])); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(clk, 1), (fmp, 1), (ctx, 1)] + )); // integrity_constraints { // enf clk' = clk + 1 // } @@ -717,10 +856,9 @@ fn full_air_file() { // boundary_constraints { // enf clk.first = 0 // } - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); + expected + .boundary_constraints + .push(enforce!(eq!(bounded_access!(clk, Boundary::First, Type::Felt), int!(0)))); ParseTest::new().expect_program_ast_from_file("src/parser/tests/input/system.air", expected); } diff --git a/parser/src/parser/tests/modules.rs b/parser/src/parser/tests/modules.rs index 6f3d34da2..0fa792831 100644 --- a/parser/src/parser/tests/modules.rs +++ b/parser/src/parser/tests/modules.rs @@ -1,8 +1,7 @@ use miden_diagnostics::SourceSpan; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; #[test] fn use_declaration() { @@ -10,9 +9,11 @@ fn use_declaration() { mod test use foo::*; + use bar::baz::*; "; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - expected.imports.insert(ident!(foo), import_all!(foo)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); + expected.imports.insert(module_ident!(foo), import_all!(foo)); + expected.imports.insert(module_ident!(bar, baz), import_all!(bar, baz)); ParseTest::new().expect_module_ast(source, expected); } @@ -22,9 +23,11 @@ fn import_declaration() { mod test use foo::bar; + use baz::bat::bot; "; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - expected.imports.insert(ident!(foo), import!(foo, bar)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); + expected.imports.insert(module_ident!(foo), import!(foo, bar)); + expected.imports.insert(module_ident!(baz, bat), import!((baz, bat), bot)); ParseTest::new().expect_module_ast(source, expected); } @@ -41,9 +44,11 @@ fn import_declaration() { #[test] fn modules_integration_test() { let mut expected = Program::new(ident!(import_example)); - expected - .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1), (fmp, 1), (ctx, 1)])); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(clk, 1), (fmp, 1), (ctx, 1)] + )); expected.periodic_columns.insert( ident!(foo, k0), PeriodicColumn::new(SourceSpan::UNKNOWN, ident!(k0), vec![1, 1, 0, 0]), @@ -58,6 +63,32 @@ fn modules_integration_test() { // defines an evaluator `other_constraint`, that evaluator is never called // so it is treated as dead code and stripped from the program + // ev is_binary([x]) { + // enf x^2 = x; + // } + expected.evaluators.insert( + function_ident!((utils, binary), is_binary), + EvaluatorFunction::new( + SourceSpan::UNKNOWN, + ident!(is_binary), + vec![trace_segment!(TraceSegmentId::Main, "%0", [(x, 1)])], + vec![enforce!(eq!(exp!(access!(x, Type::Felt), int!(2)), access!(x, Type::Felt)))], + ), + ); + // ev are_all_binary([c[3]]) { + // enf is_binary([c]) for c in c; + // } + expected.evaluators.insert( + function_ident!((utils, binary), are_all_binary), + EvaluatorFunction::new( + SourceSpan::UNKNOWN, + ident!(are_all_binary), + vec![trace_segment!(TraceSegmentId::Main, "%1", [(c, 3)])], + vec![enforce_all!( + lc!(((c, expr!(access!(c, Type::Vector(3))))) => call!((utils, binary)::is_binary(vector!(access!(c, Type::Felt))))) + )], + ), + ); // ev bar_constraint([clk]) { // enf clk' = clk + k0 when k0 // } @@ -66,11 +97,14 @@ fn modules_integration_test() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(bar_constraint), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce_all!(lc!((("%1", range!(0..1))) => eq!( - access!(clk, 1, Type::Felt), - add!(access!(clk, Type::Felt), access!(bar, k0, Type::Felt)) - ), when access!(bar, k0, Type::Felt)))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce_if!(match_arm!( + eq!( + access!(clk, 1, Type::Felt), + add!(access!(clk, Type::Felt), access!(bar, k0, Type::Felt)) + ), + access!(bar, k0, Type::Felt) + ))], ), ); // ev foo_constraint([clk]) { @@ -81,30 +115,28 @@ fn modules_integration_test() { EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(foo_constraint), - vec![trace_segment!(0, "%0", [(clk, 1)])], - vec![enforce_all!(lc!((("%1", range!(0..1))) => eq!(access!(clk, 1, Type::Felt), add!(access!(clk, Type::Felt), int!(1))), when access!(foo, k0, Type::Felt)))], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], + vec![enforce_if!(match_arm!( + eq!(access!(clk, 1, Type::Felt), add!(access!(clk, Type::Felt), int!(1))), + access!(foo, k0, Type::Felt) + ))], ), ); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected .integrity_constraints - .push(enforce!(call!(foo::foo_constraint(vector!(access!( - clk, - Type::Felt - )))))); + .push(enforce!(call!(foo::foo_constraint(vector!(access!(clk, Type::Felt)))))); expected .integrity_constraints - .push(enforce!(call!(bar::bar_constraint(vector!(access!( - clk, - Type::Felt - )))))); - expected.boundary_constraints.push(enforce!(eq!( - bounded_access!(clk, Boundary::First, Type::Felt), - int!(0) - ))); + .push(enforce!(call!(bar::bar_constraint(vector!(access!(clk, Type::Felt)))))); + expected + .integrity_constraints + .push(enforce!(call!((utils, binary)::are_all_binary(vector!(access!(clk, Type::Felt), access!(fmp, Type::Felt), access!(ctx, Type::Felt)))))); + expected + .boundary_constraints + .push(enforce!(eq!(bounded_access!(clk, Boundary::First, Type::Felt), int!(0)))); ParseTest::new() .expect_program_ast_from_file("src/parser/tests/input/import_example.air", expected); diff --git a/parser/src/parser/tests/periodic_columns.rs b/parser/src/parser/tests/periodic_columns.rs index 6b34673e1..b2178258d 100644 --- a/parser/src/parser/tests/periodic_columns.rs +++ b/parser/src/parser/tests/periodic_columns.rs @@ -1,8 +1,7 @@ use miden_diagnostics::SourceSpan; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; #[test] fn periodic_columns() { @@ -14,18 +13,14 @@ fn periodic_columns() { k1: [0, 0, 0, 0, 0, 0, 0, 1], }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); expected.periodic_columns.insert( ident!(k0), PeriodicColumn::new(SourceSpan::UNKNOWN, ident!(k0), vec![1, 0, 0, 0]), ); expected.periodic_columns.insert( ident!(k1), - PeriodicColumn::new( - SourceSpan::UNKNOWN, - ident!(k1), - vec![0, 0, 0, 0, 0, 0, 0, 1], - ), + PeriodicColumn::new(SourceSpan::UNKNOWN, ident!(k1), vec![0, 0, 0, 0, 0, 0, 0, 1]), ); ParseTest::new().expect_module_ast(source, expected); } @@ -37,7 +32,7 @@ fn empty_periodic_columns() { periodic_columns{}"; - let expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); ParseTest::new().expect_module_ast(source, expected); } diff --git a/parser/src/parser/tests/pub_inputs.rs b/parser/src/parser/tests/pub_inputs.rs index a04cee604..72de558cc 100644 --- a/parser/src/parser/tests/pub_inputs.rs +++ b/parser/src/parser/tests/pub_inputs.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::{ast::*, parser::ParseError}; - use super::ParseTest; +use crate::{ast::*, parser::ParseError}; // PUBLIC INPUTS // ================================================================================================ @@ -29,10 +28,10 @@ fn public_inputs_vec() { enf clk = 0; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); expected.public_inputs.insert( ident!(program_hash), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(program_hash), 4), @@ -43,15 +42,10 @@ fn public_inputs_vec() { ); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], - )); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(clk), int!(0)))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(clk), int!(0)))])); ParseTest::new().expect_module_ast(source, expected); } @@ -77,29 +71,22 @@ fn public_inputs_table() { enf clk = 0; }"; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1)])); - expected.public_inputs.insert( - ident!(a), - PublicInput::new_table(SourceSpan::UNKNOWN, ident!(a), 4), - ); - expected.public_inputs.insert( - ident!(b), - PublicInput::new_table(SourceSpan::UNKNOWN, ident!(b), 32), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1)])); + expected + .public_inputs + .insert(ident!(a), PublicInput::new_table(SourceSpan::UNKNOWN, ident!(a), 4)); + expected + .public_inputs + .insert(ident!(b), PublicInput::new_table(SourceSpan::UNKNOWN, ident!(b), 32)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], - )); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(clk), int!(0)))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(clk), int!(0)))])); ParseTest::new().expect_module_ast(source, expected); } diff --git a/parser/src/parser/tests/selectors.rs b/parser/src/parser/tests/selectors.rs index 40a72b8de..e2144ff15 100644 --- a/parser/src/parser/tests/selectors.rs +++ b/parser/src/parser/tests/selectors.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // SELECTORS // ================================================================================================ @@ -27,26 +26,20 @@ fn single_selector() { integrity_constraints { enf clk' = clk when n1; }"#; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1), (n1, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .push(trace_segment!(TraceSegmentId::Main, "$main", [(clk, 1), (n1, 1)])); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce_all!( - lc!((("%0", range!(0..1))) => eq!(access!(clk, 1), access!(clk)), when access!(n1)) - )], + vec![enforce_if!(match_arm!(eq!(access!(clk, 1), access!(clk)), access!(n1)))], )); ParseTest::new().expect_module_ast(source, expected); } @@ -71,28 +64,25 @@ fn chained_selectors() { integrity_constraints { enf clk' = clk when (n1 & !n2) | !n3; }"#; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected.trace_columns.push(trace_segment!( - 0, + TraceSegmentId::Main, "$main", [(clk, 1), (n1, 1), (n2, 1), (n3, 1)] )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce_all!( - lc!((("%0", range!(0..1))) => eq!(access!(clk, 1), access!(clk)), when or!(and!(access!(n1), not!(access!(n2))), not!(access!(n3)))) - )], + vec![enforce_if!(match_arm!( + eq!(access!(clk, 1), access!(clk)), + or!(and!(access!(n1), not!(access!(n2))), not!(access!(n3))) + ))], )); ParseTest::new().expect_module_ast(source, expected); diff --git a/parser/src/parser/tests/trace_columns.rs b/parser/src/parser/tests/trace_columns.rs index 2110b5829..9322e0934 100644 --- a/parser/src/parser/tests/trace_columns.rs +++ b/parser/src/parser/tests/trace_columns.rs @@ -1,8 +1,7 @@ use miden_diagnostics::{SourceSpan, Span}; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // TRACE COLUMNS // ================================================================================================ @@ -27,25 +26,21 @@ fn trace_columns() { integrity_constraints { enf clk = 0; }"#; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); + expected.trace_columns.push(trace_segment!( + TraceSegmentId::Main, + "$main", + [(clk, 1), (fmp, 1), (ctx, 1)] + )); expected - .trace_columns - .push(trace_segment!(0, "$main", [(clk, 1), (fmp, 1), (ctx, 1)])); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], - )); - expected.integrity_constraints = Some(Span::new( - SourceSpan::UNKNOWN, - vec![enforce!(eq!(access!(clk), int!(0)))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); + expected.integrity_constraints = + Some(Span::new(SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(clk), int!(0)))])); ParseTest::new().expect_module_ast(source, expected); } @@ -70,22 +65,18 @@ fn trace_columns_groups() { enf a[1]' = 1; enf clk' = clk - 1; }"#; - let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, module_ident!(test)); expected.trace_columns.push(trace_segment!( - 0, + TraceSegmentId::Main, "$main", [(clk, 1), (fmp, 1), (ctx, 1), (a, 3)] )); - expected.public_inputs.insert( - ident!(inputs), - PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2), - ); + expected + .public_inputs + .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected.boundary_constraints = Some(Span::new( SourceSpan::UNKNOWN, - vec![enforce!(eq!( - bounded_access!(clk, Boundary::First), - int!(0) - ))], + vec![enforce!(eq!(bounded_access!(clk, Boundary::First), int!(0)))], )); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, diff --git a/parser/src/parser/tests/utils.rs b/parser/src/parser/tests/utils.rs index 7e7a44c81..0a4d0af8e 100644 --- a/parser/src/parser/tests/utils.rs +++ b/parser/src/parser/tests/utils.rs @@ -74,44 +74,34 @@ impl ParseTest { no_warn: false, display: Default::default(), }; - let diagnostics = Arc::new(DiagnosticsHandler::new( - config, - codemap.clone(), - emitter.clone(), - )); + let diagnostics = + Arc::new(DiagnosticsHandler::new(config, codemap.clone(), emitter.clone())); let parser = Parser::new((), codemap); - Self { - diagnostics, - emitter, - parser, - } + Self { diagnostics, emitter, parser } } /// This adds a new in-memory file to the [CodeMap] for this test. /// - /// This is used when we want to write a test with imports, without having to place files on disk + /// This is used when we want to write a test with imports, without having to place files on + /// disk pub fn add_virtual_file>(&self, name: P, content: String) { self.parser.codemap.add(name.as_ref(), content); } pub fn parse_module_from_file(&self, path: &str) -> Result { - self.parser - .parse_file::(&self.diagnostics, path) + self.parser.parse_file::(&self.diagnostics, path) } pub fn parse_program_from_file(&self, path: &str) -> Result { - self.parser - .parse_file::(&self.diagnostics, path) + self.parser.parse_file::(&self.diagnostics, path) } pub fn parse_module(&self, source: &str) -> Result { - self.parser - .parse_string::(&self.diagnostics, source) + self.parser.parse_string::(&self.diagnostics, source) } pub fn parse_program(&self, source: &str) -> Result { - self.parser - .parse_string::(&self.diagnostics, source) + self.parser.parse_string::(&self.diagnostics, source) } // TEST METHODS @@ -153,8 +143,8 @@ impl ParseTest { } } - /// Parses a [Program] from the given source string and asserts that executing the test will result - /// in the expected AST. + /// Parses a [Program] from the given source string and asserts that executing the test will + /// result in the expected AST. #[allow(unused)] #[track_caller] pub fn expect_program_ast(&self, source: &str, expected: Program) { @@ -162,33 +152,33 @@ impl ParseTest { Err(err) => { self.diagnostics.emit(err); panic!("expected parsing to succeed, see diagnostics for details"); - } + }, Ok(ast) => assert_eq!(ast, expected), } } - /// Parses a [Module] from the given source string and asserts that executing the test will result - /// in the expected AST. + /// Parses a [Module] from the given source string and asserts that executing the test will + /// result in the expected AST. #[track_caller] pub fn expect_module_ast(&self, source: &str, expected: Module) { match self.parse_module(source) { Err(err) => { self.diagnostics.emit(err); panic!("expected parsing to succeed, see diagnostics for details"); - } + }, Ok(ast) => assert_eq!(ast, expected), } } - /// Parses a [Program] from the given source path and asserts that executing the test will result - /// in the expected AST. + /// Parses a [Program] from the given source path and asserts that executing the test will + /// result in the expected AST. #[track_caller] pub fn expect_program_ast_from_file(&self, path: &str, expected: Program) { match self.parse_program_from_file(path) { Err(err) => { self.diagnostics.emit(err); panic!("expected parsing to succeed, see diagnostics for details"); - } + }, Ok(ast) => assert_eq!(ast, expected), } } @@ -202,7 +192,7 @@ impl ParseTest { Err(err) => { self.diagnostics.emit(err); panic!("expected parsing to succeed, see diagnostics for details"); - } + }, Ok(ast) => assert_eq!(ast, expected), } } diff --git a/parser/src/parser/tests/variables.rs b/parser/src/parser/tests/variables.rs index 2005b2be9..3b84ffd42 100644 --- a/parser/src/parser/tests/variables.rs +++ b/parser/src/parser/tests/variables.rs @@ -1,8 +1,7 @@ use miden_diagnostics::SourceSpan; -use crate::ast::*; - use super::ParseTest; +use crate::ast::*; // VARIABLES // ================================================================================================ @@ -16,17 +15,17 @@ fn variables_with_and_operators() { enf clk' = clk + 1 when flag; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); // The constraint is converted into a comprehension constraint by the parser, which // involves generating an iterable with one element and giving it a generated binding let body = vec![let_!(flag = expr!(and!(access!(n1), not!(access!(n2)))) => - enforce_all!(lc!((("%1", range!(0..1))) => eq!(access!(clk, 1), add!(access!(clk), int!(1))), when access!(flag))))]; + enforce_if!(match_arm!(eq!(access!(clk, 1), add!(access!(clk), int!(1))), access!(flag))))]; expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], body, ), ); @@ -44,17 +43,16 @@ fn variables_with_or_operators() { enf clk' = clk + 1 when flag; }"; - let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); - let body = vec![ - let_!(flag = expr!(or!(access!(n1), not!(access!(n2, 1)))) => - enforce_all!(lc!((("%1", range!(0..1))) => eq!(access!(clk, 1), add!(access!(clk), int!(1))), when access!(flag)))), - ]; + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, module_ident!(test)); + let body = vec![let_!(flag = expr!(or!(access!(n1), not!(access!(n2, 1)))) => + enforce_if!(match_arm!(eq!(access!(clk, 1), add!(access!(clk), int!(1))), access!(flag))))]; + expected.evaluators.insert( ident!(test), EvaluatorFunction::new( SourceSpan::UNKNOWN, ident!(test), - vec![trace_segment!(0, "%0", [(clk, 1)])], + vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], body, ), ); diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index 38a22c3aa..d9bf5eed9 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -1,6 +1,9 @@ -use crate::ast::{AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, Type}; use std::fmt; +use crate::ast::{ + AccessType, BusType, FunctionType, InvalidAccessError, ScalarExpr, TraceBinding, Type, +}; + /// This type provides type and contextual information about a binding, /// i.e. not only does it tell us the type of a binding, but what type /// of value was bound. This is used during analysis to check whether a @@ -46,156 +49,31 @@ impl BindingType { } } - /// Returns true if this binding type is a trace binding - pub fn is_trace_binding(&self) -> bool { - match self { - Self::TraceColumn(_) | Self::TraceParam(_) => true, - Self::Vector(elems) => elems.iter().all(|e| e.is_trace_binding()), - _ => false, - } - } - - /// This function is used to split the current binding into two parts, the - /// first of which contains `n` trace columns, the second of which contains - /// what remains of the original binding. This function returns `Ok` when - /// there were `n` columns in the input binding type, otherwise `Err` with - /// a binding that contains as many columns as possible. - /// - /// If the input binding type is a single logical binding, then the resulting - /// binding types will be of the same type. If however, the input binding type - /// is a vector of bindings, then the first part of the split will be a vector - /// containing `n` elements, where each element is a single logical binding of - /// size 1. This corresponds to the way trace column bindings are packed/unpacked - /// using vectors/lists in AirScript - pub fn split_columns(&self, n: usize) -> Result<(Self, Option), Self> { - use core::cmp::Ordering; - - if n == 1 { - return Ok(self.pop_column()); - } - - match self { - Self::TraceColumn(tb) => match n.cmp(&tb.size) { - Ordering::Equal => Ok((self.clone(), None)), - Ordering::Less => { - let remaining = tb.size - n; - let first = Self::TraceColumn(TraceBinding { size: n, ..*tb }); - let rest = Self::TraceColumn(TraceBinding { - size: remaining, - offset: tb.offset + n, - ..*tb - }); - Ok((first, Some(rest))) - } - Ordering::Greater => Err(self.clone()), - }, - Self::Vector(elems) if elems.len() == 1 => elems[0].split_columns(n), - Self::Vector(elems) => { - let mut index = 0; - let mut remaining = n; - let mut set = Vec::with_capacity(elems.len()); - let mut next = elems.get(index).cloned(); - while remaining > 0 { - match next.take() { - None => return Err(Self::Vector(set)), - Some(binding_ty) => { - let (col, rest) = binding_ty.pop_column(); - set.push(col); - remaining -= 1; - next = rest.or_else(|| { - index += 1; - elems.get(index).cloned() - }); - } - } - } - let leftover = elems.len() - (index + 1); - match next { - None => Ok((Self::Vector(set), None)), - Some(mid) => { - index += 1; - let mut rest = Vec::with_capacity(leftover + 1); - rest.push(mid); - rest.extend_from_slice(&elems[index..]); - Ok((Self::Vector(set), Some(Self::Vector(rest)))) - } - } - } - invalid => panic!("invalid trace column(s) binding type: {invalid:#?}"), - } - } - - /// This function is like `split`, for the use case in which only a single - /// column is desired. This is used internally by `split` to handle those - /// cases, but may be used directly as well. - pub fn pop_column(&self) -> (Self, Option) { - match self { - // If we have a single logical binding, return the first half as - // a binding containing the first column of that binding, and the - // second half as a binding representing whatever was left, or `None` - // if it is empty. - Self::TraceColumn(tb) if tb.is_scalar() => (Self::TraceColumn(*tb), None), - Self::TraceColumn(tb) => { - let first = Self::TraceColumn(TraceBinding { - size: 1, - ty: Type::Felt, - ..*tb - }); - let remaining = tb.size - 1; - if remaining == 0 { - (first, None) - } else { - let rest = Self::TraceColumn(TraceBinding { - size: remaining, - ty: Type::Vector(remaining), - offset: tb.offset + 1, - ..*tb - }); - (first, Some(rest)) - } - } - // If the vector has only one element, remove the vector and - // return the result of popping a column on the first element. - Self::Vector(elems) if elems.len() == 1 => elems[0].pop_column(), - // If the vector has multiple elements, then we're going to return - // a vector for the remainder of the split. - Self::Vector(elems) => { - // Take the first element out of the vector - let (popped, rest) = elems.split_first().unwrap(); - // Pop a single trace column from that element - let (first, mid) = popped.pop_column(); - // The `popped` binding must have been a TraceColumn type, as - // as nested binding vectors are not permitted in calls to evaluators - match mid { - None => (first, Some(Self::Vector(rest.to_vec()))), - Some(mid) => { - let mut mid_and_rest = Vec::with_capacity(rest.len() + 1); - mid_and_rest.push(mid); - mid_and_rest.extend_from_slice(rest); - (first, Some(Self::Vector(mid_and_rest))) - } - } - } - invalid => panic!("invalid trace column(s) binding type: {invalid:#?}"), - } - } - /// Produce a new [BindingType] which represents accessing the current binding via `access_type` pub fn access(&self, access_type: AccessType) -> Result { match self { Self::Alias(aliased) => aliased.access(access_type), Self::Local(ty) => ty.access(access_type).map(Self::Local), - Self::Constant(ty) => ty - .access(access_type) - .map(|t| Self::Alias(Box::new(Self::Constant(t)))), + Self::Constant(ty) => { + ty.access(access_type).map(|t| Self::Alias(Box::new(Self::Constant(t)))) + }, Self::TraceColumn(tb) => tb.access(access_type).map(Self::TraceColumn), Self::TraceParam(tb) => tb.access(access_type).map(Self::TraceParam), Self::Vector(elems) => match access_type { AccessType::Default => Ok(Self::Vector(elems.clone())), - AccessType::Index(idx) if idx >= elems.len() => { - Err(InvalidAccessError::IndexOutOfBounds) - } - AccessType::Index(idx) => Ok(elems[idx].clone()), + AccessType::Index(idx) => { + if let ScalarExpr::Const(idx) = *idx { + if idx.item as usize >= elems.len() { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + Ok(elems[idx.item as usize].clone()) + } + } else { + // Items are all of the same type, we can just return the first one for now, + // as we cannot determine its value for now. + Ok(elems[0].clone()) + } + }, AccessType::Slice(range) => { let slice_range = range.to_slice_range(); if slice_range.end > elems.len() { @@ -203,11 +81,20 @@ impl BindingType { } else { Ok(Self::Vector(elems[slice_range].to_vec())) } - } - AccessType::Matrix(row, _) if row >= elems.len() => { - Err(InvalidAccessError::IndexOutOfBounds) - } - AccessType::Matrix(row, col) => elems[row].access(AccessType::Index(col)), + }, + AccessType::Matrix(row, col) => { + if let ScalarExpr::Const(row) = *row { + if row.item as usize >= elems.len() { + Err(InvalidAccessError::IndexOutOfBounds) + } else { + elems[row.item as usize].access(AccessType::Index(col)) + } + } else { + // Items are all of the same type, we can just return the first one for now, + // as we cannot determine its value for now. + elems[0].access(AccessType::Index(col)) + } + }, }, Self::PublicInput(ty) => ty.access(access_type).map(Self::PublicInput), Self::PeriodicColumn(period) => match access_type { diff --git a/parser/src/sema/dependencies.rs b/parser/src/sema/dependencies.rs index 9e647512d..e97180e69 100644 --- a/parser/src/sema/dependencies.rs +++ b/parser/src/sema/dependencies.rs @@ -6,12 +6,12 @@ use crate::ast::{ModuleId, QualifiedIdentifier}; /// The dependency graph is used to construct the final [Program] representation, /// containing only those parts of the program which are referenced from the root /// module. -pub type DependencyGraph = petgraph::graphmap::DiGraphMap; +pub type DependencyGraph = petgraph::graph::DiGraph; /// Represents the graph of dependencies between modules, with no regard to what /// items in those modules are actually used. In other words, this graph tells us /// what modules depend on what other modules in the program. -pub type ModuleGraph = petgraph::graphmap::DiGraphMap; +pub type ModuleGraph = petgraph::graph::DiGraph; /// Represents the type of edges in the dependency graph #[derive(Debug, Copy, Clone, PartialEq, Eq)] diff --git a/parser/src/sema/errors.rs b/parser/src/sema/errors.rs index d686f7072..637e5a6a3 100644 --- a/parser/src/sema/errors.rs +++ b/parser/src/sema/errors.rs @@ -44,7 +44,7 @@ impl PartialEq for SemanticAnalysisError { (Self::ImportUndefined(lm), Self::ImportUndefined(rm)) => lm == rm, (Self::ImportConflict { item: li, .. }, Self::ImportConflict { item: ri, .. }) => { li == ri - } + }, (Self::InvalidExpr(l), Self::InvalidExpr(r)) => l == r, _ => core::mem::discriminant(self) == core::mem::discriminant(other), } diff --git a/parser/src/sema/import_resolver.rs b/parser/src/sema/import_resolver.rs index ed31470c9..44b46294d 100644 --- a/parser/src/sema/import_resolver.rs +++ b/parser/src/sema/import_resolver.rs @@ -1,10 +1,9 @@ -use std::collections::HashMap; -use std::ops::ControlFlow; +use std::{collections::HashMap, ops::ControlFlow}; use miden_diagnostics::{DiagnosticsHandler, Severity, Spanned}; use crate::{ - ast::{visit::VisitMut, *}, + ast::{Export, visit::VisitMut, *}, sema::SemanticAnalysisError, }; @@ -48,28 +47,28 @@ impl VisitMut for ImportResolver<'_> { for import in imports.values_mut() { match import { Import::All { module: from } => { - let imported_from = match self - .library - .get(from) - .ok_or(SemanticAnalysisError::ImportUndefined(*from)) - { - Ok(value) => value, - Err(err) => return ControlFlow::Break(err), - }; - for export in imported_from.exports() { - let name = export.name(); - let item = Identifier::new(from.span(), name.name()); - self.import(module, *from, item, export)?; + let submodules = self.library.get_submodules_of(from); + for submodule in submodules { + let imported_from = match self + .library + .get(&submodule) + .ok_or(SemanticAnalysisError::ImportUndefined(submodule.clone())) + { + Ok(value) => value, + Err(err) => return ControlFlow::Break(err), + }; + for export in imported_from.exports() { + let name = export.name(); + let item = Identifier::new(from.span(), name.name()); + self.import(module, from.clone(), item, export)?; + } } - } - Import::Partial { - module: from, - items, - } => { + }, + Import::Partial { module: from, items } => { let imported_from = match self .library .get(from) - .ok_or(SemanticAnalysisError::ImportUndefined(*from)) + .ok_or(SemanticAnalysisError::ImportUndefined(from.clone())) { Ok(value) => value, Err(err) => return ControlFlow::Break(err), @@ -81,10 +80,10 @@ impl VisitMut for ImportResolver<'_> { // with the item in the set, not the span associated with the // export. if let Some(item) = items.get(&name) { - self.import(module, *from, *item, export)?; + self.import(module, from.clone(), *item, export)?; } } - } + }, } } @@ -106,6 +105,7 @@ impl ImportResolver<'_> { match export { Export::Constant(_) => self.import_constant(module, from, item), Export::Evaluator(_) => self.import_evaluator(module, from, item), + Export::Function(_) => self.import_function(module, from, item), } } @@ -148,13 +148,13 @@ impl ImportResolver<'_> { prev: id.span(), }) } - } + }, Entry::Vacant(entry) => { entry.insert(from); ControlFlow::Continue(()) - } + }, } - } + }, } } @@ -197,13 +197,57 @@ impl ImportResolver<'_> { prev: id.span(), }) } - } + }, Entry::Vacant(entry) => { entry.insert(from); ControlFlow::Continue(()) - } + }, } - } + }, + } + } + + /// Imports a function into the current module + fn import_function( + &mut self, + module: &mut Module, + from: ModuleId, + item: Identifier, + ) -> ControlFlow { + use std::collections::hash_map::Entry; + + let namespaced_name = NamespacedIdentifier::Function(item); + match module.functions.get(&item) { + Some(exists) => ControlFlow::Break(SemanticAnalysisError::ImportConflict { + item, + prev: exists.name.span(), + }), + None => match self.imported.entry(namespaced_name) { + Entry::Occupied(entry) => { + let id = entry.key(); + let originally_imported_from = entry.get(); + if originally_imported_from == &from { + // Warn about redundant import + self.diagnostics + .diagnostic(Severity::Warning) + .with_message("redundant import") + .with_primary_label(item.span(), "this import is unnecessary") + .with_secondary_label(id.span(), "because it was already imported here") + .emit(); + ControlFlow::Continue(()) + } else { + // Conflict is with another import, raise an error + ControlFlow::Break(SemanticAnalysisError::ImportConflict { + item, + prev: id.span(), + }) + } + }, + Entry::Vacant(entry) => { + entry.insert(from); + ControlFlow::Continue(()) + }, + }, } } } diff --git a/parser/src/sema/mod.rs b/parser/src/sema/mod.rs index 314004bd2..fc7be6501 100644 --- a/parser/src/sema/mod.rs +++ b/parser/src/sema/mod.rs @@ -6,8 +6,10 @@ mod scope; mod semantic_analysis; pub(crate) use self::binding_type::BindingType; -pub use self::dependencies::*; -pub use self::errors::SemanticAnalysisError; -pub use self::import_resolver::{ImportResolver, Imported}; -pub use self::scope::LexicalScope; -pub use self::semantic_analysis::SemanticAnalysis; +pub use self::{ + dependencies::*, + errors::SemanticAnalysisError, + import_resolver::{ImportResolver, Imported}, + scope::LexicalScope, + semantic_analysis::SemanticAnalysis, +}; diff --git a/parser/src/sema/scope.rs b/parser/src/sema/scope.rs index 713b787d6..0c7b373f3 100644 --- a/parser/src/sema/scope.rs +++ b/parser/src/sema/scope.rs @@ -79,7 +79,7 @@ where Self::Empty | Self::Root(_) => (), Self::Nested(parent, _) => { *self = Rc::unwrap_or_clone(parent); - } + }, } } } @@ -99,7 +99,7 @@ where env.insert(k, v); *self = Self::Root(env); None - } + }, Self::Root(env) => env.insert(k, v), Self::Nested(_, env) => env.insert(k, v), } @@ -125,9 +125,9 @@ where match self { Self::Empty => None, Self::Root(env) => env.get_mut(key), - Self::Nested(parent, env) => env - .get_mut(key) - .or_else(|| Rc::get_mut(parent).and_then(|p| p.get_mut(key))), + Self::Nested(parent, env) => { + env.get_mut(key).or_else(|| Rc::get_mut(parent).and_then(|p| p.get_mut(key))) + }, } } @@ -141,7 +141,7 @@ where Self::Root(env) => env.get_key_value(key), Self::Nested(parent, env) => { env.get_key_value(key).or_else(|| parent.get_key_value(key)) - } + }, } } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 232dbc2c8..09d89426f 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -6,14 +6,13 @@ use std::{ use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; +use super::*; use crate::{ ast::{visit::VisitMut, *}, sema::SemanticAnalysisError, symbols::{self, Symbol}, }; -use super::*; - /// A helper enum for representing what constraint mode is active #[derive(Copy, Clone, PartialEq, Eq)] enum ConstraintMode { @@ -43,24 +42,30 @@ impl fmt::Display for ConstraintMode { } } -/// This pass is used to perform a variety of semantic analysis tasks in a single traversal of a module AST +/// This pass is used to perform a variety of semantic analysis tasks in a single traversal of a +/// module AST /// -/// * Resolves all identifiers to their fully-qualified names, or raises appropriate errors if unable -/// * Warns/errors as appropriate when declarations/bindings shadow or conflict with previous declarations/bindings -/// * Assigns binding context and type information to identifiers, or raises appropriate errors if unable +/// * Resolves all identifiers to their fully-qualified names, or raises appropriate errors if +/// unable +/// * Warns/errors as appropriate when declarations/bindings shadow or conflict with previous +/// declarations/bindings +/// * Assigns binding context and type information to identifiers, or raises appropriate errors if +/// unable /// * Performs type checking /// * Tracks references to imported items, and updates the dependency graph with that information /// * Ensures constraints are valid in the context they are defined in /// * Verifies comprehension invariants /// -/// These could each be done as separate passes, but since we don't have good facilities presently for fusing -/// multiple traversals into a single traversal, or for sharing analyses, it is best for us to take advantage -/// of the information being gathered to perform many of these tasks simultaneously. +/// These could each be done as separate passes, but since we don't have good facilities presently +/// for fusing multiple traversals into a single traversal, or for sharing analyses, it is best for +/// us to take advantage of the information being gathered to perform many of these tasks +/// simultaneously. pub struct SemanticAnalysis<'a> { diagnostics: &'a DiagnosticsHandler, program: &'a Program, library: &'a Library, - deps: &'a mut DependencyGraph, + deps_graph: &'a mut DependencyGraph, + deps_nodes: &'a mut BTreeMap, imported: Imported, globals: HashMap, constants: BTreeMap, @@ -71,6 +76,9 @@ pub struct SemanticAnalysis<'a> { has_undefined_variables: bool, has_type_errors: bool, in_constraint_comprehension: bool, + /// Tracks the qualified identifier of the function or evaluator currently being analyzed. + /// This is used to build the dependency graph for function-to-function calls. + current_function: Option, } impl<'a> SemanticAnalysis<'a> { /// Create a new instance of the semantic analyzer @@ -79,13 +87,15 @@ impl<'a> SemanticAnalysis<'a> { program: &'a Program, library: &'a Library, deps: &'a mut DependencyGraph, + deps_nodes: &'a mut BTreeMap, imported: Imported, ) -> Self { Self { diagnostics, program, library, - deps, + deps_graph: deps, + deps_nodes, imported, globals: Default::default(), constants: Default::default(), @@ -96,6 +106,7 @@ impl<'a> SemanticAnalysis<'a> { has_undefined_variables: false, has_type_errors: false, in_constraint_comprehension: false, + current_function: None, } } @@ -106,21 +117,23 @@ impl<'a> SemanticAnalysis<'a> { } // If this is the root module, we may have top-level dependencies - if module.name == self.program.name { + if module.path.0.item == vec![self.program.name] { // Update the dependency graph with the collected information // // We use a special node to represent the references which occur in // the top-level boundary_constraints and integrity_constraints sections let root_node = QualifiedIdentifier::new( - self.program.name, + ModuleId::new(vec![self.program.name], self.program.name.span()), NamespacedIdentifier::Binding(Identifier::new( SourceSpan::UNKNOWN, Symbol::intern("$$root"), )), ); - for (referenced_item, ref_type) in self.referenced.iter() { - let referenced_item = self.deps.add_node(*referenced_item); - self.deps.add_edge(root_node, referenced_item, *ref_type); + let root_node_index = self.get_node_index_or_add(&root_node); + for (referenced_item, ref_type) in self.referenced.clone().iter() { + let referenced_item_node_index = + self.get_node_index_or_add(&referenced_item.clone()); + self.deps_graph.add_edge(root_node_index, referenced_item_node_index, *ref_type); } } else { // We should never have top-level dependencies here @@ -136,17 +149,21 @@ impl<'a> SemanticAnalysis<'a> { impl VisitMut for SemanticAnalysis<'_> { fn visit_mut_module(&mut self, module: &mut Module) -> ControlFlow { - self.current_module = Some(module.name); + self.current_module = Some(module.path.clone()); // Collect the values of all named constants that can be referenced in range declarations - self.constants.extend( - module - .constants - .iter() - .map(|(id, c)| (*id, c.value.clone())), - ); + self.constants + .extend(module.constants.iter().map(|(id, c)| (*id, c.value.clone()))); - // Next, add all the top-level root module declarations as locals, if this is the root module + if !module.is_root() { + for (qid, bus) in self.program.buses.iter() { + // Add the bus defined in the root to the locals bindings + self.locals.insert(qid.id(), BindingType::Bus(bus.bus_type)); + } + } + + // Next, add all the top-level root module declarations as locals, if this is the root + // module // // As above, we are guaranteed that these names have no conflicts, but we assert that anyway if module.is_root() { @@ -161,11 +178,12 @@ impl VisitMut for SemanticAnalysis<'_> { offset: 0, size: segment.size, ty: Type::Vector(segment.size), + access: AccessType::Default, }) ), None ); - for binding in segment.bindings.iter().copied() { + for binding in segment.bindings.iter().cloned() { assert_eq!( self.locals.insert( NamespacedIdentifier::Binding(binding.name.unwrap()), @@ -176,6 +194,7 @@ impl VisitMut for SemanticAnalysis<'_> { offset: binding.offset, size: binding.size, ty: binding.ty, + access: binding.access, }) ), None @@ -207,8 +226,7 @@ impl VisitMut for SemanticAnalysis<'_> { } // It should be impossible for there to be a local by this name at this point assert_eq!( - self.locals - .insert(namespaced_name, BindingType::Constant(constant.ty())), + self.locals.insert(namespaced_name, BindingType::Constant(constant.ty())), None ); } @@ -254,11 +272,7 @@ impl VisitMut for SemanticAnalysis<'_> { if let Some((prev, _)) = self.imported.get_key_value(&namespaced_name) { self.declaration_import_conflict(namespaced_name.span(), prev.span())?; } - assert_eq!( - self.locals - .insert(namespaced_name, BindingType::Bus(bus.bus_type)), - None - ); + assert_eq!(self.locals.insert(namespaced_name, BindingType::Bus(bus.bus_type)), None); } // Next, we add any periodic columns to the set of local bindings. @@ -302,16 +316,16 @@ impl VisitMut for SemanticAnalysis<'_> { self.visit_mut_bus(bus)?; } - if let Some(boundary_constraints) = module.boundary_constraints.as_mut() { - if !boundary_constraints.is_empty() { - self.visit_mut_boundary_constraints(boundary_constraints)?; - } + if let Some(boundary_constraints) = module.boundary_constraints.as_mut() + && !boundary_constraints.is_empty() + { + self.visit_mut_boundary_constraints(boundary_constraints)?; } - if let Some(integrity_constraints) = module.integrity_constraints.as_mut() { - if !integrity_constraints.is_empty() { - self.visit_mut_integrity_constraints(integrity_constraints)?; - } + if let Some(integrity_constraints) = module.integrity_constraints.as_mut() + && !integrity_constraints.is_empty() + { + self.visit_mut_integrity_constraints(integrity_constraints)?; } self.current_module = None; @@ -328,6 +342,14 @@ impl VisitMut for SemanticAnalysis<'_> { &mut self, function: &mut EvaluatorFunction, ) -> ControlFlow { + // Set the current function context for dependency tracking + let prev_function = self.current_function.clone(); + let current_item = QualifiedIdentifier::new( + self.current_module.clone().unwrap(), + NamespacedIdentifier::Function(function.name), + ); + self.current_function = Some(current_item.clone()); + // Only allow integrity constraints in this context self.constraint_mode = ConstraintMode::Integrity; // Start a new lexical scope @@ -350,6 +372,7 @@ impl VisitMut for SemanticAnalysis<'_> { offset: trace_binding.offset, size: trace_binding.size, ty: trace_binding.ty, + access: trace_binding.access.clone(), }), ); } @@ -360,12 +383,17 @@ impl VisitMut for SemanticAnalysis<'_> { // Update the dependency graph for this function let current_item = QualifiedIdentifier::new( - self.current_module.unwrap(), + self.current_module.clone().unwrap(), NamespacedIdentifier::Function(function.name), ); - for (referenced_item, ref_type) in self.referenced.iter() { - let referenced_item = self.deps.add_node(*referenced_item); - self.deps.add_edge(current_item, referenced_item, *ref_type); + let current_item_node_index = self.get_node_index_or_add(¤t_item); + for (referenced_item, ref_type) in self.referenced.clone().iter() { + let referenced_item_node_index = self.get_node_index_or_add(referenced_item); + self.deps_graph.add_edge( + current_item_node_index, + referenced_item_node_index, + *ref_type, + ); } // Restore the original references metadata @@ -374,6 +402,8 @@ impl VisitMut for SemanticAnalysis<'_> { self.locals.exit(); // Disallow constraints self.constraint_mode = ConstraintMode::None; + // Restore previous function context + self.current_function = prev_function; ControlFlow::Continue(()) } @@ -382,6 +412,14 @@ impl VisitMut for SemanticAnalysis<'_> { &mut self, function: &mut Function, ) -> ControlFlow { + // Set the current function context for dependency tracking + let prev_function = self.current_function.clone(); + let current_item = QualifiedIdentifier::new( + self.current_module.clone().unwrap(), + NamespacedIdentifier::Function(function.name), + ); + self.current_function = Some(current_item.clone()); + // constraints are not allowed in pure functions self.constraint_mode = ConstraintMode::None; @@ -395,8 +433,7 @@ impl VisitMut for SemanticAnalysis<'_> { // Add the set of parameters to the current scope, check for conflicts for (param, param_type) in function.params.iter_mut() { let namespaced_name = NamespacedIdentifier::Binding(*param); - self.locals - .insert(namespaced_name, BindingType::Local(*param_type)); + self.locals.insert(namespaced_name, BindingType::Local(*param_type)); } // Visit all of the statements in the body @@ -404,18 +441,25 @@ impl VisitMut for SemanticAnalysis<'_> { // Update the dependency graph for this function let current_item = QualifiedIdentifier::new( - self.current_module.unwrap(), + self.current_module.clone().unwrap(), NamespacedIdentifier::Function(function.name), ); + let current_item_node_index = self.deps_graph.add_node(current_item); for (referenced_item, ref_type) in self.referenced.iter() { - let referenced_item = self.deps.add_node(*referenced_item); - self.deps.add_edge(current_item, referenced_item, *ref_type); + let referenced_item_node_index = self.deps_graph.add_node(referenced_item.clone()); + self.deps_graph.add_edge( + current_item_node_index, + referenced_item_node_index, + *ref_type, + ); } // Restore the original references metadata self.referenced = referenced; // Restore the original lexical scope self.locals.exit(); + // Restore previous function context + self.current_function = prev_function; ControlFlow::Continue(()) } @@ -471,16 +515,18 @@ impl VisitMut for SemanticAnalysis<'_> { ) .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, ConstraintMode::Boundary => self.visit_mut_boundary_constraint(expr), ConstraintMode::Integrity => self.visit_mut_integrity_constraint(expr), } } - /// Comprehension constraints are very similar to those handled by `visit_mut_enforce`, except that they occur in - /// the body of a list comprehension. The comprehension itself is validated normally, but the body of the comprehension - /// must be checked using `visit_mut_enforce`, rather than `visit_mut_scalar_expr`. We do this by setting a flag in the - /// state that is checked in `visit_mut_list_comprehension` to enable checks that are specific to constraints. + /// Comprehension constraints are very similar to those handled by `visit_mut_enforce`, except + /// that they occur in the body of a list comprehension. The comprehension itself is + /// validated normally, but the body of the comprehension must be checked using + /// `visit_mut_enforce`, rather than `visit_mut_scalar_expr`. We do this by setting a flag in + /// the state that is checked in `visit_mut_list_comprehension` to enable checks that are + /// specific to constraints. fn visit_mut_enforce_all( &mut self, expr: &mut ListComprehension, @@ -516,8 +562,7 @@ impl VisitMut for SemanticAnalysis<'_> { self.warn_declaration_shadowed(expr.name.span(), prev.span()); } else { let binding_ty = self.expr_binding_type(&expr.value).unwrap(); - self.locals - .insert(NamespacedIdentifier::Binding(expr.name), binding_ty); + self.locals.insert(NamespacedIdentifier::Binding(expr.name), binding_ty); } // Visit the let body @@ -543,7 +588,8 @@ impl VisitMut for SemanticAnalysis<'_> { // Track the result type of this comprehension expression let mut result_ty = None; - // Add all of the bindings to the local scope, warn on shadowing, error on conflicting bindings + // Add all of the bindings to the local scope, warn on shadowing, error on conflicting + // bindings let mut bound = HashSet::::default(); // Track the successfully typed check bindings for validation let mut binding_tys: Vec<(Identifier, SourceSpan, Option)> = vec![]; @@ -565,26 +611,30 @@ impl VisitMut for SemanticAnalysis<'_> { let iterable = &expr.iterables[i]; let iterable_ty = iterable.ty().unwrap(); - if let Some(expected_ty) = result_ty.replace(iterable_ty) { - if expected_ty != iterable_ty { - self.has_type_errors = true; - // Note: We don't break here but at the end of the module's compilation, as we want to continue to gather as many errors as possible - let _ = self.type_mismatch( - Some(&iterable_ty), - iterable.span(), - &expected_ty, - expr.iterables[0].span(), - expr.span(), - ); - } + if let Some(expected_ty) = result_ty.replace(iterable_ty) + && expected_ty != iterable_ty + { + self.has_type_errors = true; + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible + let _ = self.type_mismatch( + Some(&iterable_ty), + iterable.span(), + &expected_ty, + expr.iterables[0].span(), + expr.span(), + ); } match self.expr_binding_type(iterable) { Ok(iterable_binding_ty) => { let binding_ty = iterable_binding_ty - .access(AccessType::Index(0)) + .access(AccessType::Index(Box::new(ScalarExpr::Const(Span::new( + iterable.span(), + 0, + ))))) .expect("unexpected scalar iterable"); binding_tys.push((binding, iterable.span(), Some(binding_ty))); - } + }, Err(InvalidAccessError::InvalidBinding) => { // We tried to call an evaluator function self.diagnostics @@ -596,7 +646,7 @@ impl VisitMut for SemanticAnalysis<'_> { ) .emit(); return ControlFlow::Break(SemanticAnalysisError::Invalid); - } + }, Err(_) => { // The iterable type is undefined/unresolvable // @@ -609,11 +659,12 @@ impl VisitMut for SemanticAnalysis<'_> { // // For now, we record `None` until all iterables have been visited binding_tys.push((binding, iterable.span(), None)); - } + }, } } - // If we were unable to determine a type for any of the bindings, use a large vector as a placeholder + // If we were unable to determine a type for any of the bindings, use a large vector as a + // placeholder let expected = BindingType::Local(result_ty.unwrap_or(Type::Vector(u32::MAX as usize))); // Bind everything now, resolving any deferred types using our fallback expected type @@ -668,12 +719,24 @@ impl VisitMut for SemanticAnalysis<'_> { FunctionType::Evaluator(_) => DependencyType::Evaluator, _ => DependencyType::Function, }; + + // If we're currently analyzing a function, add a direct dependency edge + // from the current function to the called function. This ensures that + // transitive dependencies across module boundaries are properly tracked. + if let Some(caller) = self.current_function.clone() { + let caller_node = self.get_node_index_or_add(&caller); + let callee_node = self.get_node_index_or_add(&qid); + self.deps_graph.add_edge(caller_node, callee_node, dependency_type); + } + let prev = self.referenced.insert(qid, dependency_type); if prev.is_some() { assert_eq!(prev, Some(dependency_type)); } - // TODO: When we have non-evaluator functions, we must fetch the type in its signature here, - // and store it as the type of the Call expression + + // TODO: When we have non-evaluator functions, we must fetch the type in its + // signature here, and store it as the type of the + // Call expression expr.ty = fty.result(); } } else { @@ -689,11 +752,11 @@ impl VisitMut for SemanticAnalysis<'_> { .emit(); return ControlFlow::Break(SemanticAnalysisError::Invalid); } - } + }, Err(_) => { // We've already raised a diagnostic for this when visiting the access expression assert!(self.has_undefined_variables || self.has_type_errors); - } + }, } // Visit the call arguments @@ -701,7 +764,8 @@ impl VisitMut for SemanticAnalysis<'_> { self.visit_mut_expr(expr)?; } - // Validate arguments for builtin functions, which currently consist only of the sum/prod reducers + // Validate arguments for builtin functions, which currently consist only of the sum/prod + // reducers if expr.is_builtin() { self.validate_call_to_builtin(expr)?; } @@ -710,11 +774,11 @@ impl VisitMut for SemanticAnalysis<'_> { // // * Must be trace bindings or aliases of same // * Must match the type signature of the callee - if let Ok(ty) = callee_binding_ty { - if let BindingType::Function(FunctionType::Evaluator(ref params)) = ty.item { - for (arg, param) in expr.args.iter().zip(params.iter()) { - self.validate_evaluator_argument(expr.span(), arg, param)?; - } + if let Ok(ty) = callee_binding_ty + && let BindingType::Function(FunctionType::Evaluator(ref params)) = ty.item + { + for (arg, param) in expr.args.iter().zip(params.iter()) { + self.validate_evaluator_argument(expr.span(), arg, param)?; } } @@ -733,7 +797,8 @@ impl VisitMut for SemanticAnalysis<'_> { (Ok(Some(lty)), Ok(Some(rty))) => { if lty != rty { self.has_type_errors = true; - // Note: We don't break here but at the end of the module's compilation, as we want to continue to gather as many errors as possible + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible let _ = self.type_mismatch( Some(<y), expr.lhs.span(), @@ -741,9 +806,19 @@ impl VisitMut for SemanticAnalysis<'_> { expr.rhs.span(), expr.span(), ); + } else if lty != Type::Felt { + self.has_type_errors = true; + self.diagnostics + .diagnostic(Severity::Error) + .with_message("unexpected type") + .with_primary_label( + expr.span(), + "binary operations are only allowed on scalar values", + ) + .emit(); } ControlFlow::Continue(()) - } + }, _ => ControlFlow::Continue(()), } } @@ -780,10 +855,10 @@ impl VisitMut for SemanticAnalysis<'_> { *expr = crate::ast::RangeBound::Const(Span::new(expr.span(), value)); ControlFlow::Continue(()) - } + }, Err(err) => ControlFlow::Break(err), } - } + }, const_expr => { self.diagnostics .diagnostic(Severity::Error) @@ -796,9 +871,9 @@ impl VisitMut for SemanticAnalysis<'_> { ) .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, } - } + }, } } @@ -817,13 +892,13 @@ impl VisitMut for SemanticAnalysis<'_> { .with_primary_label(expr.span, format!("invalid constant identifier: {err}")) .emit(); return ControlFlow::Break(SemanticAnalysisError::Invalid); - } + }, }; match binding_ty.item { BindingType::Constant(ty) => { expr.ty = Some(ty); ControlFlow::Continue(()) - } + }, binding_ty => { self.diagnostics .diagnostic(Severity::Error) @@ -836,7 +911,7 @@ impl VisitMut for SemanticAnalysis<'_> { ) .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, } } @@ -890,13 +965,13 @@ impl VisitMut for SemanticAnalysis<'_> { .with_note("It is not allowed to access trace columns with an offset in boundary constraints.") .emit(); } - } + }, ty @ BindingType::PeriodicColumn(_) if self.constraint_mode.is_boundary() => { self.invalid_access_in_constraint(expr.span(), ty); - } + }, ty @ BindingType::PublicInput(_) if self.constraint_mode.is_integrity() => { self.invalid_access_in_constraint(expr.span(), ty); - } + }, _ => (), } @@ -907,14 +982,14 @@ impl VisitMut for SemanticAnalysis<'_> { BindingType::PeriodicColumn(_) => Some(DependencyType::PeriodicColumn), BindingType::Function(_) => { panic!("unexpected function binding in symbol access context") - } + }, _ => None, }; // Update the dependency graph if let Some(dep_type) = dep_type { // If the item is already in the referenced set, it should have the same type - let prev = self.referenced.insert(*qid, dep_type); + let prev = self.referenced.insert(qid.clone(), dep_type); if prev.is_some() { assert_eq!(prev, Some(dep_type)); } @@ -927,18 +1002,19 @@ impl VisitMut for SemanticAnalysis<'_> { match resolved_binding_ty.access(expr.access_type.clone()) { Ok(binding_ty) => { match expr.access_type { - // The only way to distinguish trace bindings of size 1 that are single columns vs vectors - // with a single column is dependent on the access type. A slice of columns of size 1 must + // The only way to distinguish trace bindings of size 1 that are single columns + // vs vectors with a single column is dependent on the + // access type. A slice of columns of size 1 must // be captured as a vector of size 1 AccessType::Slice(ref range) => { let range = range.to_slice_range(); assert_eq!(expr.ty.replace(Type::Vector(range.len())), None) - } + }, // All other access types can be derived from the binding type _ => assert_eq!(expr.ty.replace(binding_ty.ty().unwrap()), None), } ControlFlow::Continue(()) - } + }, Err(err) => { self.has_type_errors = true; self.diagnostics @@ -952,12 +1028,12 @@ impl VisitMut for SemanticAnalysis<'_> { AccessType::Slice(range) => { let range = range.to_slice_range(); Type::Vector(range.len()) - } + }, _ => Type::Felt, }; assert_eq!(expr.ty.replace(ty), None); ControlFlow::Continue(()) - } + }, } } @@ -965,13 +1041,14 @@ impl VisitMut for SemanticAnalysis<'_> { &mut self, expr: &mut ResolvableIdentifier, ) -> ControlFlow { - let current_module = self.current_module.unwrap(); + let current_module = self.current_module.clone().unwrap(); match expr { // If already resolved, and referencing a local variable, there is nothing to do ResolvableIdentifier::Local(_) => ControlFlow::Continue(()), // If already resolved, and referencing a global declaration, there is nothing to do ResolvableIdentifier::Global(_) => ControlFlow::Continue(()), - // If already resolved, and not referencing the current module or the root module, add it to the referenced set + // If already resolved, and not referencing the current module or the root module, add + // it to the referenced set ResolvableIdentifier::Resolved(id) => { // Ignore references to functions in the builtin module if id.is_builtin() { @@ -979,14 +1056,15 @@ impl VisitMut for SemanticAnalysis<'_> { } ControlFlow::Continue(()) - } + }, ResolvableIdentifier::Unresolved(namespaced_id) => { // If locally defined, resolve it to the current module let namespaced_id = *namespaced_id; if let Some(binding_ty) = self.locals.get(&namespaced_id) { match binding_ty { - // This identifier is a local variable, alias to a declaration, or a function parameter + // This identifier is a local variable, alias to a declaration, or a + // function parameter BindingType::Alias(_) | BindingType::Local(_) | BindingType::Vector(_) @@ -994,17 +1072,24 @@ impl VisitMut for SemanticAnalysis<'_> { | BindingType::TraceColumn(_) | BindingType::TraceParam(_) => { *expr = ResolvableIdentifier::Local(namespaced_id.id()); - } + }, // These binding types are module-local declarations BindingType::Constant(_) | BindingType::Function(_) - | BindingType::PeriodicColumn(_) - | BindingType::Bus(_) => { + | BindingType::PeriodicColumn(_) => { *expr = ResolvableIdentifier::Resolved(QualifiedIdentifier::new( current_module, namespaced_id, )); - } + }, + BindingType::Bus(_) => { + // We use the program name to resolve the bus, as it is a globally + // defined item in the root module + *expr = ResolvableIdentifier::Resolved(QualifiedIdentifier::new( + ModuleId::new(vec![self.program.name], self.program.name.span()), + namespaced_id, + )); + }, } return ControlFlow::Continue(()); } @@ -1020,7 +1105,8 @@ impl VisitMut for SemanticAnalysis<'_> { if let Some((imported_id, imported_from)) = self.imported.get_key_value(&namespaced_id) { - let qualified_id = QualifiedIdentifier::new(*imported_from, *imported_id); + let qualified_id = + QualifiedIdentifier::new(imported_from.clone(), *imported_id); *expr = ResolvableIdentifier::Resolved(qualified_id); return ControlFlow::Continue(()); @@ -1038,7 +1124,7 @@ impl VisitMut for SemanticAnalysis<'_> { "no function by this name is declared in scope", ) .emit(); - } + }, NamespacedIdentifier::Binding(_) => { self.diagnostics .diagnostic(Severity::Error) @@ -1048,20 +1134,22 @@ impl VisitMut for SemanticAnalysis<'_> { "this variable / bus is not defined", ) .emit(); - } + }, } ControlFlow::Continue(()) - } + }, } } } impl SemanticAnalysis<'_> { - /// Validate arguments for builtin functions, which currently consist only of the sum/prod reducers + /// Validate arguments for builtin functions, which currently consist only of the sum/prod + /// reducers fn validate_call_to_builtin(&mut self, call: &Call) -> ControlFlow { match call.callee.as_ref().name() { - // The known reducers - each takes a single argument, which must be an aggregate or comprehension + // The known reducers - each takes a single argument, which must be an aggregate or + // comprehension symbols::Sum | symbols::Prod => { match call.args.as_slice() { [arg] => { @@ -1082,13 +1170,14 @@ impl SemanticAnalysis<'_> { ) .emit(); } - } + }, Err(_) => { - // We've already raised a diagnostic for this when visiting the access expression + // We've already raised a diagnostic for this when visiting the + // access expression assert!(self.has_undefined_variables || self.has_type_errors); - } + }, } - } + }, _ => { self.has_type_errors = true; self.diagnostics @@ -1102,9 +1191,9 @@ impl SemanticAnalysis<'_> { ), ) .emit(); - } + }, } - } + }, other => unimplemented!("unrecognized builtin function: {}", other), } ControlFlow::Continue(()) @@ -1121,7 +1210,7 @@ impl SemanticAnalysis<'_> { Expr::SymbolAccess(access) => { match self.access_binding_type(access) { Ok(BindingType::TraceColumn(tr) | BindingType::TraceParam(tr)) => { - if tr.size == param.size { + if tr.tb_size() == param.size { // Success, the argument and parameter types match up, but // we must make sure the segments also match let same_segment = tr.segment == param.id; @@ -1130,57 +1219,57 @@ impl SemanticAnalysis<'_> { let segment_name = segment_id_to_name(tr.segment); self.has_type_errors = true; self.diagnostics - .diagnostic(Severity::Error) - .with_message("invalid evaluator function argument") - .with_primary_label( - arg.span(), - format!( - "callee expects columns from the {expected_segment} trace"), - ) - .with_secondary_label( - tr.span, - format!( - "but this column is from the {segment_name} trace"), - ) - .emit(); + .diagnostic(Severity::Error) + .with_message("invalid evaluator function argument") + .with_primary_label( + arg.span(), + format!( + "callee expects columns from the {expected_segment} trace"), + ) + .with_secondary_label( + tr.span, + format!( + "but this column is from the {segment_name} trace"), + ) + .emit(); } } else { self.has_type_errors = true; self.diagnostics.diagnostic(Severity::Error) - .with_message("invalid call") - .with_primary_label(span, "type mismatch in function argument") - .with_secondary_label(arg.span(), format!("callee expects {} trace columns here, but this binding provides {}", param.size, tr.size)) - .emit(); + .with_message("invalid call") + .with_primary_label(span, "type mismatch in function argument") + .with_secondary_label(arg.span(), format!("callee expects {} trace columns here, but this binding provides {}", param.size, tr.tb_size())) + .emit(); } - } + }, Ok(BindingType::Vector(ref elems)) => { let mut size = 0; for elem in elems.iter() { match elem { BindingType::TraceColumn(tr) | BindingType::TraceParam(tr) => { if tr.segment == param.id { - size += tr.size; + size += tr.tb_size(); } else { let expected_segment = segment_id_to_name(param.id); let segment_name = segment_id_to_name(tr.segment); self.has_type_errors = true; self.diagnostics - .diagnostic(Severity::Error) - .with_message("invalid evaluator function argument") - .with_primary_label( - arg.span(), - format!( - "callee expects columns from the {expected_segment} trace"), - ) - .with_secondary_label( - tr.span, - format!( - "but this column is from the {segment_name} trace"), - ) - .emit(); + .diagnostic(Severity::Error) + .with_message("invalid evaluator function argument") + .with_primary_label( + arg.span(), + format!( + "callee expects columns from the {expected_segment} trace"), + ) + .with_secondary_label( + tr.span, + format!( + "but this column is from the {segment_name} trace"), + ) + .emit(); return ControlFlow::Continue(()); } - } + }, invalid => { self.has_type_errors = true; self.diagnostics @@ -1196,13 +1285,14 @@ impl SemanticAnalysis<'_> { ) .emit(); return ControlFlow::Continue(()); - } + }, } } if size != param.size { self.has_type_errors = true; - // Note: We don't break here but at the end of the module's compilation, as we want to continue to gather as many errors as possible + // Note: We don't break here but at the end of the module's compilation, + // as we want to continue to gather as many errors as possible let _ = self.type_mismatch( Some(&Type::Vector(param.size)), arg.span(), @@ -1211,7 +1301,7 @@ impl SemanticAnalysis<'_> { span, ); } - } + }, Ok(binding_ty) => { self.has_type_errors = true; let expected = BindingType::TraceParam(TraceBinding::new( @@ -1222,7 +1312,8 @@ impl SemanticAnalysis<'_> { param.size, Type::Vector(param.size), )); - // Note: We don't break here but at the end of the module's compilation, as we want to continue to gather as many errors as possible + // Note: We don't break here but at the end of the module's compilation, as + // we want to continue to gather as many errors as possible let _ = self.binding_mismatch( &binding_ty, arg.span(), @@ -1230,23 +1321,25 @@ impl SemanticAnalysis<'_> { param.span(), span, ); - } + }, Err(_) => { - // We've already raised a diagnostic for this when visiting the access expression + // We've already raised a diagnostic for this when visiting the access + // expression assert!(self.has_undefined_variables || self.has_type_errors); - } + }, } - } + }, Expr::Vector(elems) => { - // We need to make sure that the number of columns represented by the vector corresponds to those - // expected by the callee, which requires us to first check each element of the vector, and then - // at the end determine if the sizes line up + // We need to make sure that the number of columns represented by the vector + // corresponds to those expected by the callee, which requires us to + // first check each element of the vector, and then at the end + // determine if the sizes line up let mut size = 0; for elem in elems.iter() { match self.expr_binding_type(elem) { Ok(BindingType::TraceColumn(tr) | BindingType::TraceParam(tr)) => { if tr.segment == param.id { - size += tr.size; + size += tr.tb_size(); } else { let expected_segment = segment_id_to_name(param.id); let segment_name = segment_id_to_name(tr.segment); @@ -1267,7 +1360,7 @@ impl SemanticAnalysis<'_> { .emit(); return ControlFlow::Continue(()); } - } + }, Ok(invalid) => { self.has_type_errors = true; self.diagnostics @@ -1285,11 +1378,12 @@ impl SemanticAnalysis<'_> { ), ) .emit(); - } + }, Err(_) => { - // We've already raised a diagnostic for this when visiting the access expression + // We've already raised a diagnostic for this when visiting the access + // expression assert!(self.has_undefined_variables || self.has_type_errors); - } + }, } } if size != param.size { @@ -1300,7 +1394,7 @@ impl SemanticAnalysis<'_> { .with_secondary_label(arg.span(), format!("callee expects {} trace columns here, but this argument only provides {size}", param.size)) .emit(); } - } + }, wrong => { self.has_type_errors = true; self.diagnostics.diagnostic(Severity::Error) @@ -1308,7 +1402,7 @@ impl SemanticAnalysis<'_> { .with_primary_label(span, "invalid argument for evaluator function") .with_secondary_label(arg.span(), format!("expected a trace binding, or vector of trace bindings here, but got a {wrong}")) .emit(); - } + }, } ControlFlow::Continue(()) @@ -1328,16 +1422,21 @@ impl SemanticAnalysis<'_> { // Visit the expression operands self.visit_mut_symbol_access(&mut access.column)?; - // Ensure the referenced symbol was a trace column, and that it produces a scalar value, or a bus - let (found, _segment) = - match self.resolvable_binding_type(&access.column.name) { - Ok(ty) => match ty.item.access(access.column.access_type.clone()) { + // Ensure the referenced symbol was a trace column, and that it produces a + // scalar value, or a bus + let (found, _segment) = match self + .resolvable_binding_type(&access.column.name) + { + Ok(ty) => { + let accessed_ty = ty.item.access(access.column.access_type.clone()); + match accessed_ty.clone() { Ok(BindingType::TraceColumn(tb)) | Ok(BindingType::TraceParam(tb)) => { - if tb.is_scalar() { - (ty, tb.segment) + let tb_type = tb.ty(); + if tb_type.is_scalar() { + (Span::new(ty.span(), accessed_ty.unwrap()), tb.segment) } else { - let inferred = tb.ty(); + let inferred = tb_type; return self.type_mismatch( Some(&inferred), access.span(), @@ -1346,16 +1445,16 @@ impl SemanticAnalysis<'_> { constraint_span, ); } - } + }, Ok(BindingType::Bus(_)) => { // Buses are valid in boundary constraints - (ty, 0) - } + (ty, TraceSegmentId::Main) + }, Ok(aty) => { let expected = BindingType::TraceColumn(TraceBinding::new( constraint_span, Identifier::new(constraint_span, symbols::Main), - 0, + TraceSegmentId::Main, 0, 1, Type::Felt, @@ -1367,21 +1466,23 @@ impl SemanticAnalysis<'_> { ty.span(), constraint_span, ); - } + }, _ => return ControlFlow::Break(SemanticAnalysisError::Invalid), - }, - Err(_) => { - // We've already raised a diagnostic for the undefined variable - return ControlFlow::Break(SemanticAnalysisError::Invalid); } - }; + }, + Err(_) => { + // We've already raised a diagnostic for the undefined variable + return ControlFlow::Break(SemanticAnalysisError::Invalid); + }, + }; match (found.clone().item, expr.rhs.as_mut()) { - // Buses boundaries can be constrained by null or set to be unconstrained + // Buses boundaries can be constrained by null or set to be + // unconstrained ( BindingType::Bus(_), ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_), - ) => {} + ) => {}, (BindingType::Bus(_), ScalarExpr::SymbolAccess(access)) => { self.visit_mut_resolvable_identifier(&mut access.name)?; self.visit_mut_access_type(&mut access.access_type)?; @@ -1389,16 +1490,19 @@ impl SemanticAnalysis<'_> { let resolved_binding_ty = match self.resolvable_binding_type(&access.name) { Ok(ty) => ty, - // An unresolved identifier at this point means that it is undefined, + // An unresolved identifier at this point means that it is + // undefined, // but we've already raised a diagnostic // - // There is nothing useful we can do here other than continue traversing the module - // gathering as many undefined variable usages as possible before bailing + // There is nothing useful we can do here other than + // continue traversing the module + // gathering as many undefined variable usages as possible + // before bailing Err(_) => return ControlFlow::Continue(()), }; match resolved_binding_ty.item { - BindingType::PublicInput(_) => {} + BindingType::PublicInput(_) => {}, _ => { self.has_type_errors = true; self.invalid_constraint( @@ -1410,9 +1514,9 @@ impl SemanticAnalysis<'_> { "this is not a reference to a public input", ) .emit(); - } + }, } - } + }, (BindingType::Bus(_), _) => { // Buses cannot be constrained otherwise self.has_type_errors = true; @@ -1425,7 +1529,7 @@ impl SemanticAnalysis<'_> { "Only the null value is valid for constraining buses", ) .emit(); - } + }, (_, ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_)) => { // Only buses can be constrained to null or set to be unconstrained self.has_type_errors = true; @@ -1439,30 +1543,33 @@ impl SemanticAnalysis<'_> { ) .with_note("The null / unconstrained keywords are only valid for constraining buses, not columns") .emit(); - } + }, _ => { // Validate that the symbol access produces a scalar value // - // If no type is known, a diagnostic is already emitted, so proceed as if it is valid - if let Some(ty) = access.column.ty.as_ref() { - if !ty.is_scalar() { - // Invalid constraint, only scalar values are allowed - self.type_mismatch( - Some(ty), - access.span(), - &Type::Felt, - found.span(), - constraint_span, - )?; - } + // If no type is known, a diagnostic is already emitted, so proceed + // as if it is valid + if let Some(ty) = access.column.ty.as_ref() + && !ty.is_scalar() + { + // Invalid constraint, only scalar values are allowed + self.type_mismatch( + Some(ty), + access.span(), + &Type::Felt, + found.span(), + constraint_span, + )?; } // Verify that the right-hand expression evaluates to a scalar // - // The only way this is not the case, is if it is a a symbol access which produces an aggregate + // The only way this is not the case, is if it is a a symbol access + // which produces an aggregate self.visit_mut_scalar_expr(expr.rhs.as_mut())?; if let ScalarExpr::SymbolAccess(access) = expr.rhs.as_ref() { - // Ensure this access produces a scalar, or if the type is unknown, assume it is valid + // Ensure this access produces a scalar, or if the type is + // unknown, assume it is valid // because a diagnostic will have already been emitted if !access.ty.as_ref().map(|t| t.is_scalar()).unwrap_or(true) { self.type_mismatch( @@ -1474,25 +1581,25 @@ impl SemanticAnalysis<'_> { )?; } } - } + }, } ControlFlow::Continue(()) - } + }, other => { self.invalid_constraint(other.span(), "expected this to be a reference to a trace column or bus boundary, e.g. `a.first`") .with_note("The given constraint is not a boundary constraint, and only boundary constraints are valid here.") .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, } - } + }, ScalarExpr::BusOperation(expr) => { self.invalid_constraint(expr.span(), "expected an equality expression here") .with_note("Bus operations are only permitted in integrity constraints") .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, ScalarExpr::Call(expr) => { self.invalid_constraint(expr.span(), "expected an equality expression here") .with_note( @@ -1500,7 +1607,7 @@ impl SemanticAnalysis<'_> { ) .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, expr => { self.invalid_constraint(expr.span(), "expected an equality expression here") .with_note( @@ -1508,7 +1615,7 @@ impl SemanticAnalysis<'_> { ) .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, } } @@ -1522,7 +1629,8 @@ impl SemanticAnalysis<'_> { // However, we do need to validate two things: // // 1. That the constraint produces a scalar value - // 2. That the expression is either an equality, a call to an evaluator function, or a bus operation + // 2. That the expression is either an equality, a call to an evaluator function, or a bus + // operation // match expr { ScalarExpr::Binary(expr) if expr.op == BinaryOp::Eq => self.visit_mut_binary_expr(expr), @@ -1533,7 +1641,7 @@ impl SemanticAnalysis<'_> { // Check that the call references an evaluator // // If unresolved, we've already raised a diagnostic for the invalid call - match expr.callee { + match &expr.callee { ResolvableIdentifier::Resolved(callee) => { match callee.id() { id @ NamespacedIdentifier::Function(_) => { @@ -1577,7 +1685,7 @@ impl SemanticAnalysis<'_> { } ResolvableIdentifier::Unresolved(_) => ControlFlow::Continue(()), } - } + }, ScalarExpr::BusOperation(expr) => { // Visit the call normally, so we can resolve the callee identifier self.visit_mut_bus_operation(expr)?; @@ -1585,7 +1693,7 @@ impl SemanticAnalysis<'_> { // Check that the call references an evaluator // // If unresolved, we've already raised a diagnostic for the invalid call - match expr.bus { + match &expr.bus { ResolvableIdentifier::Resolved(bus) => { match bus.id() { id @ NamespacedIdentifier::Binding(_) => { @@ -1623,13 +1731,13 @@ impl SemanticAnalysis<'_> { } ResolvableIdentifier::Unresolved(_) => ControlFlow::Continue(()), } - } + }, expr => { self.invalid_constraint(expr.span(), "expected either an equality expression, a call to an evaluator, or a bus operation here") .with_note("Integrity constraints must be expressed as an equality, e.g. `a = 0`, a call, e.g. `evaluator(a)`, or a bus operation, e.g. `p.insert(a) when 1`") .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) - } + }, } } @@ -1734,18 +1842,16 @@ impl SemanticAnalysis<'_> { .diagnostic(Severity::Error) .with_message("invalid access") .with_primary_label(span, format!("cannot access {ty} here")) - .with_note(format!( - "It is not allowed to access {ty} in {mode} constraints." - )) + .with_note(format!("It is not allowed to access {ty} in {mode} constraints.")) .emit(); } fn expr_binding_type(&self, expr: &Expr) -> Result { match expr { Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), - Expr::Range(range) => Ok(BindingType::Local(Type::Vector( - range.to_slice_range().len(), - ))), + Expr::Range(range) => { + Ok(BindingType::Local(Type::Vector(range.to_slice_range().len()))) + }, Expr::Vector(elems) => { let mut binding_tys = Vec::with_capacity(elems.len()); for elem in elems.iter() { @@ -1753,12 +1859,12 @@ impl SemanticAnalysis<'_> { } Ok(BindingType::Vector(binding_tys)) - } + }, Expr::Matrix(expr) => { let rows = expr.len(); let columns = expr[0].len(); Ok(BindingType::Local(Type::Matrix(rows, columns))) - } + }, Expr::SymbolAccess(expr) => self.access_binding_type(expr), Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding), Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), @@ -1771,9 +1877,9 @@ impl SemanticAnalysis<'_> { // the comprehension is given by the type of the iterables. We // just pick the first iterable to tell us the type self.expr_binding_type(&lc.iterables[0]) - } + }, } - } + }, Expr::Let(expr) => { self.diagnostics .diagnostic(Severity::Bug) @@ -1781,7 +1887,7 @@ impl SemanticAnalysis<'_> { .with_primary_label(expr.span(), "let expressions are not valid here") .emit(); Err(InvalidAccessError::InvalidBinding) - } + }, Expr::BusOperation(_expr) => Ok(BindingType::Local(Type::Felt)), Expr::Null(_) | Expr::Unconstrained(_) => Ok(BindingType::Local(Type::Felt)), } @@ -1804,14 +1910,14 @@ impl SemanticAnalysis<'_> { .get_key_value(&namespaced_id) .map(|(nid, ty)| Span::new(nid.span(), ty.clone())) .ok_or(InvalidAccessError::UndefinedVariable) - } + }, ResolvableIdentifier::Global(id) => { // The item is a declaration in the root module self.globals .get_key_value(id) .map(|(nid, ty)| Span::new(nid.span(), ty.clone())) .ok_or(InvalidAccessError::UndefinedVariable) - } + }, ResolvableIdentifier::Resolved(qid) => self.resolved_binding_type(qid), ResolvableIdentifier::Unresolved(_) => Err(InvalidAccessError::UndefinedVariable), } @@ -1821,7 +1927,7 @@ impl SemanticAnalysis<'_> { &self, qid: &QualifiedIdentifier, ) -> Result, InvalidAccessError> { - if qid.module == self.program.name { + if qid.module.0.item == vec![self.program.name] { // This is the root module, so the value will be in either locals or globals self.locals .get_key_value(&qid.item) @@ -1832,7 +1938,7 @@ impl SemanticAnalysis<'_> { .map(|(k, v)| Span::new(k.span(), v.clone())) }) .ok_or(InvalidAccessError::UndefinedVariable) - } else if qid.module == self.current_module.unwrap() { + } else if qid.module == self.current_module.clone().unwrap() { // This is a reference to a module-local declaration self.locals .get_key_value(&qid.item) @@ -1843,14 +1949,15 @@ impl SemanticAnalysis<'_> { // so we hardcode the type information here match qid.name() { symbols::Sum | symbols::Prod => { - // NOTE: We're using `usize::MAX` elements to indicate a vector of any size, but we - // should probably add this to the Type enum and handle it elsewhere. For the time - // being, functions are not implemented, so the only place this comes up is with these + // NOTE: We're using `usize::MAX` elements to indicate a vector of any size, but + // we should probably add this to the Type enum and handle + // it elsewhere. For the time being, functions are not + // implemented, so the only place this comes up is with these // list folding builtins let folder_ty = FunctionType::Function(vec![Type::Vector(usize::MAX)], Type::Felt); Ok(Span::new(qid.span(), BindingType::Function(folder_ty))) - } + }, name => unimplemented!("unsupported builtin: {}", name), } } else { @@ -1870,14 +1977,35 @@ impl SemanticAnalysis<'_> { ) }) }) + .or_else(|| { + imported_from.functions.get(qid.as_ref()).map(|f| { + Span::new( + f.span(), + BindingType::Function(FunctionType::Function( + f.param_types(), + f.return_type, + )), + ) + }) + }) .ok_or(InvalidAccessError::UndefinedVariable) } } + + /// Adds the given module to the module graph if it does not already exist, + /// returning the corresponding node index. + fn get_node_index_or_add(&mut self, qid: &QualifiedIdentifier) -> petgraph::graph::NodeIndex { + self.deps_nodes.get(qid).cloned().unwrap_or_else(|| { + let index = self.deps_graph.add_node(qid.clone()); + self.deps_nodes.insert(qid.clone(), index); + index + }) + } } fn segment_id_to_name(id: TraceSegmentId) -> Symbol { match id { - 0 => symbols::Main, + TraceSegmentId::Main => symbols::Main, _ => unimplemented!(), } } diff --git a/parser/src/symbols.rs b/parser/src/symbols.rs index 9ab10cc46..0b5022fc0 100644 --- a/parser/src/symbols.rs +++ b/parser/src/symbols.rs @@ -1,10 +1,5 @@ -use core::fmt; -use core::mem; -use core::ops::Deref; -use core::str; - -use std::collections::BTreeMap; -use std::sync::RwLock; +use core::{fmt, mem, ops::Deref, str}; +use std::{collections::BTreeMap, sync::RwLock}; lazy_static::lazy_static! { static ref SYMBOL_TABLE: SymbolTable = SymbolTable::new(); @@ -23,12 +18,8 @@ pub mod predefined { /// The symbol `prod` pub const Prod: Symbol = Symbol::new(3); - pub(super) const __SYMBOLS: &[(Symbol, &str)] = &[ - (Main, "$main"), - (Builtin, "$builtin"), - (Sum, "sum"), - (Prod, "prod"), - ]; + pub(super) const __SYMBOLS: &[(Symbol, &str)] = + &[(Main, "$main"), (Builtin, "$builtin"), (Sum, "sum"), (Prod, "prod")]; } pub use self::predefined::*; @@ -38,9 +29,7 @@ struct SymbolTable { } impl SymbolTable { pub fn new() -> Self { - Self { - interner: RwLock::new(Interner::new()), - } + Self { interner: RwLock::new(Interner::new()) } } } unsafe impl Sync for SymbolTable {} diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 54e3882cc..bcd425ba1 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -44,7 +44,7 @@ impl Pass for ConstantPropagation<'_> { ControlFlow::Break(err) => { self.diagnostics.emit(err.clone()); Err(err) - } + }, } } } @@ -65,7 +65,7 @@ impl<'a> ConstantPropagation<'a> { for (name, constant) in program.constants.iter() { assert_eq!( self.global - .insert(*name, Span::new(constant.span(), constant.value.clone())), + .insert(name.clone(), Span::new(constant.span(), constant.value.clone())), None ); } @@ -107,16 +107,16 @@ impl<'a> ConstantPropagation<'a> { /// When folding a `let`, one of the following can occur: /// - /// * The let-bound variable is non-constant, so the entire let must remain, but we - /// can constant-propagate as much of the bound expression and body as possible. - /// * The let-bound variable is constant, so once we have constant propagated the body, - /// the let is no longer needed, and one of the following happens: - /// * The `let` terminates with a constant expression, so the entire `let` is replaced - /// with that expression. - /// * The `let` terminates with a non-constant expression, or a constraint, so we inline - /// the let body into the containing block. In the non-constant expression case, we - /// replace the `let` with the last expression in the returned block, since in expression - /// position, we may not have a statement block to inline into. + /// * The let-bound variable is non-constant, so the entire let must remain, but we can + /// constant-propagate as much of the bound expression and body as possible. + /// * The let-bound variable is constant, so once we have constant propagated the body, the let + /// is no longer needed, and one of the following happens: + /// * The `let` terminates with a constant expression, so the entire `let` is replaced with + /// that expression. + /// * The `let` terminates with a non-constant expression, or a constraint, so we inline the + /// let body into the containing block. In the non-constant expression case, we replace the + /// `let` with the last expression in the returned block, since in expression position, we + /// may not have a statement block to inline into. fn try_fold_let_expr( &mut self, expr: &mut Let, @@ -135,14 +135,13 @@ impl<'a> ConstantPropagation<'a> { match expr.value { Expr::Const(ref value) => { self.local.insert(expr.name, value.clone()); - } + }, Expr::Range(ref range) => { let span = range.span(); let range = range.to_slice_range(); let vector = range.map(|i| i as u64).collect(); - self.local - .insert(expr.name, Span::new(span, ConstantExpr::Vector(vector))); - } + self.local.insert(expr.name, Span::new(span, ConstantExpr::Vector(vector))); + }, _ => unreachable!(), } } @@ -160,7 +159,7 @@ impl<'a> ConstantPropagation<'a> { match expr.body.last().unwrap() { Statement::Expr(Expr::Const(const_value)) => { Left(Some(Span::new(expr.span(), const_value.item.clone()))) - } + }, _ => Right(core::mem::take(&mut expr.body)), } } else { @@ -188,18 +187,21 @@ impl VisitMut for ConstantPropagation<'_> { // Expression is already folded ScalarExpr::Const(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => { ControlFlow::Continue(()) - } - // Need to check if this access is to a constant value, and transform to a constant if so + }, + // Need to check if this access is to a constant value, and transform to a constant if + // so ScalarExpr::SymbolAccess(sym) => { + self.visit_mut_access_type(&mut sym.access_type)?; + let constant_value = match sym.name { // Possibly a reference to a constant declaration ResolvableIdentifier::Resolved(ref qid) => { self.global.get(qid).cloned().map(|s| (s.span(), s.item)) - } + }, // Possibly a reference to a local bound to a constant ResolvableIdentifier::Local(ref id) => { self.local.get(id).cloned().map(|s| (s.span(), s.item)) - } + }, // Other identifiers cannot possibly be constant _ => None, }; @@ -208,26 +210,58 @@ impl VisitMut for ConstantPropagation<'_> { ConstantExpr::Scalar(value) => { assert_eq!(sym.access_type, AccessType::Default); *expr = ScalarExpr::Const(Span::new(span, value)); - } - ConstantExpr::Vector(value) => match sym.access_type { - AccessType::Index(idx) => { - *expr = ScalarExpr::Const(Span::new(span, value[idx])); - } + }, + ConstantExpr::Vector(value) => match sym.access_type.clone() { + AccessType::Index(idx) => match *idx { + ScalarExpr::Const(idx) => { + if idx.item >= value.len() as u64 { + self.diagnostics.diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access an index which is out of bounds") + .with_primary_label(span, "index out of bounds") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + *expr = ScalarExpr::Const(Span::new( + span, + value[idx.item as usize], + )); + }, + _ => { + self.live.insert(*sym.name.as_ref()); + }, + }, // This access cannot be resolved here, so we need to record the fact // that there are still live uses of this binding _ => { self.live.insert(*sym.name.as_ref()); - } + }, }, - ConstantExpr::Matrix(value) => match sym.access_type { - AccessType::Matrix(row, col) => { - *expr = ScalarExpr::Const(Span::new(span, value[row][col])); - } + ConstantExpr::Matrix(value) => match sym.access_type.clone() { + AccessType::Matrix(row, col) => match (*row, *col) { + (ScalarExpr::Const(row), ScalarExpr::Const(col)) => { + if row.item >= value.len() as u64 + || col.item >= value[row.item as usize].len() as u64 + { + self.diagnostics.diagnostic(miden_diagnostics::Severity::Error) + .with_message("attempted to access an index which is out of bounds") + .with_primary_label(span, "index out of bounds") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + *expr = ScalarExpr::Const(Span::new( + span, + value[row.item as usize][col.item as usize], + )); + }, + _ => { + self.live.insert(*sym.name.as_ref()); + }, + }, // This access cannot be resolved here, so we need to record the fact // that there are still live uses of this binding _ => { self.live.insert(*sym.name.as_ref()); - } + }, }, } } else { @@ -235,7 +269,7 @@ impl VisitMut for ConstantPropagation<'_> { self.live.insert(*sym.name.as_ref()); } ControlFlow::Continue(()) - } + }, // Fold constant expressions ScalarExpr::Binary(binary_expr) => { match self.try_fold_binary_expr(binary_expr) { @@ -244,7 +278,7 @@ impl VisitMut for ConstantPropagation<'_> { *expr = ScalarExpr::Const(folded); } ControlFlow::Continue(()) - } + }, Err(SemanticAnalysisError::InvalidExpr( InvalidExprError::NonConstantExponent(_), )) if self.in_list_comprehension => { @@ -253,10 +287,10 @@ impl VisitMut for ConstantPropagation<'_> { // The check for non-constant exponents in list comprehensions is done // during lowering from MIR to AIR, so we can safely silence it here. ControlFlow::Continue(()) - } + }, Err(err) => ControlFlow::Break(err), } - } + }, // While calls cannot be constant folded, arguments can be ScalarExpr::Call(call) => self.visit_mut_call(call), // This cannot be constant folded @@ -269,40 +303,40 @@ impl VisitMut for ConstantPropagation<'_> { match const_expr.item { ConstantExpr::Scalar(value) => { *expr = ScalarExpr::Const(Span::new(span, value)); - } + }, _ => { self.diagnostics.diagnostic(miden_diagnostics::Severity::Error) .with_message("invalid scalar expression") .with_primary_label(span, "expected scalar value, but this expression evaluates to an aggregate type") .emit(); return ControlFlow::Break(SemanticAnalysisError::Invalid); - } + }, } - } + }, Ok(Left(None)) => (), Ok(Right(mut block)) => match block.pop().unwrap() { Statement::Let(inner_expr) => { *let_expr.as_mut() = inner_expr; - } + }, Statement::Expr(inner_expr) => { match ScalarExpr::try_from(inner_expr) .map_err(SemanticAnalysisError::InvalidExpr) { Ok(scalar_expr) => { *expr = scalar_expr; - } + }, Err(err) => return ControlFlow::Break(err), } - } + }, Statement::Enforce(_) - | Statement::EnforceIf(_, _) + | Statement::EnforceIf(..) | Statement::EnforceAll(_) | Statement::BusEnforce(_) => unreachable!(), }, Err(err) => return ControlFlow::Break(err), } ControlFlow::Continue(()) - } + }, ScalarExpr::BusOperation(expr) => self.visit_mut_bus_operation(expr), } } @@ -316,15 +350,17 @@ impl VisitMut for ConstantPropagation<'_> { // // We deal with symbol accesses directly, as they may evaluate to an aggregate constant Expr::SymbolAccess(access) => { + self.visit_mut_access_type(&mut access.access_type)?; + let constant_value = match access.name { // Possibly a reference to a constant declaration ResolvableIdentifier::Resolved(ref qid) => { self.global.get(qid).cloned().map(|s| (s.span(), s.item)) - } + }, // Possibly a reference to a local bound to a constant ResolvableIdentifier::Local(ref id) => { self.local.get(id).cloned().map(|s| (s.span(), s.item)) - } + }, // Other identifiers cannot possibly be constant _ => None, }; @@ -333,20 +369,27 @@ impl VisitMut for ConstantPropagation<'_> { cexpr @ ConstantExpr::Scalar(_) => { assert_eq!(access.access_type, AccessType::Default); *expr = Expr::Const(Span::new(span, cexpr)); - } + }, ConstantExpr::Vector(value) => match access.access_type.clone() { AccessType::Default => { *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(value))); - } + }, AccessType::Slice(range) => { let range = range.to_slice_range(); let vector = value[range].to_vec(); *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(vector))); - } - AccessType::Index(idx) => { - *expr = - Expr::Const(Span::new(span, ConstantExpr::Scalar(value[idx]))); - } + }, + AccessType::Index(idx) => match *idx { + ScalarExpr::Const(idx) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Scalar(value[idx.item as usize]), + )); + }, + _ => { + self.live.insert(*access.name.as_ref()); + }, + }, ref ty => panic!( "invalid constant reference, expected scalar access, got {ty:?}", ), @@ -354,24 +397,36 @@ impl VisitMut for ConstantPropagation<'_> { ConstantExpr::Matrix(value) => match access.access_type.clone() { AccessType::Default => { *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(value))); - } + }, AccessType::Slice(range) => { let range = range.to_slice_range(); let matrix = value[range].to_vec(); *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(matrix))); - } - AccessType::Index(idx) => { - *expr = Expr::Const(Span::new( - span, - ConstantExpr::Vector(value[idx].clone()), - )); - } - AccessType::Matrix(row, col) => { - *expr = Expr::Const(Span::new( - span, - ConstantExpr::Scalar(value[row][col]), - )); - } + }, + AccessType::Index(idx) => match *idx { + ScalarExpr::Const(idx) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Vector(value[idx.item as usize].clone()), + )); + }, + _ => { + self.live.insert(*access.name.as_ref()); + }, + }, + AccessType::Matrix(row, col) => match (*row, *col) { + (ScalarExpr::Const(row), ScalarExpr::Const(col)) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Scalar( + value[row.item as usize][col.item as usize], + ), + )); + }, + _ => { + self.live.insert(*access.name.as_ref()); + }, + }, }, } } else { @@ -379,7 +434,7 @@ impl VisitMut for ConstantPropagation<'_> { self.live.insert(*access.name.as_ref()); } ControlFlow::Continue(()) - } + }, Expr::Call(call) if call.is_builtin() => { self.visit_mut_call(call)?; match call.callee.as_ref().name() { @@ -396,17 +451,17 @@ impl VisitMut for ConstantPropagation<'_> { }; *expr = Expr::Const(Span::new(span, ConstantExpr::Scalar(folded))); - } + }, invalid => { panic!("bad argument to list folding builtin: {invalid:#?}") - } + }, } } - } + }, invalid => unimplemented!("unknown builtin function: {invalid}"), } ControlFlow::Continue(()) - } + }, Expr::Call(call) => self.visit_mut_call(call), Expr::Binary(binary_expr) => match self.try_fold_binary_expr(binary_expr) { Ok(maybe_folded) => { @@ -417,7 +472,7 @@ impl VisitMut for ConstantPropagation<'_> { )); } ControlFlow::Continue(()) - } + }, Err(SemanticAnalysisError::InvalidExpr(InvalidExprError::NonConstantExponent( _, ))) if self.in_list_comprehension => { @@ -426,7 +481,7 @@ impl VisitMut for ConstantPropagation<'_> { // The check for non-constant exponents in list comprehensions is done // during lowering from MIR to AIR, so we can safely silence it here. ControlFlow::Continue(()) - } + }, Err(err) => ControlFlow::Break(err), }, // Ranges are constant @@ -455,21 +510,17 @@ impl VisitMut for ConstantPropagation<'_> { vector .iter() .map(|expr| match expr { - Expr::Const(Span { - item: ConstantExpr::Scalar(v), - .. - }) => *v, + Expr::Const(Span { item: ConstantExpr::Scalar(v), .. }) => *v, _ => unreachable!(), }) .collect(), ), - Type::Matrix(_, _) => ConstantExpr::Matrix( + Type::Matrix(..) => ConstantExpr::Matrix( vector .iter() .map(|expr| match expr { Expr::Const(Span { - item: ConstantExpr::Vector(vs), - .. + item: ConstantExpr::Vector(vs), .. }) => vs.clone(), _ => unreachable!(), }) @@ -480,7 +531,7 @@ impl VisitMut for ConstantPropagation<'_> { *expr = Expr::Const(Span::new(span, new_expr)); } ControlFlow::Continue(()) - } + }, // Visit matrix elements, and promote the matrix to `Expr::Const` if possible Expr::Matrix(matrix) => { let mut is_constant = true; @@ -507,7 +558,7 @@ impl VisitMut for ConstantPropagation<'_> { *expr = Expr::Const(Span::new(span, matrix)); } ControlFlow::Continue(()) - } + }, // Visit list comprehensions and convert to constant if possible Expr::ListComprehension(lc) => { let old_in_lc = core::mem::replace(&mut self.in_list_comprehension, true); @@ -534,14 +585,8 @@ impl VisitMut for ConstantPropagation<'_> { // All iterables must be the same length, so determine the number of // steps based on the length of the first iterable let max_len = match &lc.iterables[0] { - Expr::Const(Span { - item: ConstantExpr::Vector(elems), - .. - }) => elems.len(), - Expr::Const(Span { - item: ConstantExpr::Matrix(rows), - .. - }) => rows.len(), + Expr::Const(Span { item: ConstantExpr::Vector(elems), .. }) => elems.len(), + Expr::Const(Span { item: ConstantExpr::Matrix(rows), .. }) => rows.len(), Expr::Const(_) => panic!("expected iterable constant, got scalar"), Expr::Range(range) => range.to_slice_range().len(), _ => unreachable!( @@ -557,26 +602,20 @@ impl VisitMut for ConstantPropagation<'_> { { let span = iterable.span(); match iterable { - Expr::Const(Span { - item: ConstantExpr::Vector(elems), - .. - }) => { + Expr::Const(Span { item: ConstantExpr::Vector(elems), .. }) => { let value = ConstantExpr::Scalar(elems[step]); self.local.insert(binding, Span::new(span, value)); - } - Expr::Const(Span { - item: ConstantExpr::Matrix(elems), - .. - }) => { + }, + Expr::Const(Span { item: ConstantExpr::Matrix(elems), .. }) => { let value = ConstantExpr::Vector(elems[step].clone()); self.local.insert(binding, Span::new(span, value)); - } + }, Expr::Range(range) => { let range = range.to_slice_range(); assert!(range.end > range.start + step); let value = ConstantExpr::Scalar((range.start + step) as u64); self.local.insert(binding, Span::new(span, value)); - } + }, _ => unreachable!( "expected iterable constant or range, got {:#?}", iterable @@ -588,16 +627,18 @@ impl VisitMut for ConstantPropagation<'_> { self.visit_mut_scalar_expr(&mut selector)?; match selector { ScalarExpr::Const(selected) => { - // If the selector returns false on this iteration, go to the next step + // If the selector returns false on this iteration, go to the next + // step if *selected == 0 { continue; } - } + }, // The selector cannot be evaluated, bail out early _ => { self.in_list_comprehension = old_in_lc; + self.local.exit(); return ControlFlow::Continue(()); - } + }, } } @@ -610,6 +651,7 @@ impl VisitMut for ConstantPropagation<'_> { folded.push(folded_body.item); } else { self.in_list_comprehension = old_in_lc; + self.local.exit(); return ControlFlow::Continue(()); } } @@ -621,29 +663,29 @@ impl VisitMut for ConstantPropagation<'_> { *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(folded))); self.in_list_comprehension = old_in_lc; ControlFlow::Continue(()) - } + }, Expr::Let(let_expr) => { match self.try_fold_let_expr(let_expr) { Ok(Left(Some(const_expr))) => { *expr = Expr::Const(Span::new(span, const_expr.item)); - } + }, Ok(Left(None)) => (), Ok(Right(mut block)) => match block.pop().unwrap() { Statement::Let(inner_expr) => { *let_expr.as_mut() = inner_expr; - } + }, Statement::Expr(inner_expr) => { *expr = inner_expr; - } + }, Statement::Enforce(_) - | Statement::EnforceIf(_, _) + | Statement::EnforceIf(..) | Statement::EnforceAll(_) | Statement::BusEnforce(_) => unreachable!(), }, Err(err) => return ControlFlow::Break(err), } ControlFlow::Continue(()) - } + }, Expr::BusOperation(expr) => self.visit_mut_bus_operation(expr), Expr::Null(_) | Expr::Unconstrained(_) => ControlFlow::Continue(()), } @@ -670,32 +712,40 @@ impl VisitMut for ConstantPropagation<'_> { match self.try_fold_let_expr(expr) { Ok(Left(Some(const_expr))) => { buffer.push(Statement::Expr(Expr::Const(const_expr))); - } + }, Ok(Left(None)) => (), Ok(Right(mut block)) => { buffer.append(&mut block); - } + }, Err(err) => return ControlFlow::Break(err), } - } + }, Statement::Enforce(expr) => { self.visit_mut_enforce(expr)?; - } + }, Statement::EnforceAll(expr) => { self.in_constraint_comprehension = true; self.visit_mut_list_comprehension(expr)?; self.in_constraint_comprehension = false; - } + }, Statement::Expr(expr) => { self.visit_mut_expr(expr)?; - } + }, Statement::BusEnforce(expr) => { self.in_constraint_comprehension = true; self.visit_mut_list_comprehension(expr)?; self.in_constraint_comprehension = false; - } + }, // This statement type is only present in the AST after inlining - Statement::EnforceIf(_, _) => unreachable!(), + Statement::EnforceIf(match_expr) => { + self.in_constraint_comprehension = true; + for match_arm in match_expr.match_arms.iter_mut() { + // Visit the selector and expression of the match arm + self.visit_mut_scalar_expr(&mut match_arm.condition)?; + self.visit_mut_scalar_expr(&mut match_arm.expr)?; + } + self.in_constraint_comprehension = false; + }, } // If we have a non-empty buffer, then we are collapsing a let into the current block, diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs deleted file mode 100644 index 78b94d547..000000000 --- a/parser/src/transforms/inlining.rs +++ /dev/null @@ -1,1920 +0,0 @@ -use std::{ - collections::{BTreeMap, HashMap, HashSet, VecDeque}, - ops::ControlFlow, - vec, -}; - -use air_pass::Pass; -use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; - -use crate::{ - ast::{visit::VisitMut, *}, - sema::{BindingType, LexicalScope, SemanticAnalysisError}, - symbols, -}; - -use super::constant_propagation; - -/// This pass performs the following transformations on a [Program]: -/// -/// * Monomorphizing and inlining evaluators/functions at their call sites -/// * Unrolling constraint comprehensions into a sequence of scalar constraints -/// * Unrolling list comprehensions into a tree of `let` statements which end in -/// a vector expression (the implicit result of the tree). Each iteration of the -/// unrolled comprehension is reified as a value and bound to a variable so that -/// other transformations may refer to it directly. -/// * Rewriting aliases of top-level declarations to refer to those declarations directly -/// * Removing let-bound variables which are unused, which is also used to clean up -/// after the aliasing rewrite mentioned above. -/// -/// The trickiest transformation comes with inlining the body of evaluators at their -/// call sites, as evaluator parameter lists can arbitrarily destructure/regroup columns -/// provided as arguments for each trace segment. This means that columns can be passed -/// in a variety of configurations as arguments, and the patterns expressed in the evaluator -/// parameter list can arbitrarily reconfigure them for use in the evaluator body. -/// -/// For example, let's say you call an evaluator `foo` with three columns, passed as individual -/// bindings, like so: `foo([a, b, c])`. Let's further assume that the evaluator signature -/// is defined as `ev foo([x[2], y])`. While you might expect that this would be an error, -/// and that the caller would need to provide the columns in the same configuration, that -/// is not the case. Instead, `a` and `b` are implicitly re-bound as a vector of trace column -/// bindings for use in the function body. There is further no requirement that `a` and `b` -/// are consecutive bindings either, as long as they are from the same trace segment. During -/// compilation however, accesses to individual elements of the vector will be rewritten to use -/// the correct binding in the caller after inlining, e.g. an access like `x[1]` becomes `b`. -/// -/// This pass accomplishes three goals: -/// -/// * Remove all function abstractions from the program -/// * Remove all comprehensions from the program -/// * Inline all constraints into the integrity and boundary constraints sections -/// * Make all references to top-level declarations concrete -/// -/// When done, it should be impossible for there to be any invalid trace column references. -/// -/// It is expected that the provided [Program] has already been run through semantic analysis -/// and constant propagation, so a number of assumptions are made with regard to what syntax can -/// be observed at this stage of compilation (e.g. no references to constant declarations, no -/// undefined variables, expressions are well-typed, etc.). -pub struct Inlining<'a> { - // This may be unused for now, but it's helpful to assume its needed in case we want it in the future - #[allow(unused)] - diagnostics: &'a DiagnosticsHandler, - /// The name of the root module - root: Identifier, - /// The global trace segment configuration - trace: Vec, - /// The public_inputs declaration - public_inputs: BTreeMap, - /// All local/global bindings in scope - bindings: LexicalScope, - /// The values of all let-bound variables in scope - let_bound: LexicalScope, - /// All items which must be referenced fully-qualified, namely periodic columns at this point - imported: HashMap, - /// All evaluator functions in the program - evaluators: HashMap, - /// All pure functions in the program - functions: HashMap, - /// A set of identifiers for which accesses should be rewritten. - /// - /// When an identifier is in this set, it means it is a local alias for a trace column, - /// and should be rewritten based on the current `BindingType` associated with the alias - /// identifier in `bindings`. - rewrites: HashSet, - /// The call stack during expansion of a function call. - /// - /// Each time we begin to expand a call, we check if it is already present on the call - /// stack, and if so, raise a diagnostic due to infinite recursion. If not, the callee - /// is pushed on the stack while we expand its body. When we finish expanding the body - /// of the callee, we pop it off this stack, and proceed as usual. - call_stack: Vec, - in_comprehension_constraint: bool, - next_ident_lc: usize, - next_ident: usize, -} -impl Pass for Inlining<'_> { - type Input<'a> = Program; - type Output<'a> = Program; - type Error = SemanticAnalysisError; - - fn run<'a>(&mut self, mut program: Self::Input<'a>) -> Result, Self::Error> { - self.root = program.name; - self.evaluators = program - .evaluators - .iter() - .map(|(k, v)| (*k, v.clone())) - .collect(); - - self.functions = program - .functions - .iter() - .map(|(k, v)| (*k, v.clone())) - .collect(); - - // We'll be referencing the trace configuration during inlining, so keep a copy of it - self.trace.clone_from(&program.trace_columns); - // And the public inputs - self.public_inputs.clone_from(&program.public_inputs); - - // Add all of the local bindings visible in the root module, except for - // constants and periodic columns, which by this point have been rewritten - // to use fully-qualified names (or in the case of constants, have been - // eliminated entirely) - // - // Trace first.. - for segment in program.trace_columns.iter() { - self.bindings.insert( - segment.name, - BindingType::TraceColumn(TraceBinding { - span: segment.name.span(), - segment: segment.id, - name: Some(segment.name), - offset: 0, - size: segment.size, - ty: Type::Vector(segment.size), - }), - ); - for binding in segment.bindings.iter().copied() { - self.bindings.insert( - binding.name.unwrap(), - BindingType::TraceColumn(TraceBinding { - span: segment.name.span(), - segment: segment.id, - name: binding.name, - offset: binding.offset, - size: binding.size, - ty: binding.ty, - }), - ); - } - } - // Public inputs.. - for input in program.public_inputs.values() { - self.bindings.insert( - input.name(), - BindingType::PublicInput(Type::Vector(input.size())), - ); - } - // For periodic columns, we register the imported item, but do not add any to the local bindings. - for (name, periodic) in program.periodic_columns.iter() { - let binding_ty = BindingType::PeriodicColumn(periodic.values.len()); - self.imported.insert(*name, binding_ty); - } - - // The root of the inlining process is the integrity_constraints and - // boundary_constraints blocks. Function calls in inlined functions are - // inlined at the same time as the parent - self.expand_boundary_constraints(&mut program.boundary_constraints)?; - self.expand_integrity_constraints(&mut program.integrity_constraints)?; - - Ok(program) - } -} -impl<'a> Inlining<'a> { - pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { - Self { - diagnostics, - root: Identifier::new(SourceSpan::UNKNOWN, crate::symbols::Main), - trace: vec![], - public_inputs: Default::default(), - bindings: Default::default(), - let_bound: Default::default(), - imported: Default::default(), - evaluators: Default::default(), - functions: Default::default(), - rewrites: Default::default(), - in_comprehension_constraint: false, - call_stack: vec![], - next_ident_lc: 0, - next_ident: 0, - } - } - - /// Generate a new variable - /// - /// This is only used when expanding list comprehensions, so we use a special prefix for - /// these generated identifiers to make it clear what they were expanded from. - fn get_next_ident_lc(&mut self, span: SourceSpan) -> Identifier { - let id = self.next_ident_lc; - self.next_ident_lc += 1; - Identifier::new(span, crate::Symbol::intern(format!("%lc{id}"))) - } - - fn get_next_ident(&mut self, span: SourceSpan) -> Identifier { - let id = self.next_ident; - self.next_ident += 1; - Identifier::new(span, crate::Symbol::intern(format!("%{id}"))) - } - - /// Inline/expand all of the statements in the `boundary_constraints` section - fn expand_boundary_constraints( - &mut self, - body: &mut Vec, - ) -> Result<(), SemanticAnalysisError> { - // Save the current bindings set, as we're entering a new lexical scope - self.bindings.enter(); - // Visit all of the statements, check variable usage, and track referenced imports - self.expand_statement_block(body)?; - // Restore the original lexical scope - self.bindings.exit(); - - Ok(()) - } - - /// Inline/expand all of the statements in the `integrity_constraints` section - fn expand_integrity_constraints( - &mut self, - body: &mut Vec, - ) -> Result<(), SemanticAnalysisError> { - // Save the current bindings set, as we're entering a new lexical scope - self.bindings.enter(); - // Visit all of the statements, check variable usage, and track referenced imports - self.expand_statement_block(body)?; - // Restore the original lexical scope - self.bindings.exit(); - - Ok(()) - } - - /// Expand a block of statements by visiting each statement front-to-back - fn expand_statement_block( - &mut self, - statements: &mut Vec, - ) -> Result<(), SemanticAnalysisError> { - // This conversion is free, and gives us a natural way to treat the block as a queue - let mut buffer: VecDeque = core::mem::take(statements).into(); - // Visit each statement, appending the resulting expansion to the original vector - while let Some(statement) = buffer.pop_front() { - let mut expanded = self.expand_statement(statement)?; - if expanded.is_empty() { - continue; - } - statements.append(&mut expanded); - } - - Ok(()) - } - - /// Expand a single statement into one or more statements which are fully-expanded - fn expand_statement( - &mut self, - statement: Statement, - ) -> Result, SemanticAnalysisError> { - match statement { - // Expanding a let requires special treatment, as let-bound values may be inlined as a block - // of statements, which requires us to rewrite the `let` into a `let` tree - Statement::Let(expr) => self.expand_let(expr), - // A call to an evaluator function is expanded by inlining the function itself at the call site - Statement::Enforce(ScalarExpr::Call(call)) => self.expand_evaluator_callsite(call), - // Constraints are inlined by expanding the constraint expression - Statement::Enforce(expr) => self.expand_constraint(expr), - // Constraint comprehensions are inlined by unrolling the comprehension into a sequence of constraints - Statement::EnforceAll(expr) => { - let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, true); - let result = self.expand_comprehension(expr); - self.in_comprehension_constraint = in_cc; - result - } - // Conditional constraints are expanded like regular constraints, except the selector is applied - // to all constraints in the expansion. - Statement::EnforceIf(expr, mut selector) => { - let mut statements = match expr { - ScalarExpr::Call(call) => self.expand_evaluator_callsite(call)?, - expr => self.expand_constraint(expr)?, - }; - self.rewrite_scalar_expr(&mut selector)?; - // We need to make sure the selector is applied to all constraints in the expansion - for statement in statements.iter_mut() { - let mut visitor = ApplyConstraintSelector { - selector: &selector, - }; - if let ControlFlow::Break(err) = visitor.visit_mut_statement(statement) { - return Err(err); - } - } - Ok(statements) - } - // Expresssions containing function calls require expansion via inlining, otherwise - // all other expression types are introduced during inlining and are thus already expanded, - // but we must still visit them to apply rewrites. - Statement::Expr(expr) => match self.expand_expr(expr)? { - Expr::Let(let_expr) => Ok(vec![Statement::Let(*let_expr)]), - expr => Ok(vec![Statement::Expr(expr)]), - }, - Statement::BusEnforce(_) => { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("buses are not implemented for this Pipeline") - .emit(); - Err(SemanticAnalysisError::Invalid) - } - } - } - - fn expand_expr(&mut self, expr: Expr) -> Result { - match expr { - Expr::Vector(mut elements) => { - let elems = Vec::with_capacity(elements.len()); - for elem in core::mem::replace(&mut elements.item, elems) { - elements.push(self.expand_expr(elem)?); - } - Ok(Expr::Vector(elements)) - } - Expr::Matrix(mut rows) => { - for row in rows.iter_mut() { - let cols = Vec::with_capacity(row.len()); - for col in core::mem::replace(row, cols) { - row.push(self.expand_scalar_expr(col)?); - } - } - Ok(Expr::Matrix(rows)) - } - Expr::Binary(expr) => self.expand_binary_expr(expr), - Expr::Call(expr) => self.expand_call(expr), - Expr::ListComprehension(expr) => { - let mut block = self.expand_comprehension(expr)?; - assert_eq!(block.len(), 1); - Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr) - } - Expr::Let(expr) => { - let mut block = self.expand_let(*expr)?; - assert_eq!(block.len(), 1); - Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr) - } - expr @ (Expr::Const(_) | Expr::Range(_) | Expr::SymbolAccess(_)) => Ok(expr), - Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("buses are not implemented for this Pipeline") - .emit(); - Err(SemanticAnalysisError::Invalid) - } - } - } - - /// Let expressions are expanded using the following rules: - /// - /// * The let-bound expression is expanded first. If it expands to a statement block and - /// not an expression, the block is inlined in place of the let being expanded, and the - /// rest of the expansion takes place at the end of the block; replacing the last statement - /// in the block. If the last statement in the block was an expression, it is treated as - /// the let-bound value. If the last statement in the block was another `let` however, then - /// we recursively walk down the let tree until we reach the bottom, which must always be - /// an expression statement. - /// - /// * The body is expanded in-place after the previous step has been completed. - /// - /// * If a let-bound variable is an alias for a declaration, we replace all uses - /// of the variable with direct references to the declaration, making the let-bound - /// variable dead - /// - /// * If a let-bound variable is dead (i.e. has no references), then the let is elided, - /// by replacing it with the result of expanding its body - fn expand_let(&mut self, expr: Let) -> Result, SemanticAnalysisError> { - let span = expr.span(); - let name = expr.name; - let body = expr.body; - - // Visit the let-bound expression first, since it determines how the rest of the process goes - let value = match expr.value { - // When expanding a call in this context, we're expecting a single - // statement of either `Expr` or `Let` type, as calls to pure functions - // can never contain constraints. - Expr::Call(call) => self.expand_call(call)?, - // Same as above, but for list comprehensions. - // - // The rules for expansion are the same. - Expr::ListComprehension(lc) => { - let mut expanded = self.expand_comprehension(lc)?; - match expanded.pop().unwrap() { - Statement::Let(let_expr) => Expr::Let(Box::new(let_expr)), - Statement::Expr(expr) => expr, - Statement::Enforce(_) - | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) - | Statement::BusEnforce(_) => unreachable!(), - } - } - // The operands of a binary expression can contain function calls, so we must ensure - // that we expand the operands as needed, and then proceed with expanding the let. - Expr::Binary(expr) => self.expand_binary_expr(expr)?, - // Other expressions we visit just to expand rewrites - mut expr => { - self.rewrite_expr(&mut expr)?; - expr - } - }; - - let expr = Let { - span, - name, - value, - body, - }; - - self.expand_let_tree(expr) - } - - /// This is only expected to be called on a let tree which is guaranteed to only have - /// simple values as let-bound expressions, i.e. the `value` of the `Let` requires no - /// expansion or rewrites. You should use `expand_let` in general. - fn expand_let_tree(&mut self, mut expr: Let) -> Result, SemanticAnalysisError> { - // Start new lexical scope for the body - self.bindings.enter(); - self.let_bound.enter(); - let prev_rewrites = self.rewrites.clone(); - - // Register the binding - let binding_ty = self.expr_binding_type(&expr.value).unwrap(); - - // If this let is a vector of trace column bindings, then we can - // elide the let, and rewrite all uses of the let-bound variable - // to the respective elements of the vector - let inline_body = binding_ty.is_trace_binding(); - if inline_body { - self.rewrites.insert(expr.name); - } - self.bindings.insert(expr.name, binding_ty); - self.let_bound.insert(expr.name, expr.value.clone()); - - // Visit the let body - self.expand_statement_block(&mut expr.body)?; - - // Restore the original lexical scope - self.bindings.exit(); - self.let_bound.exit(); - self.rewrites = prev_rewrites; - - // If we're inlining the body, return the body block as the result; - // otherwise re-wrap the `let` as the sole statement in the resulting block - if inline_body { - Ok(expr.body) - } else { - Ok(vec![Statement::Let(expr)]) - } - } - - /// Expand a call to a pure function (including builtin list folding functions) - fn expand_call(&mut self, mut call: Call) -> Result { - if call.is_builtin() { - match call.callee.as_ref().name() { - symbols::Sum => { - assert_eq!(call.args.len(), 1); - self.expand_fold(BinaryOp::Add, call.args.pop().unwrap()) - } - symbols::Prod => { - assert_eq!(call.args.len(), 1); - self.expand_fold(BinaryOp::Mul, call.args.pop().unwrap()) - } - other => unimplemented!("unhandled builtin: {}", other), - } - } else { - self.expand_function_callsite(call) - } - } - - fn expand_scalar_expr( - &mut self, - expr: ScalarExpr, - ) -> Result { - match expr { - ScalarExpr::Binary(expr) if expr.has_block_like_expansion() => { - self.expand_binary_expr(expr).and_then(|expr| { - ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr) - }) - } - ScalarExpr::Call(lhs) => self.expand_call(lhs).and_then(|expr| { - ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr) - }), - mut expr => { - self.rewrite_scalar_expr(&mut expr)?; - Ok(expr) - } - } - } - - fn expand_binary_expr(&mut self, expr: BinaryExpr) -> Result { - let span = expr.span(); - let op = expr.op; - let lhs = self.expand_scalar_expr(*expr.lhs)?; - let rhs = self.expand_scalar_expr(*expr.rhs)?; - - Ok(Expr::Binary(BinaryExpr { - span, - op, - lhs: Box::new(lhs), - rhs: Box::new(rhs), - })) - } - - /// Expand a list folding operation (e.g. sum/prod) over an expression of aggregate type into an equivalent expression tree - fn expand_fold(&mut self, op: BinaryOp, list: Expr) -> Result { - let span = list.span(); - match list { - Expr::Vector(mut elems) => self.expand_vector_fold(span, op, &mut elems), - Expr::ListComprehension(lc) => { - // Expand the comprehension, but ensure we don't treat it like a comprehension constraint - let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, false); - let mut expanded = self.expand_comprehension(lc)?; - self.in_comprehension_constraint = in_cc; - // Apply the fold to the expanded comprehension in the bottom of the let tree - with_let_result(self, &mut expanded, |inliner, value| { - match value { - // The result value of expanding a comprehension _must_ be a vector - Expr::Vector(elems) => { - // We're going to replace the vector binding with the fold - let folded = inliner.expand_vector_fold(span, op, elems)?; - *value = folded; - Ok(None) - } - _ => unreachable!(), - } - })?; - match expanded.pop().unwrap() { - Statement::Expr(expr) => Ok(expr), - Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))), - Statement::Enforce(_) - | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) - | Statement::BusEnforce(_) => unreachable!(), - } - } - Expr::SymbolAccess(ref access) => { - match self.let_bound.get(access.name.as_ref()).cloned() { - Some(expr) => self.expand_fold(op, expr), - None => match self.access_binding_type(access) { - Ok(BindingType::TraceColumn(tb)) => { - let mut vector = vec![]; - for i in 0..tb.size { - vector.push(Expr::SymbolAccess( - access.access(AccessType::Index(i)).unwrap(), - )); - } - self.expand_vector_fold(span, op, &mut vector) - } - Ok(_) | Err(_) => unimplemented!(), - }, - } - } - // Constant propagation will have already folded calls to list-folding builtins - // with constant arguments, so we should panic if we ever see one here - Expr::Const(_) => panic!("expected constant to have been folded"), - // All other invalid expressions should have been caught by now - invalid => panic!("invalid argument to list folding builtin: {invalid:#?}"), - } - } - - /// Expand a list folding operation (e.g. sum/prod) over a vector into an equivalent expression tree - fn expand_vector_fold( - &mut self, - span: SourceSpan, - op: BinaryOp, - vector: &mut Vec, - ) -> Result { - // To expand this fold, we simply produce a nested sequence of BinaryExpr - let mut elems = vector.drain(..); - let mut acc = elems.next().unwrap(); - self.rewrite_expr(&mut acc)?; - let mut acc: ScalarExpr = acc.try_into().map_err(SemanticAnalysisError::InvalidExpr)?; - for mut elem in elems { - self.rewrite_expr(&mut elem)?; - let elem: ScalarExpr = elem.try_into().expect("invalid scalar expr"); - let new_acc = ScalarExpr::Binary(BinaryExpr::new(span, op, acc, elem)); - acc = new_acc; - } - acc.try_into().map_err(SemanticAnalysisError::InvalidExpr) - } - - fn expand_constraint( - &mut self, - constraint: ScalarExpr, - ) -> Result, SemanticAnalysisError> { - // The constraint itself must be an equality at this point, as evaluator - // calls are handled separately in `expand_statement` - match constraint { - ScalarExpr::Binary(BinaryExpr { - op: BinaryOp::Eq, - lhs, - rhs, - span, - }) => { - let lhs = self.expand_scalar_expr(*lhs)?; - let rhs = self.expand_scalar_expr(*rhs)?; - - Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr { - span, - op: BinaryOp::Eq, - lhs: Box::new(lhs), - rhs: Box::new(rhs), - }))]) - } - invalid => unreachable!("unexpected constraint node: {:#?}", invalid), - } - } - - /// This function rewrites expressions which contain accesses for which rewrites have been registered. - fn rewrite_expr(&mut self, expr: &mut Expr) -> Result<(), SemanticAnalysisError> { - match expr { - Expr::Const(_) | Expr::Range(_) => return Ok(()), - Expr::Vector(elems) => { - for elem in elems.iter_mut() { - self.rewrite_expr(elem)?; - } - } - Expr::Matrix(rows) => { - for row in rows.iter_mut() { - for col in row.iter_mut() { - self.rewrite_scalar_expr(col)?; - } - } - } - Expr::Binary(binary_expr) => { - self.rewrite_scalar_expr(binary_expr.lhs.as_mut())?; - self.rewrite_scalar_expr(binary_expr.rhs.as_mut())?; - } - Expr::SymbolAccess(access) => { - if let Some(rewrite) = self.get_trace_access_rewrite(access) { - *access = rewrite; - } - } - Expr::Call(call) => { - for arg in call.args.iter_mut() { - self.rewrite_expr(arg)?; - } - } - // Comprehension rewrites happen when they are expanded, but we do visit the iterables now - Expr::ListComprehension(lc) => { - for expr in lc.iterables.iter_mut() { - self.rewrite_expr(expr)?; - } - } - Expr::Let(let_expr) => { - let mut next = Some(let_expr.as_mut()); - while let Some(next_let) = next.take() { - self.rewrite_expr(&mut next_let.value)?; - match next_let.body.last_mut().unwrap() { - Statement::Let(inner) => { - next = Some(inner); - } - Statement::Expr(expr) => { - self.rewrite_expr(expr)?; - } - Statement::Enforce(_) - | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) - | Statement::BusEnforce(_) => unreachable!(), - } - } - } - Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("buses are not implemented for this Pipeline") - .emit(); - return Err(SemanticAnalysisError::Invalid); - } - } - Ok(()) - } - - /// This function rewrites scalar expressions which contain accesses for which rewrites have been registered. - fn rewrite_scalar_expr(&mut self, expr: &mut ScalarExpr) -> Result<(), SemanticAnalysisError> { - match expr { - ScalarExpr::Const(_) => Ok(()), - ScalarExpr::SymbolAccess(access) - | ScalarExpr::BoundedSymbolAccess(BoundedSymbolAccess { column: access, .. }) => { - if let Some(rewrite) = self.get_trace_access_rewrite(access) { - *access = rewrite; - } - Ok(()) - } - ScalarExpr::Binary(BinaryExpr { op, lhs, rhs, .. }) => { - self.rewrite_scalar_expr(lhs.as_mut())?; - self.rewrite_scalar_expr(rhs.as_mut())?; - match op { - BinaryOp::Exp if !rhs.is_constant() => Err(SemanticAnalysisError::InvalidExpr( - InvalidExprError::NonConstantExponent(rhs.span()), - )), - _ => Ok(()), - } - } - ScalarExpr::Call(expr) => { - for arg in expr.args.iter_mut() { - self.rewrite_expr(arg)?; - } - Ok(()) - } - ScalarExpr::Let(let_expr) => { - let mut next = Some(let_expr.as_mut()); - while let Some(next_let) = next.take() { - self.rewrite_expr(&mut next_let.value)?; - match next_let.body.last_mut().unwrap() { - Statement::Let(inner) => { - next = Some(inner); - } - Statement::Expr(expr) => { - self.rewrite_expr(expr)?; - } - Statement::Enforce(_) - | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) - | Statement::BusEnforce(_) => unreachable!(), - } - } - Ok(()) - } - ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => { - self.diagnostics - .diagnostic(Severity::Error) - .with_message("buses are not implemented for this Pipeline") - .emit(); - Err(SemanticAnalysisError::Invalid) - } - } - } - - /// This function expands a comprehension into a sequence of statements. - /// - /// This is done using abstract interpretation. By this point in the compilation process, - /// all iterables should have been typed and have known static sizes. Some iterables may even - /// be constant, such as in the case of ranges. Because of this, we are able to "unroll" the - /// comprehension, evaluating the effective value of all iterable bindings at each iteration, - /// and rewriting the comprehension body accordingly. - /// - /// Depending on whether this is a standard list comprehension, or a constraint comprehension, - /// the expansion is, respectively: - /// - /// * A tree of let statements (using generated variables), where each let binds the value of a - /// single iteration of the comprehension. The body of the final let, and thus the effective - /// value of the entire tree, is a vector containing all of the bindings in the evaluation - /// order of the comprehension. - /// * A flat list of constraint statements - fn expand_comprehension( - &mut self, - mut expr: ListComprehension, - ) -> Result, SemanticAnalysisError> { - // Lift any function calls in iterable position out of the comprehension, - // binding the result of those calls via `let`. Rewrite the iterable as - // a symbol access to the newly-bound variable. - // - // NOTE: The actual expansion of the lifted iterables occurs after we expand - // the comprehension, so that we can place the expanded comprehension in the - // body of the final let - let mut lifted_bindings = vec![]; - let mut lifted = vec![]; - for param in expr.iterables.iter_mut() { - if !matches!(param, Expr::Call(_)) { - continue; - } - - let span = param.span(); - let name = self.get_next_ident(span); - let ty = match param { - Expr::Call(Call { callee, .. }) => { - let callee = callee - .resolved() - .expect("callee should have been resolved by now"); - self.functions[&callee].return_type - } - _ => unsafe { core::hint::unreachable_unchecked() }, - }; - let param = core::mem::replace( - param, - Expr::SymbolAccess(SymbolAccess { - span, - name: ResolvableIdentifier::Local(name), - access_type: AccessType::Default, - offset: 0, - ty: Some(ty), - }), - ); - match param { - Expr::Call(call) => { - lifted_bindings.push((name, BindingType::Local(ty))); - lifted.push((name, call)); - } - _ => unsafe { core::hint::unreachable_unchecked() }, - } - } - - // Get the number of iterations in this comprehension - let Type::Vector(num_iterations) = expr.ty.unwrap() else { - panic!("invalid comprehension type"); - }; - - // Step the iterables for each iteration, giving each it's own lexical scope - let mut statement_groups = vec![]; - for i in 0..num_iterations { - self.bindings.enter(); - // Ensure any lifted iterables are in scope for the expansion of this iteration - for (name, binding_ty) in lifted_bindings.iter() { - self.bindings.insert(*name, binding_ty.clone()); - } - let expansion = self.expand_comprehension_iteration(&expr, i)?; - // An expansion can be empty if a constraint selector with a constant selector expression - // evaluates to false (allowing us to elide the constraint for that iteration entirely). - if !expansion.is_empty() { - statement_groups.push(expansion); - } - self.bindings.exit(); - } - - // At this point, we have one or more statement groups, representing the expansions - // of each iteration of the comprehension. Additionally, we may have a set of lifted - // iterables which we need to bind (and expand) "around" the expansion of the comprehension - // itself. - // - // In short, we must take this list of statement groups, and flatten/treeify it. Once - // a let binding is introduced into scope, all subsequent statements must occur in the body - // of that let, forming a tree. Consecutive statements which introduce no new bindings do - // not require any nesting, resulting in the groups containing those statements being flattened. - // - // Lastly, whether this is a list or constraint comprehension determines if we will also be - // constructing a vector from the values produced by each iteration, and returning it as the - // result of the comprehension itself. - let span = expr.span(); - if self.in_comprehension_constraint { - Ok(statement_groups.into_iter().flatten().collect()) - } else { - // For list comprehensions, we must emit a let tree that binds each iteration, - // and ensure that the expansion of the iteration itself is properly nested so - // that the lexical scope of all bound variables is correct. This is more complex - // than the constraint comprehension case, as we must emit a single expression - // representing the entire expansion of the comprehension as an aggregate, whereas - // constraints produce no results. - - // Generate a new variable name for each element in the comprehension - let symbols = statement_groups - .iter() - .map(|_| self.get_next_ident_lc(span)) - .collect::>(); - // Generate the list of elements for the vector which is to be the result of the let-tree - let vars = statement_groups - .iter() - .zip(symbols.iter().copied()) - .map(|(group, name)| { - // The type of these statements must be known by now - let ty = match group.last().unwrap() { - Statement::Expr(value) => value.ty(), - Statement::Let(nested) => nested.ty(), - stmt => unreachable!( - "unexpected statement type in comprehension body: {}", - stmt.display(0) - ), - }; - Expr::SymbolAccess(SymbolAccess { - span, - name: ResolvableIdentifier::Local(name), - access_type: AccessType::Default, - offset: 0, - ty, - }) - }) - .collect(); - // Construct the let tree by visiting the statements bottom-up - let acc = vec![Statement::Expr(Expr::Vector(Span::new(span, vars)))]; - let expanded = statement_groups.into_iter().zip(symbols).try_rfold( - acc, - |acc, (mut group, name)| { - match group.pop().unwrap() { - // If the current statement is an expression, it represents the value of this - // iteration of the comprehension, and we must generate a let to bind it, using - // the accumulator expression as the body - Statement::Expr(expr) => { - group.push(Statement::Let(Let::new(span, name, expr, acc))); - } - // If the current statement is a `let`-tree, we need to generate a new `let` at - // the bottom of the tree, which binds the result expression as the value of the - // generated `let`, and uses the accumulator as the body - Statement::Let(mut wrapper) => { - with_let_result(self, &mut wrapper.body, move |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - Ok(Some(Statement::Let(Let::new(span, name, value, acc)))) - })?; - group.push(Statement::Let(wrapper)); - } - _ => unreachable!(), - } - Ok::<_, SemanticAnalysisError>(group) - }, - )?; - // Lastly, construct the let tree for the lifted iterables, placing the expanded - // comprehension at the bottom of that tree. - lifted.into_iter().try_rfold(expanded, |acc, (name, call)| { - let span = call.span(); - match self.expand_call(call)? { - Expr::Let(mut wrapper) => { - with_let_result(self, &mut wrapper.body, move |_, value| { - let value = core::mem::replace( - value, - Expr::Const(Span::new(span, ConstantExpr::Scalar(0))), - ); - Ok(Some(Statement::Let(Let::new(span, name, value, acc)))) - })?; - Ok(vec![Statement::Let(*wrapper)]) - } - expr => Ok(vec![Statement::Let(Let::new(span, name, expr, acc))]), - } - }) - } - } - - fn expand_comprehension_iteration( - &mut self, - lc: &ListComprehension, - index: usize, - ) -> Result, SemanticAnalysisError> { - // Register each iterable binding and its abstract value. - // - // The abstract value is either a constant (in which case it is concrete, not abstract), or - // an expression which represents accessing the iterable at the index corresponding to the - // current iteration. - let mut bound_values = HashMap::::default(); - for (iterable, binding) in lc.iterables.iter().zip(lc.bindings.iter().copied()) { - let abstract_value = match iterable { - // If the iterable is constant, the value of it's corresponding binding is also constant - Expr::Const(constant) => { - let span = constant.span(); - let value = match constant.item { - ConstantExpr::Vector(ref elems) => ConstantExpr::Scalar(elems[index]), - ConstantExpr::Matrix(ref rows) => ConstantExpr::Vector(rows[index].clone()), - // An iterable may never be a scalar value, this will be caught by semantic analysis - ConstantExpr::Scalar(_) => unreachable!(), - }; - let binding_ty = BindingType::Constant(value.ty()); - self.bindings.insert(binding, binding_ty); - Expr::Const(Span::new(span, value)) - } - // Ranges are constant, so same rules as above apply here - Expr::Range(range) => { - let span = range.span(); - let range = range.to_slice_range(); - let binding_ty = BindingType::Constant(Type::Felt); - self.bindings.insert(binding, binding_ty); - Expr::Const(Span::new( - span, - ConstantExpr::Scalar((range.start + index) as u64), - )) - } - // If the iterable was a vector, the abstract value is whatever expression is at - // the corresponding index of the vector. - Expr::Vector(elems) => { - let abstract_value = elems[index].clone(); - let binding_ty = self.expr_binding_type(&abstract_value).unwrap(); - self.bindings.insert(binding, binding_ty); - abstract_value - } - // If the iterable was a matrix, the abstract value is a vector of expressions - // representing the current row of the matrix. We calculate the binding type of - // each element in that vector so that accesses into the vector are well typed. - Expr::Matrix(rows) => { - let row: Vec = rows[index] - .iter() - .cloned() - .map(|se| se.try_into().unwrap()) - .collect(); - let mut tys = vec![]; - for elem in row.iter() { - tys.push(self.expr_binding_type(elem).unwrap()); - } - let binding_ty = BindingType::Vector(tys); - self.bindings.insert(binding, binding_ty); - Expr::Vector(Span::new(rows.span(), row)) - } - // If the iterable was a variable/access, then we must first index into that - // access, and then rewrite it, if applicable. - Expr::SymbolAccess(access) => { - // The access here must be of aggregate type, so index into it for the current iteration - let mut current_access = access.access(AccessType::Index(index)).unwrap(); - // Rewrite the resulting access if we have a rewrite for the underlying symbol - if let Some(rewrite) = self.get_trace_access_rewrite(¤t_access) { - current_access = rewrite; - } - let binding_ty = self.access_binding_type(¤t_access).unwrap(); - self.bindings.insert(binding, binding_ty); - Expr::SymbolAccess(current_access) - } - // Binary expressions are scalar, so cannot be used as iterables, and we don't - // (currently) support nested comprehensions, so it is never possible to observe - // these expression types here. Calls should have been lifted prior to expansion. - Expr::Call(_) - | Expr::Binary(_) - | Expr::ListComprehension(_) - | Expr::Let(_) - | Expr::BusOperation(_) - | Expr::Null(_) - | Expr::Unconstrained(_) => { - unreachable!() - } - }; - bound_values.insert(binding, abstract_value); - } - - // Clone the comprehension body for this iteration, so we don't modify the original - let mut body = lc.body.as_ref().clone(); - - // Rewrite all references to the iterable bindings in the comprehension body - let mut visitor = RewriteIterableBindingsVisitor { - values: &bound_values, - }; - if let ControlFlow::Break(err) = visitor.visit_mut_scalar_expr(&mut body) { - return Err(err); - } - - // Next, handle comprehension filters/selectors as follows: - // - // 1. Selectors are evaluated in the same context as the body, so we must visit iterable references in the same way. - // 2. If a selector has a constant value, we can elide the selector for this iteration. Furthermore, in situations where - // the selector is known false, we can elide the expansion of this iteration entirely. - // - // Since the selector is the last piece we need to construct the Statement corresponding to the expansion of - // this iteration, we do that now before proceeding to the next step. - let statement = if let Some(mut selector) = lc.selector.clone() { - assert!( - self.in_comprehension_constraint, - "selectors are not permitted in list comprehensions" - ); - // #1 - if let ControlFlow::Break(err) = visitor.visit_mut_scalar_expr(&mut selector) { - return Err(err); - } - // #2 - match selector { - // If the selector value is zero, or false, we can elide the expansion entirely - ScalarExpr::Const(value) if value.item == 0 => return Ok(vec![]), - // If the selector value is non-zero, or true, we can elide just the selector - ScalarExpr::Const(_) => Statement::Enforce(body), - // We have a selector that requires evaluation at runtime, we need to emit a conditional scalar constraint - other => Statement::EnforceIf(body, other), - } - } else if self.in_comprehension_constraint { - Statement::Enforce(body) - } else { - Statement::Expr(body.try_into().unwrap()) - }; - - // Next, although we've rewritten the comprehension body corresponding to this iteration, we - // haven't yet performed inlining on it. We do that now, while all of the bindings are - // in scope with the proper values. The result of that expansion is what we emit as the result - // for this iteration. - self.expand_statement(statement) - } - - /// This function handles inlining evaluator function calls. - /// - /// At this point, semantic analysis has verified that the call arguments are valid, in - /// that the number of trace columns passed matches the number of columns expected by the - /// function parameters. However, the number and type of bindings are permitted to be - /// different, as long as the vectors are the same size when expanded - in effect, re-grouping - /// the trace columns at the call boundary. - fn expand_evaluator_callsite( - &mut self, - call: Call, - ) -> Result, SemanticAnalysisError> { - // The callee is guaranteed to be resolved and exist at this point - let callee = call - .callee - .resolved() - .expect("callee should have been resolved by now"); - // We clone the evaluator here as we will be modifying the body during the - // inlining process, and we must not modify the original - let mut evaluator = self.evaluators.get(&callee).unwrap().clone(); - - // This will be the initial set of bindings visible within the evaluator body - // - // This is distinct from `self.bindings` at this point, because the evaluator doesn't - // inherit the caller's scope, it has an entirely new one. - let mut eval_bindings = LexicalScope::default(); - - // Add all referenced (and thus imported) items from the evaluator module - // - // NOTE: This will include constants, periodic columns, and other functions - for (qid, binding_ty) in self.imported.iter() { - if qid.module == callee.module { - eval_bindings.insert(*qid.as_ref(), binding_ty.clone()); - } - } - - // Add trace columns, and other root declarations to the set of - // bindings visible in the evaluator body, _if_ the evaluator is defined in the - // root module. - let is_evaluator_in_root = callee.module == self.root; - if is_evaluator_in_root { - for segment in self.trace.iter() { - eval_bindings.insert( - segment.name, - BindingType::TraceColumn(TraceBinding { - span: segment.name.span(), - segment: segment.id, - name: Some(segment.name), - offset: 0, - size: segment.size, - ty: Type::Vector(segment.size), - }), - ); - for binding in segment.bindings.iter().copied() { - eval_bindings.insert( - binding.name.unwrap(), - BindingType::TraceColumn(TraceBinding { - span: segment.name.span(), - segment: segment.id, - name: binding.name, - offset: binding.offset, - size: binding.size, - ty: binding.ty, - }), - ); - } - } - - for input in self.public_inputs.values() { - eval_bindings.insert( - input.name(), - BindingType::PublicInput(Type::Vector(input.size())), - ); - } - } - - // Match call arguments to function parameters, populating the set of rewrites - // which should be performed on the inlined function body. - // - // NOTE: We create a new nested scope for the parameters in order to avoid conflicting - // with the root declarations - eval_bindings.enter(); - self.populate_evaluator_rewrites( - &mut eval_bindings, - call.args.as_slice(), - evaluator.params.as_slice(), - ); - - // While we're inlining the body, use the set of evaluator bindings we built above - let prev_bindings = core::mem::replace(&mut self.bindings, eval_bindings); - - // Expand the evaluator body into a block of statements - self.expand_statement_block(&mut evaluator.body)?; - - // Restore the caller's bindings before we leave - self.bindings = prev_bindings; - - Ok(evaluator.body) - } - - /// This function handles inlining pure function calls, which must produce an expression - fn expand_function_callsite(&mut self, call: Call) -> Result { - self.bindings.enter(); - // The callee is guaranteed to be resolved and exist at this point - let callee = call - .callee - .resolved() - .expect("callee should have been resolved by now"); - - if self.call_stack.contains(&callee) { - let ifd = self - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid recursive function call") - .with_primary_label(call.span, "recursion occurs due to this function call"); - self.call_stack - .iter() - .rev() - .fold(ifd, |ifd, caller| { - ifd.with_secondary_label(caller.span(), "which was called from") - }) - .emit(); - return Err(SemanticAnalysisError::Invalid); - } else { - self.call_stack.push(callee); - } - - // We clone the function here as we will be modifying the body during the - // inlining process, and we must not modify the original - let mut function = self.functions.get(&callee).unwrap().clone(); - - // This will be the initial set of bindings visible within the function body - // - // This is distinct from `self.bindings` at this point, because the function doesn't - // inherit the caller's scope, it has an entirely new one. - let mut function_bindings = LexicalScope::default(); - - // Add all referenced (and thus imported) items from the function module - // - // NOTE: This will include constants, periodic columns, and other functions - for (qid, binding_ty) in self.imported.iter() { - if qid.module == callee.module { - function_bindings.insert(*qid.as_ref(), binding_ty.clone()); - } - } - - // Add trace columns, and other root declarations to the set of - // bindings visible in the function body, _if_ the function is defined in the - // root module. - let is_function_in_root = callee.module == self.root; - if is_function_in_root { - for segment in self.trace.iter() { - function_bindings.insert( - segment.name, - BindingType::TraceColumn(TraceBinding { - span: segment.name.span(), - segment: segment.id, - name: Some(segment.name), - offset: 0, - size: segment.size, - ty: Type::Vector(segment.size), - }), - ); - for binding in segment.bindings.iter().copied() { - function_bindings.insert( - binding.name.unwrap(), - BindingType::TraceColumn(TraceBinding { - span: segment.name.span(), - segment: segment.id, - name: binding.name, - offset: binding.offset, - size: binding.size, - ty: binding.ty, - }), - ); - } - } - - for input in self.public_inputs.values() { - function_bindings.insert( - input.name(), - BindingType::PublicInput(Type::Vector(input.size())), - ); - } - } - - // Match call arguments to function parameters, populating the set of rewrites - // which should be performed on the inlined function body. - // - // NOTE: We create a new nested scope for the parameters in order to avoid conflicting - // with the root declarations - function_bindings.enter(); - self.populate_function_rewrites( - &mut function_bindings, - call.args.as_slice(), - function.params.as_slice(), - ); - - // While we're inlining the body, use the set of function bindings we built above - let prev_bindings = core::mem::replace(&mut self.bindings, function_bindings); - - // Expand the function body into a block of statements - self.expand_statement_block(&mut function.body)?; - - // Restore the caller's bindings before we leave - self.bindings = prev_bindings; - - // We're done expanding this call, so remove it from the call stack - self.call_stack.pop(); - - match function.body.pop().unwrap() { - Statement::Expr(expr) => Ok(expr), - Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))), - Statement::Enforce(_) - | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) - | Statement::BusEnforce(_) => { - panic!("unexpected constraint in function body") - } - } - } - - /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function. - /// - /// This is done by resolving the arguments provided by the call to the evaluator, with the parameter list of the evaluator itself. - fn populate_evaluator_rewrites( - &mut self, - eval_bindings: &mut LexicalScope, - args: &[Expr], - params: &[TraceSegment], - ) { - // Reset the rewrites set - self.rewrites.clear(); - - // Each argument corresponds to a function parameter, each of which represents a single trace segment - for (arg, segment) in args.iter().zip(params.iter()) { - match arg { - // A variable was passed as an argument for this segment - // - // Arguments by now must have been validated by semantic analysis, and specifically - // in this case, the number of columns in the variable and the number expected by the - // parameter we're binding must be the same. However, a variable may represent a single - // column, a contiguous slice of columns, or a vector of such variables which may be - // non-contiguous. - Expr::SymbolAccess(access) => { - // We use a `BindingType` to track the state of the current input binding being processed. - // - // The initial state is given by the binding type of the access itself, but as we destructure - // the binding according to the parameter binding pattern, we may pop off columns, in which - // case the binding type here gets updated with the remaining columns - let mut binding_ty = Some(self.access_binding_type(access).unwrap()); - // We visit each binding in the trace segment represented by the parameter pattern, - // consuming columns from the input argument until all bindings are matched up. - for binding in segment.bindings.iter() { - // Trace binding declarations are never anonymous, i.e. always have a name - let binding_name = binding.name.unwrap(); - // We can safely assume that there is a binding type available here, - // otherwise the semantic analysis pass missed something - let bt = binding_ty.take().unwrap(); - // Split out the needed columns from the input binding - // - // We can safely assume we were able to obtain all of the needed columns, - // as the semantic analyzer should have caught mismatches. Note, however, - // that these columns may have been gathered from multiple bindings in the caller - let (matched, rest) = bt.split_columns(binding.size).unwrap(); - self.rewrites.insert(binding_name); - eval_bindings.insert(binding_name, matched); - // Update `binding_ty` with whatever remains of the input - binding_ty = rest; - } - } - // An empty vector means there are no bindings for this segment - Expr::Const(Span { - item: ConstantExpr::Vector(items), - .. - }) if items.is_empty() => { - continue; - } - // A vector of bindings was passed as an argument for this segment - // - // This is by far the most complicated scenario to handle when matching up arguments - // to parameters, as we can get them in a variety of combinations: - // - // 1. An exact match in the number and size of bindings in both the input vector and the - // segment represented by the current parameter - // 2. The same number of elements in the vector as bindings in the segment, but the elements - // have different sizes, implicitly regrouping columns between caller/callee - // 3. More elements in the vector than bindings in the segment, typically because the function - // parameter groups together columns passed individually in the caller - // 4. Fewer elements in the vector than bindings in the segment, typically because the function - // parameter destructures an input into multiple bindings - Expr::Vector(inputs) => { - // The index of the input we're currently extracting columns from - let mut index = 0; - // A `BindingType` representing the current trace binding we're extracting columns from, - // can be either of TraceColumn or Vector type - let mut binding_ty = None; - // We drive the matching process by consuming input columns for each segment binding in turn - 'next_binding: for binding in segment.bindings.iter() { - let binding_name = binding.name.unwrap(); - let mut needed = binding.size; - - // When there are insufficient columns for the current parameter binding in the current - // input, we must construct a vector of trace bindings to use as the binding type of - // the current parameter binding when we have all of the needed columns. This is because - // the input columns may come from different trace bindings in the caller, so we can't - // use a single trace binding to represent them. - let mut set = vec![]; - - // We may need to consume multiple input elements to fulfill the needed columns of - // the current parameter binding - we advance this loop whenever we have exhausted - // an input and need to move on to the next one. We may enter this loop with the - // same input index across multiple parameter bindings when the input element is - // larger than the parameter binding, in which case we have split the input and - // stored the remainder in `binding_ty`. - loop { - let input = &inputs[index]; - // The input expression must have been a symbol access, as matrices of columns - // aren't a thing, and there is no other expression type which can produce trace - // bindings. - let Expr::SymbolAccess(access) = input else { - panic!("unexpected element in trace column vector: {input:#?}") - }; - // Unless we have leftover input, initialize `binding_ty` with the binding type of this input - let bt = binding_ty - .take() - .unwrap_or_else(|| self.access_binding_type(access).unwrap()); - match bt.split_columns(needed) { - Ok((matched, rest)) => { - let eval_binding = match matched { - BindingType::TraceColumn(matched) => { - if !set.is_empty() { - // We've obtained all the remaining columns from the current input element, - // possibly with leftovers in the input. However, because we've started - // constructing a vector binding, we must ensure the matched binding is - // expanded into individual columns - for offset in 0..matched.size { - set.push(BindingType::TraceColumn( - TraceBinding { - offset: matched.offset + offset, - size: 1, - ..matched - }, - )); - } - BindingType::Vector(set) - } else { - // The input element perfectly matched the current binding - BindingType::TraceColumn(matched) - } - } - BindingType::Vector(mut matched) => { - if set.is_empty() { - // The input binding was a vector, and had the same number, or - // more, of columns expected by the parameter binding, but may contain - // non-contiguous bindings, so we are unable to use the symbol of - // the access when rewriting accesses to this parameter - BindingType::Vector(matched) - } else { - // Same as above, but we need to append the matched bindings to - // the set we've already started building - set.append(&mut matched); - BindingType::Vector(set) - } - } - _ => unreachable!(), - }; - // This binding has been fulfilled, move to the next one - self.rewrites.insert(binding_name); - eval_bindings.insert(binding_name, eval_binding); - binding_ty = rest; - // If we have no more columns remaining in this input, advance - // to the next input starting with the next binding - if binding_ty.is_none() { - index += 1; - } - continue 'next_binding; - } - Err(BindingType::TraceColumn(partial)) => { - // The input binding wasn't big enough for the parameter, so we must - // start constructing a vector of bindings since the next input is - // unlikely to be contiguous with the current input - for offset in 0..partial.size { - set.push(BindingType::TraceColumn(TraceBinding { - offset: partial.offset + offset, - size: 1, - ..partial - })); - } - needed -= partial.size; - index += 1; - } - Err(BindingType::Vector(mut partial)) => { - // Same as above, but we got a vector instead - set.append(&mut partial); - needed -= partial.len(); - index += 1; - } - Err(_) => unreachable!(), - } - } - } - } - // This should not be possible at this point, but would be an invalid evaluator call, - // only trace columns are permitted - expr => unreachable!("{:#?}", expr), - } - } - } - - fn populate_function_rewrites( - &mut self, - function_bindings: &mut LexicalScope, - args: &[Expr], - params: &[(Identifier, Type)], - ) { - // Reset the rewrites set - self.rewrites.clear(); - - for (arg, (param_name, param_ty)) in args.iter().zip(params.iter()) { - // We can safely assume that there is a binding type available here, - // otherwise the semantic analysis pass missed something - let binding_ty = self.expr_binding_type(arg).unwrap(); - debug_assert_eq!(binding_ty.ty(), Some(*param_ty), "unexpected type mismatch"); - self.rewrites.insert(*param_name); - function_bindings.insert(*param_name, binding_ty); - } - } - - /// Returns a new [SymbolAccess] which should be used in place of `access` in the current scope. - /// - /// This function should only be called on accesses which have a trace column/param [BindingType], - /// but it will simply return `None` for other types, so it is safe to call on all accesses. - fn get_trace_access_rewrite(&self, access: &SymbolAccess) -> Option { - if self.rewrites.contains(access.name.as_ref()) { - // If we have a rewrite for this access, then the bindings map will - // have an accurate trace binding for us; rewrite this access to be - // relative to that trace binding - match self.access_binding_type(access).unwrap() { - BindingType::TraceColumn(tb) => { - let original_binding = self.trace[tb.segment] - .bindings - .iter() - .find(|b| b.name == tb.name) - .unwrap(); - let (access_type, ty) = if original_binding.size == 1 { - (AccessType::Default, Type::Felt) - } else if tb.size == 1 { - ( - AccessType::Index(tb.offset - original_binding.offset), - Type::Felt, - ) - } else { - let start = tb.offset - original_binding.offset; - ( - AccessType::Slice(RangeExpr::from(start..(start + tb.size))), - Type::Vector(tb.size), - ) - }; - Some(SymbolAccess { - span: access.span(), - name: ResolvableIdentifier::Local(tb.name.unwrap()), - access_type, - offset: access.offset, - ty: Some(ty), - }) - } - // We only have a rewrite when the binding type is TraceColumn - invalid => panic!( - "unexpected trace access binding type, expected column(s), got: {:#?}", - &invalid - ), - } - } else { - None - } - } - - fn expr_binding_type(&self, expr: &Expr) -> Result { - let mut bindings = self.bindings.clone(); - eval_expr_binding_type(expr, &mut bindings, &self.imported) - } - - /// Returns the effective [BindingType] of the value produced by the given access - fn access_binding_type(&self, expr: &SymbolAccess) -> Result { - eval_access_binding_type(expr, &self.bindings, &self.imported) - } -} - -/// Returns the effective [BindingType] of the given expression -fn eval_expr_binding_type( - expr: &Expr, - bindings: &mut LexicalScope, - imported: &HashMap, -) -> Result { - match expr { - Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), - Expr::Range(range) => Ok(BindingType::Local(Type::Vector( - range.to_slice_range().len(), - ))), - Expr::Vector(elems) => match elems[0].ty() { - None | Some(Type::Felt) => { - let mut binding_tys = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - binding_tys.push(eval_expr_binding_type(elem, bindings, imported)?); - } - Ok(BindingType::Vector(binding_tys)) - } - Some(Type::Vector(cols)) => { - let rows = elems.len(); - Ok(BindingType::Local(Type::Matrix(rows, cols))) - } - Some(_) => unreachable!(), - }, - Expr::Matrix(expr) => { - let rows = expr.len(); - let columns = expr[0].len(); - Ok(BindingType::Local(Type::Matrix(rows, columns))) - } - Expr::SymbolAccess(access) => eval_access_binding_type(access, bindings, imported), - Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding), - Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), - Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)), - Expr::ListComprehension(lc) => { - // The types of all iterables must be the same, so the type of - // the comprehension is given by the type of the iterables. We - // just pick the first iterable to tell us the type - eval_expr_binding_type(&lc.iterables[0], bindings, imported) - } - Expr::Let(let_expr) => eval_let_binding_ty(let_expr, bindings, imported), - Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => { - unimplemented!("buses are not implemented for this Pipeline") - } - } -} - -/// Returns the effective [BindingType] of the value produced by the given access -fn eval_access_binding_type( - expr: &SymbolAccess, - bindings: &LexicalScope, - imported: &HashMap, -) -> Result { - let binding_ty = bindings - .get(expr.name.as_ref()) - .or_else(|| match expr.name { - ResolvableIdentifier::Resolved(qid) => imported.get(&qid), - _ => None, - }) - .ok_or(InvalidAccessError::UndefinedVariable) - .clone()?; - binding_ty.access(expr.access_type.clone()) -} - -fn eval_let_binding_ty( - let_expr: &Let, - bindings: &mut LexicalScope, - imported: &HashMap, -) -> Result { - let variable_ty = eval_expr_binding_type(&let_expr.value, bindings, imported)?; - bindings.enter(); - bindings.insert(let_expr.name, variable_ty); - let binding_ty = match let_expr.body.last().unwrap() { - Statement::Let(inner_let) => eval_let_binding_ty(inner_let, bindings, imported)?, - Statement::Expr(expr) => eval_expr_binding_type(expr, bindings, imported)?, - Statement::Enforce(_) - | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) - | Statement::BusEnforce(_) => { - unreachable!() - } - }; - bindings.exit(); - Ok(binding_ty) -} - -/// This visitor is used to rewrite uses of iterable bindings within a comprehension body, -/// including expansion of constant accesses. -struct RewriteIterableBindingsVisitor<'a> { - /// This map contains the set of symbols to be rewritten, and the abstract values which - /// should replace them in the comprehension body. - values: &'a HashMap, -} -impl RewriteIterableBindingsVisitor<'_> { - fn rewrite_scalar_access( - &mut self, - access: SymbolAccess, - ) -> ControlFlow> { - let result = match self.values.get(access.name.as_ref()) { - Some(Expr::Const(constant)) => { - let span = constant.span(); - match constant.item { - ConstantExpr::Scalar(value) => { - assert_eq!(access.access_type, AccessType::Default); - Some(ScalarExpr::Const(Span::new(span, value))) - } - ConstantExpr::Vector(ref elems) => match access.access_type { - AccessType::Index(idx) => { - Some(ScalarExpr::Const(Span::new(span, elems[idx]))) - } - invalid => panic!( - "expected vector to be reduced to scalar by access, got {invalid:#?}" - ), - }, - ConstantExpr::Matrix(ref rows) => match access.access_type { - AccessType::Matrix(row, col) => { - Some(ScalarExpr::Const(Span::new(span, rows[row][col]))) - } - invalid => panic!( - "expected matrix to be reduced to scalar by access, got {invalid:#?}", - ), - }, - } - } - Some(Expr::Range(range)) => { - let span = range.span(); - let range = range.to_slice_range(); - match access.access_type { - AccessType::Index(idx) => Some(ScalarExpr::Const(Span::new( - span, - (range.start + idx) as u64, - ))), - invalid => { - panic!("expected range to be reduced to scalar by access, got {invalid:#?}",) - } - } - } - Some(Expr::Vector(elems)) => { - match access.access_type { - AccessType::Index(idx) => Some(elems[idx].clone().try_into().unwrap()), - // This implies that the vector contains an element which is vector-like, - // if the value at `idx` is not, this is an invalid access - AccessType::Matrix(idx, nested_idx) => match &elems[idx] { - Expr::SymbolAccess(saccess) => { - let access = saccess.access(AccessType::Index(nested_idx)).unwrap(); - self.rewrite_scalar_access(access)? - } - invalid => panic!( - "expected vector-like value at {}[{idx}], got: {invalid:#?}", - access.name.as_ref(), - ), - }, - invalid => panic!( - "expected vector to be reduced to scalar by access, got {invalid:#?}" - ), - } - } - Some(Expr::Matrix(elems)) => match access.access_type { - AccessType::Matrix(row, col) => Some(elems[row][col].clone()), - invalid => { - panic!("expected matrix to be reduced to scalar by access, got {invalid:#?}") - } - }, - Some(Expr::SymbolAccess(symbol_access)) => { - let mut new_access = symbol_access.access(access.access_type).unwrap(); - new_access.offset = access.offset; - Some(ScalarExpr::SymbolAccess(new_access)) - } - // These types of expressions will never be observed in this context, as they are - // not valid iterable expressions (except calls, but those are lifted prior to rewrite - // so that their use in this context is always a symbol access). - Some( - Expr::Call(_) - | Expr::Binary(_) - | Expr::ListComprehension(_) - | Expr::Let(_) - | Expr::BusOperation(_) - | Expr::Null(_) - | Expr::Unconstrained(_), - ) => { - unreachable!() - } - None => None, - }; - ControlFlow::Continue(result) - } -} -impl VisitMut for RewriteIterableBindingsVisitor<'_> { - fn visit_mut_scalar_expr( - &mut self, - expr: &mut ScalarExpr, - ) -> ControlFlow { - match expr { - // Nothing to do with constants - ScalarExpr::Const(_) => ControlFlow::Continue(()), - // If we observe an access, try to rewrite it as an iterable binding, if it is - // not a candidate for rewrite, leave it alone. - // - // NOTE: We handle BoundedSymbolAccess here even though comprehension constraints are not - // permitted in boundary_constraints currently. That is handled elsewhere, we just need to - // make sure the symbols themselves are rewritten properly here. - ScalarExpr::SymbolAccess(access) - | ScalarExpr::BoundedSymbolAccess(BoundedSymbolAccess { column: access, .. }) => { - if let Some(replacement) = self.rewrite_scalar_access(access.clone())? { - *expr = replacement; - return ControlFlow::Continue(()); - } - ControlFlow::Continue(()) - } - // We need to visit both operands of a binary expression - but while we're here, - // check to see if resolving the operands reduces to a constant expression that - // can be folded. - ScalarExpr::Binary(binary_expr) => { - self.visit_mut_binary_expr(binary_expr)?; - match constant_propagation::try_fold_binary_expr(binary_expr) { - Ok(Some(folded)) => { - *expr = ScalarExpr::Const(folded); - ControlFlow::Continue(()) - } - Ok(None) => ControlFlow::Continue(()), - Err(err) => ControlFlow::Break(SemanticAnalysisError::InvalidExpr(err)), - } - } - // If we observe a call here, just rewrite the arguments, inlining happens elsewhere - ScalarExpr::Call(call) => { - for arg in call.args.iter_mut() { - self.visit_mut_expr(arg)?; - } - ControlFlow::Continue(()) - } - // We rewrite comprehension bodies before they are expanded, so it should never be - // the case that we encounter a let here, as they can only be introduced in scalar - // expression position as a result of inlining/expansion - ScalarExpr::Let(_) => unreachable!(), - ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => { - ControlFlow::Break(SemanticAnalysisError::Invalid) - } - } - } -} - -/// This visitor is used to apply a selector expression to all constraints in a block -/// -/// For constraints which already have a selector, this rewrites those selectors to be the -/// logical AND of the original selector and the selector being applied. -struct ApplyConstraintSelector<'a> { - selector: &'a ScalarExpr, -} -impl VisitMut for ApplyConstraintSelector<'_> { - fn visit_mut_statement( - &mut self, - statement: &mut Statement, - ) -> ControlFlow { - match statement { - Statement::Let(expr) => self.visit_mut_let(expr), - Statement::Enforce(expr) => { - let expr = - core::mem::replace(expr, ScalarExpr::Const(Span::new(SourceSpan::UNKNOWN, 0))); - *statement = Statement::EnforceIf(expr, self.selector.clone()); - ControlFlow::Continue(()) - } - Statement::EnforceIf(_, selector) => { - // Combine the selectors - let lhs = core::mem::replace( - selector, - ScalarExpr::Const(Span::new(SourceSpan::UNKNOWN, 0)), - ); - let rhs = self.selector.clone(); - *selector = ScalarExpr::Binary(BinaryExpr::new( - self.selector.span(), - BinaryOp::Mul, - lhs, - rhs, - )); - ControlFlow::Continue(()) - } - Statement::EnforceAll(_) => unreachable!(), - Statement::Expr(_) => ControlFlow::Continue(()), - Statement::BusEnforce(_) => ControlFlow::Break(SemanticAnalysisError::Invalid), - } - } -} - -/// This helper function is used to perform a mutation/replacement based on the expression -/// representing the effective value of a `let`-tree. -/// -/// In particular, this function traverses the tree until it reaches the final `let` body -/// and the last `Expr` in that body. When it does, it invokes `callback` with a mutable -/// reference to that `Expr`. The callback may choose to simply mutate the `Expr`, or it -/// may return a new `Statement` which will be used to replace the `Statement` which -/// contained the `Expr` given to the callback. -/// -/// This is used when expanding calls and list comprehensions, where the expanded form -/// of these is potentially a `let` tree, and we desire to place additional statements -/// in the bottom-most block, or perform some transformation on the expression which acts -/// as the result of the tree. -fn with_let_result( - inliner: &mut Inlining, - entry: &mut Vec, - callback: F, -) -> Result<(), SemanticAnalysisError> -where - F: FnOnce(&mut Inlining, &mut Expr) -> Result, SemanticAnalysisError>, -{ - // Preserve the original lexical scope to be restored on exit - let prev = inliner.bindings.clone(); - - // SAFETY: We must use a raw pointer here because the Rust compiler is not able to - // see that we only ever use the mutable reference once, and that the reference - // is never aliased. - // - // Both of these guarantees are in fact upheld here however, as each iteration of the loop - // is either the last iteration (when we use the mutable reference to mutate the end of the - // bottom-most block), or a traversal to the last child of the current let expression. - // We never alias the mutable reference, and in fact immediately convert back to a mutable - // reference inside the loop to ensure that within the loop body we have some degree of - // compiler-assisted checking of that invariant. - let mut current_block = Some(entry as *mut Vec); - while let Some(parent_block) = current_block.take() { - // SAFETY: We convert the pointer back to a mutable reference here before - // we do anything else to ensure the usual aliasing rules are enforced. - // - // It is further guaranteed that this reference is never improperly aliased - // across iterations, as each iteration is visiting a child of the previous - // iteration's node, i.e. what we're doing here is equivalent to holding a - // mutable reference and using it to mutate a field in a deeply nested struct. - let parent_block = unsafe { &mut *parent_block }; - // A block is guaranteed to always have at least one statement here - match parent_block.last_mut().unwrap() { - // When we hit a block whose last statement is an expression, which - // must also be the bottom-most block of this tree. This expression - // is the effective value of the `let` tree. We will replace this - // node if the callback we were given returns a new `Statement`. In - // either case, we're done once we've handled the callback result. - Statement::Expr(value) => match callback(inliner, value) { - Ok(Some(replacement)) => { - parent_block.pop(); - parent_block.push(replacement); - break; - } - Ok(None) => break, - Err(err) => { - inliner.bindings = prev; - return Err(err); - } - }, - // We've traversed down a level in the let-tree, but there are more to go. - // Set up the next iteration to visit the next block down in the tree. - Statement::Let(let_expr) => { - // Register this binding - let binding_ty = inliner.expr_binding_type(&let_expr.value).unwrap(); - inliner.bindings.insert(let_expr.name, binding_ty); - // Set up the next iteration - current_block = Some(&mut let_expr.body as *mut Vec); - continue; - } - // No other statements types are possible here - _ => unreachable!(), - } - } - - // Restore the original lexical scope - inliner.bindings = prev; - - Ok(()) -} diff --git a/parser/src/transforms/mod.rs b/parser/src/transforms/mod.rs index 6d237b9a1..53d9b90d1 100644 --- a/parser/src/transforms/mod.rs +++ b/parser/src/transforms/mod.rs @@ -1,5 +1,3 @@ mod constant_propagation; -mod inlining; pub use self::constant_propagation::ConstantPropagation; -pub use self::inlining::Inlining; diff --git a/pass/Cargo.toml b/pass/Cargo.toml index ce81446e2..f8cc41331 100644 --- a/pass/Cargo.toml +++ b/pass/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-pass" -version = "0.4.0" +version = "0.5.0" description = "Reusable compiler pass infrastructure for the AirScript compiler" authors.workspace = true license.workspace = true diff --git a/pass/src/lib.rs b/pass/src/lib.rs index 2cdab39ed..eaf2d7601 100644 --- a/pass/src/lib.rs +++ b/pass/src/lib.rs @@ -2,11 +2,14 @@ /// This trait represents anything that can be run as a pass. /// -/// Passes operate on an input value, and return either the same type, or a new type, depending on the nature of the pass. +/// Passes operate on an input value, and return either the same type, or a new type, depending on +/// the nature of the pass. /// -/// Implementations may represent a single pass, or an arbitrary number of passes that will be run as a single unit. +/// Implementations may represent a single pass, or an arbitrary number of passes that will be run +/// as a single unit. /// -/// Functions are valid implementations of `Pass` as long as their signature is `fn(I) -> Result`. +/// Functions are valid implementations of `Pass` as long as their signature is `fn(I) -> +/// Result`. pub trait Pass { type Input<'a>; type Output<'a>; diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000..eef0263e4 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,5 @@ +[toolchain] +channel = "1.90" +components = ["rustfmt", "rust-src", "clippy"] +targets = ["wasm32-unknown-unknown"] +profile = "minimal" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..248799dde --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,18 @@ +array_width = 80 +attr_fn_like_width = 80 +chain_width = 80 +comment_width = 100 +condense_wildcard_suffixes = true +edition = "2024" +fn_call_width = 80 +group_imports = "StdExternalCrate" +imports_granularity = "Crate" +match_block_trailing_comma = true +newline_style = "Unix" +single_line_if_else_max_width = 60 +single_line_let_else_max_width = 60 +struct_lit_width = 40 +struct_variant_width = 40 +use_field_init_shorthand = true +use_try_shorthand = true +wrap_comments = true diff --git a/scripts/check-changelog.sh b/scripts/check-changelog.sh new file mode 100755 index 000000000..f7207ae3e --- /dev/null +++ b/scripts/check-changelog.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -uo pipefail + +CHANGELOG_FILE="CHANGELOG.md" + +if [ "${NO_CHANGELOG_LABEL}" = "true" ]; then + # 'no changelog' set, so finish successfully + echo "\"no changelog\" label has been set" + exit 0 +else + # a changelog check is required + # fail if the diff is empty + if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then + >&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior +can be overridden by using the \"no changelog\" label, which is used for changes +that are trivial / explicitly stated not to require a changelog entry." + exit 1 + fi + + echo "The ${CHANGELOG_FILE} file has been updated." +fi diff --git a/scripts/check-msrv.sh b/scripts/check-msrv.sh new file mode 100755 index 000000000..0bde2955f --- /dev/null +++ b/scripts/check-msrv.sh @@ -0,0 +1,153 @@ +#!/bin/bash +set -e +set -o pipefail + +# Enhanced MSRV checking script for workspace repository +# Checks MSRV for each workspace member and provides helpful error messages + +# ---- utilities -------------------------------------------------------------- + +check_command() { + if ! command -v "$1" >/dev/null 2>&1; then + echo "ERROR: Required command '$1' is not installed or not in PATH" + exit 1 + fi +} + +# Check required commands +check_command "cargo" +check_command "jq" +check_command "rustup" +check_command "sed" +check_command "grep" +check_command "awk" + +# Portable in-place sed (GNU/macOS); usage: sed_i 's/foo/bar/' file +# shellcheck disable=SC2329 # used quoted +sed_i() { + if sed --version >/dev/null 2>&1; then + sed -i "$@" + else + sed -i '' "$@" + fi +} + +# ---- repo root -------------------------------------------------------------- + +# Get the directory where this script is located and change to the parent directory +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$DIR/.." + +echo "Checking MSRV for workspace members..." + +# ---- metadata -------------------------------------------------------------- + +metadata_json="$(cargo metadata --no-deps --format-version 1)" +workspace_root="$(printf '%s' "$metadata_json" | jq -r '.workspace_root')" + +failed_packages="" + +# Iterate actual workspace packages with manifest paths and (maybe) rust_version +# Fields per line (TSV): id name manifest_path rust_version_or_empty +while IFS=$'\t' read -r pkg_id package_name manifest_path rust_version; do + # Derive package directory (avoid external dirname for portability) + package_dir="${manifest_path%/*}" + if [[ -z "$package_dir" || "$package_dir" == "$manifest_path" ]]; then + package_dir="." + fi + + echo "Checking $package_name ($pkg_id) in $package_dir" + + if [[ ! -f "$package_dir/Cargo.toml" ]]; then + echo "WARNING: No Cargo.toml found in $package_dir, skipping..." + continue + fi + + # Prefer cargo metadata's effective rust_version if present + current_msrv="$rust_version" + if [[ -z "$current_msrv" ]]; then + # If the crate inherits: rust-version.workspace = true + if grep -Eq '^\s*rust-version\.workspace\s*=\s*true\b' "$package_dir/Cargo.toml"; then + # Read from workspace root [workspace.package] + current_msrv="$(grep -Eo '^\s*rust-version\s*=\s*"[^"]+"' "$workspace_root/Cargo.toml" | head -n1 | sed -E 's/.*"([^"]+)".*/\1/')" + if [[ -n "$current_msrv" ]]; then + echo " Using workspace MSRV: $current_msrv" + fi + fi + fi + + if [[ -z "$current_msrv" ]]; then + echo "WARNING: No rust-version found (package or workspace) for $package_name" + continue + fi + + echo " Current MSRV: $current_msrv" + + # Try to verify the MSRV + if ! cargo msrv verify --manifest-path "$package_dir/Cargo.toml" >/dev/null 2>&1; then + echo "ERROR: MSRV check failed for $package_name" + failed_packages="$failed_packages $package_name" + + echo "Searching for correct MSRV for $package_name..." + + # Determine the currently-installed stable toolchain version (e.g., "1.81.0") + latest_stable="$(rustup run stable rustc --version 2>/dev/null | awk '{print $2}')" + if [[ -z "$latest_stable" ]]; then latest_stable="1.81.0"; fi + + # Search for the actual MSRV starting from the current one + if actual_msrv=$(cargo msrv find \ + --manifest-path "$package_dir/Cargo.toml" \ + --min "$current_msrv" \ + --max "$latest_stable" \ + --output-format minimal 2>/dev/null); then + echo " Found actual MSRV: $actual_msrv" + echo "" + echo "ERROR SUMMARY for $package_name:" + echo " Package: $package_name" + echo " Directory: $package_dir" + echo " Current (incorrect) MSRV: $current_msrv" + echo " Correct MSRV: $actual_msrv" + echo "" + echo "TO FIX:" + echo " Update rust-version in $package_dir/Cargo.toml from \"$current_msrv\" to \"$actual_msrv\"" + echo "" + echo " Or run this command (portable in-place edit):" + echo " sed_i 's/^\\s*rust-version\\s*=\\s*\"$current_msrv\"/rust-version = \"$actual_msrv\"/' \"$package_dir/Cargo.toml\"" + else + echo " Could not determine correct MSRV automatically" + echo "" + echo "ERROR SUMMARY for $package_name:" + echo " Package: $package_name" + echo " Directory: $package_dir" + echo " Current (incorrect) MSRV: $current_msrv" + echo " Could not automatically determine correct MSRV" + echo "" + echo "TO FIX:" + echo " Run manually: cargo msrv find --manifest-path \"$package_dir/Cargo.toml\"" + fi + echo "-------------------------------------------------------------------------------" + else + echo "OK: MSRV check passed for $package_name" + fi + echo "" + +done < <( + printf '%s' "$metadata_json" \ + | jq -r '. as $m + | $m.workspace_members[] + | . as $id + | ($m.packages[] | select(.id == $id) + | [ .id, .name, .manifest_path, (.rust_version // "") ] | @tsv)' +) + +if [[ -n "$failed_packages" ]]; then + echo "MSRV CHECK FAILED" + echo "" + echo "The following packages have incorrect MSRV settings:$failed_packages" + echo "" + echo "Please fix the rust-version fields in the affected Cargo.toml files as shown above." + exit 1 +else + echo "ALL WORKSPACE MEMBERS PASSED MSRV CHECKS!" + exit 0 +fi \ No newline at end of file