Skip to content

Commit de8cd65

Browse files
committed
Define insert & union in terms of union_with & insert_with
Interestingly enough, this completely obliterate the performance we gained from the scott-encoding on union_with. Somehow, re-using union_with with a different combining function prevents some optimizations/specializations?
1 parent 51bc926 commit de8cd65

File tree

2 files changed

+26
-132
lines changed

2 files changed

+26
-132
lines changed

lib/aiken/collection/dict.ak

Lines changed: 21 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ pub fn from_pairs(self: Pairs<ByteArray, value>) -> Dict<key, value> {
161161
fn do_from_pairs(xs: Pairs<ByteArray, value>) -> Pairs<ByteArray, value> {
162162
when xs is {
163163
[] -> []
164-
[Pair(k, v), ..rest] -> do_insert(do_from_pairs(rest), k, v)
164+
[Pair(k, v), ..rest] ->
165+
do_insert_with(do_from_pairs(rest), k, v, union.keep_left())
165166
}
166167
}
167168

@@ -681,27 +682,7 @@ pub fn insert(
681682
key k: ByteArray,
682683
value v: value,
683684
) -> Dict<key, value> {
684-
Dict { inner: do_insert(self.inner, k, v) }
685-
}
686-
687-
fn do_insert(
688-
self: Pairs<ByteArray, value>,
689-
key k: ByteArray,
690-
value v: value,
691-
) -> Pairs<ByteArray, value> {
692-
when self is {
693-
[] -> [Pair(k, v)]
694-
[Pair(k2, v2), ..rest] ->
695-
if builtin.less_than_bytearray(k, k2) {
696-
[Pair(k, v), ..self]
697-
} else {
698-
if k == k2 {
699-
[Pair(k, v), ..rest]
700-
} else {
701-
[Pair(k2, v2), ..do_insert(rest, k, v)]
702-
}
703-
}
704-
}
685+
insert_with(self, k, v, union.keep_right())
705686
}
706687

707688
test insert_1() {
@@ -729,37 +710,25 @@ test insert_2() {
729710
/// to the merge function, and the new value is passed as the third argument.
730711
///
731712
/// ```aiken
732-
/// let sum =
733-
/// fn (_k, a, b) { Some(a + b) }
713+
/// use aiken/collection/dict/union
734714
///
735715
/// let result =
736716
/// dict.empty
737-
/// |> dict.insert_with(key: "a", value: 1, with: sum)
738-
/// |> dict.insert_with(key: "b", value: 2, with: sum)
739-
/// |> dict.insert_with(key: "a", value: 3, with: sum)
717+
/// |> dict.insert_with(key: "a", value: 1, with: union.sum)
718+
/// |> dict.insert_with(key: "b", value: 2, with: union.sum)
719+
/// |> dict.insert_with(key: "a", value: 3, with: union.sum)
740720
/// |> dict.to_pairs()
741721
///
742722
/// result == [Pair("a", 4), Pair("b", 2)]
743723
/// ```
744724
pub fn insert_with(
745-
self: Dict<key, value>,
746-
key k: ByteArray,
747-
value v: value,
748-
with: fn(ByteArray, value, value) -> Option<value>,
749-
) -> Dict<key, value> {
750-
Dict {
751-
inner: do_insert_with(self.inner, k, v, fn(k, v1, v2) { with(k, v2, v1) }),
752-
}
753-
}
754-
755-
pub fn insert_with_alt(
756725
self: Dict<key, value>,
757726
key k: ByteArray,
758727
value v: value,
759728
with: UnionStrategy<ByteArray, value>,
760729
) -> Dict<key, value> {
761730
Dict {
762-
inner: do_insert_with_alt(
731+
inner: do_insert_with(
763732
self.inner,
764733
k,
765734
v,
@@ -769,49 +738,33 @@ pub fn insert_with_alt(
769738
}
770739

771740
test insert_with_1() {
772-
let sum =
773-
fn(_k, a, b) { Some(a + b) }
774-
775-
let result =
776-
empty
777-
|> insert_with(key: "foo", value: 1, with: sum)
778-
|> insert_with(key: "bar", value: 2, with: sum)
779-
|> to_pairs()
780-
781-
result == [Pair("bar", 2), Pair("foo", 1)]
782-
}
783-
784-
test insert_with_alt_1() {
785741
let result =
786742
empty
787-
|> insert_with_alt(key: "foo", value: 1, with: union.sum())
788-
|> insert_with_alt(key: "bar", value: 2, with: union.sum())
743+
|> insert_with(key: "foo", value: 1, with: union.sum())
744+
|> insert_with(key: "bar", value: 2, with: union.sum())
789745
|> to_pairs()
790746

791747
result == [Pair("bar", 2), Pair("foo", 1)]
792748
}
793749

794750
test insert_with_2() {
795-
let sum =
796-
fn(_k, a, b) { Some(a + b) }
797-
798751
let result =
799752
empty
800-
|> insert_with(key: "foo", value: 1, with: sum)
801-
|> insert_with(key: "bar", value: 2, with: sum)
802-
|> insert_with(key: "foo", value: 3, with: sum)
753+
|> insert_with(key: "foo", value: 1, with: union.sum())
754+
|> insert_with(key: "bar", value: 2, with: union.sum())
755+
|> insert_with(key: "foo", value: 3, with: union.sum())
803756
|> to_pairs()
804757

805758
result == [Pair("bar", 2), Pair("foo", 4)]
806759
}
807760

808761
test insert_with_3() {
809762
let with =
810-
fn(k, a, _b) {
763+
fn(k, a, _b, keep, discard) {
811764
if k == "foo" {
812-
Some(a)
765+
keep(a)
813766
} else {
814-
None
767+
discard()
815768
}
816769
}
817770

@@ -970,17 +923,7 @@ pub fn union(
970923
left: Dict<key, value>,
971924
right: Dict<key, value>,
972925
) -> Dict<key, value> {
973-
Dict { inner: do_union(left.inner, right.inner) }
974-
}
975-
976-
fn do_union(
977-
left: Pairs<ByteArray, value>,
978-
right: Pairs<ByteArray, value>,
979-
) -> Pairs<ByteArray, value> {
980-
when left is {
981-
[] -> right
982-
[Pair(k, v), ..rest] -> do_union(rest, do_insert(right, k, v))
983-
}
926+
Dict(do_union_with(left.inner, right.inner, union.keep_left()))
984927
}
985928

986929
test union_1() {
@@ -1036,70 +979,26 @@ test union_4() {
1036979
/// result == [Pair("a", 250), Pair("b", 200), Pair("c", 300)]
1037980
/// ```
1038981
pub fn union_with(
1039-
left: Dict<key, value>,
1040-
right: Dict<key, value>,
1041-
with: fn(ByteArray, value, value) -> Option<value>,
1042-
) -> Dict<key, value> {
1043-
Dict { inner: do_union_with(left.inner, right.inner, with) }
1044-
}
1045-
1046-
pub fn union_with_alt(
1047982
left: Dict<key, value>,
1048983
right: Dict<key, value>,
1049984
with: UnionStrategy<ByteArray, value>,
1050985
) -> Dict<key, value> {
1051-
Dict { inner: do_union_with_alt(left.inner, right.inner, with) }
986+
Dict { inner: do_union_with(left.inner, right.inner, with) }
1052987
}
1053988

1054989
fn do_union_with(
1055-
left: Pairs<ByteArray, value>,
1056-
right: Pairs<ByteArray, value>,
1057-
with: fn(ByteArray, value, value) -> Option<value>,
1058-
) -> Pairs<ByteArray, value> {
1059-
when left is {
1060-
[] -> right
1061-
[Pair(k, v), ..rest] ->
1062-
do_union_with(rest, do_insert_with(right, k, v, with), with)
1063-
}
1064-
}
1065-
1066-
fn do_union_with_alt(
1067990
left: Pairs<ByteArray, value>,
1068991
right: Pairs<ByteArray, value>,
1069992
with: UnionStrategy<ByteArray, value>,
1070993
) -> Pairs<ByteArray, value> {
1071994
when left is {
1072995
[] -> right
1073996
[Pair(k, v), ..rest] ->
1074-
do_union_with_alt(rest, do_insert_with_alt(right, k, v, with), with)
997+
do_union_with(rest, do_insert_with(right, k, v, with), with)
1075998
}
1076999
}
10771000

10781001
fn do_insert_with(
1079-
self: Pairs<ByteArray, value>,
1080-
key k: ByteArray,
1081-
value v: value,
1082-
with: fn(ByteArray, value, value) -> Option<value>,
1083-
) -> Pairs<ByteArray, value> {
1084-
when self is {
1085-
[] -> [Pair(k, v)]
1086-
[Pair(k2, v2), ..rest] ->
1087-
if builtin.less_than_bytearray(k, k2) {
1088-
[Pair(k, v), ..self]
1089-
} else {
1090-
if k == k2 {
1091-
when with(k, v, v2) is {
1092-
Some(combined) -> [Pair(k, combined), ..rest]
1093-
None -> rest
1094-
}
1095-
} else {
1096-
[Pair(k2, v2), ..do_insert_with(rest, k, v, with)]
1097-
}
1098-
}
1099-
}
1100-
}
1101-
1102-
fn do_insert_with_alt(
11031002
self: Pairs<ByteArray, value>,
11041003
key k: ByteArray,
11051004
value v: value,
@@ -1120,7 +1019,7 @@ fn do_insert_with_alt(
11201019
fn() { rest },
11211020
)
11221021
} else {
1123-
[Pair(k2, v2), ..do_insert_with_alt(rest, k, v, with)]
1022+
[Pair(k2, v2), ..do_insert_with(rest, k, v, with)]
11241023
}
11251024
}
11261025
}
@@ -1136,7 +1035,7 @@ test union_with_1() {
11361035
|> insert(bar, 42)
11371036
|> insert(foo, 1337)
11381037

1139-
let result = union_with(left, right, with: fn(_, l, r) { Some(l + r) })
1038+
let result = union_with(left, right, with: union.sum())
11401039

11411040
result == from_pairs([Pair(foo, 1351), Pair(bar, 42)])
11421041
}

lib/cardano/assets.ak

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub fn from_asset_list(xs: Pairs<PolicyId, Pairs<AssetName, Int>>) -> Value {
7878
expect Pair(p, [_, ..] as x) = inner
7979
x
8080
|> from_ascending_pairs_with(fn(v) { v != 0 })
81-
|> dict.insert_with_alt(acc, p, _, union.expect_no_duplicate())
81+
|> dict.insert_with(acc, p, _, union.expect_no_duplicate())
8282
},
8383
)
8484
|> Value
@@ -551,12 +551,7 @@ pub fn add(
551551
let helper =
552552
fn(_, left, _right, keep, discard) {
553553
let inner_result =
554-
dict.insert_with_alt(
555-
left,
556-
asset_name,
557-
quantity,
558-
union.sum_if_non_zero(),
559-
)
554+
dict.insert_with(left, asset_name, quantity, union.sum_if_non_zero())
560555

561556
if dict.is_empty(inner_result) {
562557
discard()
@@ -566,7 +561,7 @@ pub fn add(
566561
}
567562

568563
Value(
569-
dict.insert_with_alt(
564+
dict.insert_with(
570565
self.inner,
571566
policy_id,
572567
dict.from_ascending_pairs([Pair(asset_name, quantity)]),
@@ -618,11 +613,11 @@ test add_5() {
618613
/// Combine two `Value` together.
619614
pub fn merge(left v0: Value, right v1: Value) -> Value {
620615
Value(
621-
dict.union_with_alt(
616+
dict.union_with(
622617
v0.inner,
623618
v1.inner,
624619
fn(_, a0, a1, keep, discard) {
625-
let result = dict.union_with_alt(a0, a1, union.sum_if_non_zero())
620+
let result = dict.union_with(a0, a1, union.sum_if_non_zero())
626621
if dict.is_empty(result) {
627622
discard()
628623
} else {

0 commit comments

Comments
 (0)