Skip to content

Commit c533311

Browse files
committed
use specialized indices for sum tags
1 parent 2a3a01a commit c533311

File tree

16 files changed

+496
-69
lines changed

16 files changed

+496
-69
lines changed

crates/bindings-macro/src/sats.rs

+18-6
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ pub(crate) fn derive_satstype(ty: &SatsType<'_>) -> TokenStream {
143143
let name = &ty.ident;
144144
let krate = &ty.krate;
145145

146-
let mut add_filterable_value = false;
146+
let mut add_impls_for_plain_enum = false;
147147
let typ = match &ty.data {
148148
SatsTypeData::Product(fields) => {
149149
let fields = fields.iter().map(|field| {
@@ -169,7 +169,7 @@ pub(crate) fn derive_satstype(ty: &SatsType<'_>) -> TokenStream {
169169
SatsTypeData::Sum(variants) => {
170170
// To allow an enum, with all-unit variants, as an index key type,
171171
// add derive `Filterable` for the enum.
172-
add_filterable_value = variants.iter().all(|var| var.ty.is_none());
172+
add_impls_for_plain_enum = variants.iter().all(|var| var.ty.is_none());
173173

174174
let unit = syn::Type::Tuple(syn::TypeTuple {
175175
paren_token: Default::default(),
@@ -209,28 +209,40 @@ pub(crate) fn derive_satstype(ty: &SatsType<'_>) -> TokenStream {
209209
}
210210
let (_, typeid_ty_generics, _) = typeid_generics.split_for_impl();
211211

212-
let impl_filterable_value = if add_filterable_value {
212+
let impl_plain_enum_extras = if add_impls_for_plain_enum {
213213
// These will mostly be empty as lifetime and type parameters must be constrained
214214
// but const parameters don't require constraining.
215215
let mut generics = ty.generics.clone();
216216
add_type_bounds(&mut generics, &quote!(#krate::FilterableValue));
217217
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
218-
219218
// As we don't have access to other `derive` attributes,
220219
// we don't know if `Copy` was derived,
221220
// so we won't impl for the owned type.
222-
Some(quote! {
221+
let filterable_impl = quote! {
223222
#[automatically_derived]
224223
impl #impl_generics #krate::FilterableValue for &#name #ty_generics #where_clause {
225224
type Column = #name #ty_generics;
226225
}
226+
};
227+
228+
let mut generics = ty.generics.clone();
229+
add_type_bounds(&mut generics, &quote!(#krate::DirectIndexKey));
230+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
231+
let dik_impl = quote! {
232+
#[automatically_derived]
233+
impl #impl_generics #krate::DirectIndexKey for #name #ty_generics #where_clause {}
234+
};
235+
236+
Some(quote! {
237+
#filterable_impl
238+
#dik_impl
227239
})
228240
} else {
229241
None
230242
};
231243

232244
quote! {
233-
#impl_filterable_value
245+
#impl_plain_enum_extras
234246

235247
#[automatically_derived]
236248
impl #impl_generics #krate::SpacetimeType for #name #ty_generics #where_clause {

crates/bindings-macro/src/table.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ impl ValidatedIndex<'_> {
381381
let col_ty = col.ty;
382382
let typeck = quote_spanned!(col_ty.span()=>
383383
const _: () = {
384-
spacetimedb::rt::assert_column_type_valid_for_direct_index::<#col_ty>();
384+
spacetimedb::spacetimedb_lib::assert_column_type_valid_for_direct_index::<#col_ty>();
385385
};
386386
);
387387
(slice::from_ref(col), Some(typeck))

crates/bindings/src/rt.rs

-20
Original file line numberDiff line numberDiff line change
@@ -346,26 +346,6 @@ pub fn register_table<T: Table>() {
346346
})
347347
}
348348

349-
mod sealed_direct_index {
350-
pub trait Sealed {}
351-
}
352-
#[diagnostic::on_unimplemented(
353-
message = "column type must be a one of: `u8`, `u16`, `u32`, or `u64`",
354-
label = "should be `u8`, `u16`, `u32`, or `u64`, not `{Self}`"
355-
)]
356-
pub trait DirectIndexKey: sealed_direct_index::Sealed {}
357-
impl sealed_direct_index::Sealed for u8 {}
358-
impl DirectIndexKey for u8 {}
359-
impl sealed_direct_index::Sealed for u16 {}
360-
impl DirectIndexKey for u16 {}
361-
impl sealed_direct_index::Sealed for u32 {}
362-
impl DirectIndexKey for u32 {}
363-
impl sealed_direct_index::Sealed for u64 {}
364-
impl DirectIndexKey for u64 {}
365-
366-
/// Assert that `T` is a valid column to use direct index on.
367-
pub const fn assert_column_type_valid_for_direct_index<T: DirectIndexKey>() {}
368-
369349
impl From<IndexAlgo<'_>> for RawIndexAlgorithm {
370350
fn from(algo: IndexAlgo<'_>) -> RawIndexAlgorithm {
371351
match algo {

crates/lib/src/direct_index_key.rs

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#[diagnostic::on_unimplemented(
2+
message = "column type must be a one of: `u8`, `u16`, `u32`, `u64`, or plain `enum`",
3+
label = "should be `u8`, `u16`, `u32`, `u64`, or plain `enum`, not `{Self}`"
4+
)]
5+
pub trait DirectIndexKey {}
6+
impl DirectIndexKey for u8 {}
7+
impl DirectIndexKey for u16 {}
8+
impl DirectIndexKey for u32 {}
9+
impl DirectIndexKey for u64 {}
10+
11+
/// Assert that `T` is a valid column to use direct index on.
12+
pub const fn assert_column_type_valid_for_direct_index<T: DirectIndexKey>() {}

crates/lib/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::collections::{btree_map, BTreeMap};
88

99
pub mod connection_id;
1010
pub mod db;
11+
mod direct_index_key;
1112
pub mod error;
1213
mod filterable_value;
1314
pub mod identity;
@@ -27,6 +28,7 @@ pub mod type_value {
2728
}
2829

2930
pub use connection_id::ConnectionId;
31+
pub use direct_index_key::{assert_column_type_valid_for_direct_index, DirectIndexKey};
3032
pub use filterable_value::{FilterableValue, IndexScanRangeBoundsTerminator, TermBound};
3133
pub use identity::Identity;
3234
pub use scheduler::ScheduleAt;

crates/sats/src/algebraic_value.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,13 @@ impl AlgebraicValue {
178178

179179
/// Returns an [`AlgebraicValue`] representing a sum value with `tag` and `value`.
180180
pub fn sum(tag: u8, value: Self) -> Self {
181-
let value = Box::new(value);
182-
Self::Sum(SumValue { tag, value })
181+
Self::Sum(SumValue::new(tag, value))
183182
}
184183

185184
/// Returns an [`AlgebraicValue`] representing a sum value with `tag` and empty [AlgebraicValue::product], that is
186185
/// valid for simple enums without payload.
187186
pub fn enum_simple(tag: u8) -> Self {
188-
let value = Box::new(AlgebraicValue::product(vec![]));
189-
Self::Sum(SumValue { tag, value })
187+
Self::Sum(SumValue::new_simple(tag))
190188
}
191189

192190
/// Returns an [`AlgebraicValue`] representing a product value with the given `elements`.

crates/sats/src/convert.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::sum_value::SumTag;
12
use crate::{i256, u256};
23
use crate::{AlgebraicType, AlgebraicValue, ProductType, ProductValue};
34
use spacetimedb_primitives::{ColId, ConstraintId, IndexId, ScheduleId, SequenceId, TableId};
@@ -25,6 +26,12 @@ impl From<AlgebraicType> for ProductType {
2526
}
2627
}
2728

29+
impl From<()> for AlgebraicValue {
30+
fn from((): ()) -> Self {
31+
AlgebraicValue::unit()
32+
}
33+
}
34+
2835
macro_rules! built_in_into {
2936
($native:ty, $kind:ident) => {
3037
impl From<$native> for AlgebraicValue {
@@ -45,6 +52,7 @@ built_in_into!(&str, String);
4552
built_in_into!(String, String);
4653
built_in_into!(&[u8], Bytes);
4754
built_in_into!(Box<[u8]>, Bytes);
55+
built_in_into!(SumTag, Sum);
4856

4957
macro_rules! system_id {
5058
($name:ident) => {

crates/sats/src/sum_value.rs

+24
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,27 @@ pub struct SumValue {
1414
impl crate::Value for SumValue {
1515
type Type = SumType;
1616
}
17+
18+
impl SumValue {
19+
/// Returns a new `SumValue` with the given `tag` and `value`.
20+
pub fn new(tag: u8, value: impl Into<AlgebraicValue>) -> Self {
21+
let value = Box::from(value.into());
22+
Self { tag, value }
23+
}
24+
25+
/// Returns a new `SumValue` with the given `tag` and unit value.
26+
pub fn new_simple(tag: u8) -> Self {
27+
Self::new(tag, ())
28+
}
29+
}
30+
31+
/// The tag of a `SumValue`.
32+
/// Can be used to read out the tag of a sum value without reading the payload.
33+
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
34+
pub struct SumTag(pub u8);
35+
36+
impl From<SumTag> for SumValue {
37+
fn from(SumTag(tag): SumTag) -> Self {
38+
SumValue::new_simple(tag)
39+
}
40+
}

crates/schema/src/def/validate/v9.rs

+11-4
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,16 @@ impl TableValidator<'_, '_> {
625625
let field = &self.product_type.elements[column.idx()];
626626
let ty = &field.algebraic_type;
627627
use AlgebraicType::*;
628-
if let U8 | U16 | U32 | U64 = ty {
629-
} else {
630-
return Err(ValidationError::DirectIndexOnNonUnsignedInt {
628+
let is_bad_type = match ty {
629+
U8 | U16 | U32 | U64 => false,
630+
Ref(r) => self.module_validator.typespace[*r]
631+
.as_sum()
632+
.is_none_or(|s| !s.is_simple_enum()),
633+
Sum(sum) if sum.is_simple_enum() => false,
634+
_ => true,
635+
};
636+
if is_bad_type {
637+
return Err(ValidationError::DirectIndexOnBadType {
631638
index: name.clone(),
632639
column: field.name.clone().unwrap_or_else(|| column.idx().to_string().into()),
633640
ty: ty.clone().into(),
@@ -1440,7 +1447,7 @@ mod tests {
14401447
.finish();
14411448
let result: Result<ModuleDef> = builder.finish().try_into();
14421449

1443-
expect_error_matching!(result, ValidationError::DirectIndexOnNonUnsignedInt { index, .. } => {
1450+
expect_error_matching!(result, ValidationError::DirectIndexOnBadType { index, .. } => {
14441451
&index[..] == "Bananas_b_idx_direct"
14451452
});
14461453
}

crates/schema/src/error.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ pub enum ValidationError {
6262
#[error("No index found to support unique constraint `{constraint}` for columns `{columns:?}`")]
6363
UniqueConstraintWithoutIndex { constraint: Box<str>, columns: ColSet },
6464
#[error("Direct index does not support type `{ty}` in column `{column}` in index `{index}`")]
65-
DirectIndexOnNonUnsignedInt {
65+
DirectIndexOnBadType {
6666
index: RawIdentifier,
6767
column: RawIdentifier,
6868
ty: PrettyAlgebraicType,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Seeds for failure cases proptest has generated in the past. It is
2+
# automatically read and these particular cases re-run before any
3+
# novel cases are generated.
4+
#
5+
# It is recommended to check this file in to source control so that
6+
# everyone who runs the test benefits from these saved cases.
7+
cc 471733dcdf80f2858344c15b6297070a312d2051ecf0eb56eaf9bf38cd51aab4 # shrinks to start = 38, end = 0
8+
cc 98241ddfa4ab59c91a264cbbd364619c1d02a912406c4402cd01fb0c10a72e10 # shrinks to start = 75, end = 0
9+
cc b058bb0beeaf3bb4f97854dff2bf835e2ea4eaffb1eb3c13c78652018b02b957 # shrinks to start = 130, end = 0, key = 0
10+
cc dc8757b81c7b7cffe9e4a9d43b75cd5814d5fb0667a0c1f401fd5f402fde3d97 # shrinks to start = 37, end = 37, key = 0

crates/table/src/memory_usage.rs

+12
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ impl MemoryUsage for f64 {}
3939
impl MemoryUsage for spacetimedb_sats::F32 {}
4040
impl MemoryUsage for spacetimedb_sats::F64 {}
4141

42+
impl<T: MemoryUsage + ?Sized> MemoryUsage for &T {
43+
fn heap_usage(&self) -> usize {
44+
(*self).heap_usage()
45+
}
46+
}
47+
4248
impl<T: MemoryUsage + ?Sized> MemoryUsage for Box<T> {
4349
fn heap_usage(&self) -> usize {
4450
mem::size_of_val::<T>(self) + T::heap_usage(self)
@@ -65,6 +71,12 @@ impl<T: MemoryUsage> MemoryUsage for [T] {
6571
}
6672
}
6773

74+
impl<T: MemoryUsage, const N: usize> MemoryUsage for [T; N] {
75+
fn heap_usage(&self) -> usize {
76+
self.iter().map(T::heap_usage).sum()
77+
}
78+
}
79+
6880
impl MemoryUsage for str {}
6981

7082
impl<T: MemoryUsage> MemoryUsage for Option<T> {

crates/table/src/read_column.rs

+42-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use crate::{
1111
};
1212
use spacetimedb_sats::{
1313
algebraic_value::{ser::ValueSerializer, Packed},
14-
i256, u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductType, ProductValue, SumValue,
14+
i256,
15+
sum_value::SumTag,
16+
u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductType, ProductValue, SumValue,
1517
};
1618
use std::{cell::Cell, mem};
1719
use thiserror::Error;
@@ -339,6 +341,29 @@ impl_read_column_via_from! {
339341
i256 => Box<i256>;
340342
}
341343

344+
/// SAFETY: `is_compatible_type` only returns true for sum types,
345+
/// and any sum value stores the tag first in BFLATN.
346+
unsafe impl ReadColumn for SumTag {
347+
fn is_compatible_type(ty: &AlgebraicTypeLayout) -> bool {
348+
matches!(ty, AlgebraicTypeLayout::Sum(_))
349+
}
350+
351+
unsafe fn unchecked_read_column(row_ref: RowRef<'_>, layout: &ProductTypeElementLayout) -> Self {
352+
debug_assert!(Self::is_compatible_type(&layout.ty));
353+
354+
let (page, offset) = row_ref.page_and_offset();
355+
let col_offset = offset + PageOffset(layout.offset);
356+
357+
let data = page.get_row_data(col_offset, Size(1));
358+
let data: Result<[u8; 1], _> = data.try_into();
359+
// SAFETY: `<[u8; 1] as TryFrom<&[u8]>` succeeds if and only if the slice's length is `1`.
360+
// We used `1` as both the length of the slice and the array, so we know them to be equal.
361+
let [data] = unsafe { data.unwrap_unchecked() };
362+
363+
Self(data)
364+
}
365+
}
366+
342367
#[cfg(test)]
343368
mod test {
344369
use super::*;
@@ -512,5 +537,21 @@ mod test {
512537

513538
// Use a long string which will hit the blob store.
514539
read_column_long_string { AlgebraicType::String => Box<str> = "long string. ".repeat(2048).into() };
540+
541+
read_sum_value_plain { AlgebraicType::simple_enum(["a", "b"].into_iter()) => SumValue = SumValue::new_simple(1) };
542+
read_sum_tag_plain { AlgebraicType::simple_enum(["a", "b"].into_iter()) => SumTag = SumTag(1) };
543+
}
544+
545+
#[test]
546+
fn read_sum_tag_from_sum_with_payload() {
547+
let algebraic_type = AlgebraicType::sum([("a", AlgebraicType::U8), ("b", AlgebraicType::U16)]);
548+
549+
let mut blob_store = HashMapBlobStore::default();
550+
let mut table = table(ProductType::from([algebraic_type]));
551+
552+
let val = SumValue::new(1, 42u16);
553+
let (_, row_ref) = table.insert(&mut blob_store, &product![val.clone()]).unwrap();
554+
555+
assert_eq!(val.tag, row_ref.read_col::<SumTag>(0).unwrap().0);
515556
}
516557
}

0 commit comments

Comments
 (0)