Skip to content

Commit

Permalink
feat: generalize conv mem layout and ND (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Feb 10, 2025
1 parent c19fa52 commit a7544f4
Show file tree
Hide file tree
Showing 17 changed files with 672 additions and 178 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ jobs:
locked: true
# - name: The Worm Mock
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: Large 1D Conv Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
- name: MNIST Gan Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions benches/accum_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();
Expand Down
3 changes: 3 additions & 0 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;


mod params;

const K: usize = 20;
Expand Down Expand Up @@ -208,6 +209,8 @@ where
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
};
let x = config
.layer_config
Expand Down
106 changes: 106 additions & 0 deletions examples/onnx/1d_conv/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
{
"input_data": [
[
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127,
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127
]
]
}
Binary file added examples/onnx/1d_conv/network.onnx
Binary file not shown.
19 changes: 13 additions & 6 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -57,11 +57,13 @@ pub enum HybridOp {
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
data_format: DataFormat,
},
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
data_format: DataFormat,
},
ReduceMin {
axes: Vec<usize>,
Expand Down Expand Up @@ -154,20 +156,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
kernel_shape,
normalized,
normalized, data_format
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
padding, stride, kernel_shape, normalized
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
HybridOp::MaxPool {
padding,
stride,
pool_dims,
data_format,
} => format!(
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
padding, stride, pool_dims, data_format
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
Expand Down Expand Up @@ -239,6 +242,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
normalized,
data_format,
} => layouts::sumpool(
config,
region,
Expand All @@ -247,6 +251,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
*normalized,
*data_format,
)?,
HybridOp::Recip {
input_scale,
Expand Down Expand Up @@ -287,13 +292,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => layouts::max_pool(
config,
region,
values[..].try_into()?,
padding,
stride,
pool_dims,
*data_format,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?
Expand Down
Loading

0 comments on commit a7544f4

Please sign in to comment.