Skip to content

Commit

Permalink
implemented d_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
nohzafk committed Aug 25, 2024
1 parent c29397a commit ad6e181
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 36 deletions.
39 changes: 33 additions & 6 deletions src/flat_tensor.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,7 @@ pub fn lower_float2(

// element-wise multiplication with broadcasting
pub fn extend_rank1_numeric(f, m, shape_fn) {
fn(t: Tensor) -> Tensor {
flat_extend_rank1_numeric(f, m, shape_fn, t)
}
fn(t: Tensor) -> Tensor { flat_extend_rank1_numeric(f, m, shape_fn, t) }
}

pub fn flat_extend_rank1_numeric(
Expand All @@ -382,9 +380,7 @@ pub fn flat_extend_rank1_numeric(

let v_out =
list.range(0, size_out / stride_out - 1)
|> list.fold(
<<>>,
fn(acc, i) {
|> list.fold(<<>>, fn(acc, i) {
f(t0.store, t0.offset + i * stride0 * 8, stride0, acc, i, stride_out)
})

Expand Down Expand Up @@ -1407,6 +1403,13 @@ pub fn d_multiply_2_1_numeric(t, u) {
//----------------------------
// D-sum
//----------------------------
pub fn sum_1_numeric(t0_store, offset, stride0, v_out, _i_out, _stride_out) {
let assert Ok(slice) = t0_store |> bit_array.slice(offset, stride0 * 8)

let sum = float_bits_walker(fn(acc, i) { acc +. i }, slice, 0.0)
<<v_out:bits, sum:float>>
}

pub fn sum_1_gradient(g0, _t0_store, offset_t0, stride0, z_store, iz, _stride_z) {
let z = get_float(z_store, iz)

Expand All @@ -1422,3 +1425,27 @@ pub fn sum_1_gradient(g0, _t0_store, offset_t0, stride0, z_store, iz, _stride_z)
)
bitarray_replace_slice(g0, offset_t0, stride0, new_slice)
}

pub fn refr(lst: List(a), n: Int) -> List(a) {
lst |> list.drop(n)
}

pub fn sum_shape(st: Shape) {
refr(st, 1)
}

pub fn sum_1() {
Prim1BitArrayFn(
numeric_fn: sum_1_numeric,
gradient_fn: sum_1_gradient,
shape_fn: sum_shape,
)
}

pub fn d_sum(t) {
{ sum_1() |> ext1(1) }(t)
}

pub fn d_sum_numeric(t) {
extend_rank1_numeric(sum_1_numeric, 1, sum_shape)(t)
}
169 changes: 139 additions & 30 deletions test/flat_tensor_test.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import flat_tensor.{
type Differentiable, type Dual, type Shape, type Tensor, Dual, DualDiff,
ListDiff, add_numeric, bitarray_replace_slice, bitarray_to_floats, build_store,
build_tensor, d_add, d_divide, d_exp, d_expt, d_log, d_multiply,
d_multiply_2_1, d_multiply_2_1_numeric, d_sqr, d_sqrt, d_subtract,
d_multiply_2_1, d_multiply_2_1_numeric, d_sqr, d_sqrt, d_subtract, d_sum,
equal_elements, extend_rank1_gradient, extend_rank1_numeric,
extend_rank2_gradient, extend_rank2_numeric, extend_rank2_shapes,
flat_extend_rank1_gradient, float_bits_walker, float_to_tensor,
Expand Down Expand Up @@ -937,39 +937,148 @@ pub fn d_multiply_2_1_test() {
|> tensors_to_differentiable,
)
}
{
let a =
[[3, 4, 5, 6], [7, 8, 9, 10]]
|> dynamic.from
|> to_tensor
let b =
[[2, 3, 4, 5], [12, 13, 14, 15]]
|> dynamic.from
|> to_tensor

let a =
[[3, 4, 5, 6], [7, 8, 9, 10]]
|> dynamic.from
|> to_tensor
let b =
[[2, 3, 4, 5], [12, 13, 14, 15]]
|> dynamic.from
|> to_tensor
d_multiply_2_1_numeric(a, b)
|> tensor_should_equal(
[
[[6, 12, 20, 30], [14, 24, 36, 50]],
[[36, 52, 70, 90], [84, 104, 126, 150]],
]
|> dynamic.from
|> to_tensor,
)

d_multiply_2_1_numeric(a, b)
|> tensor_should_equal(
[
[[6, 12, 20, 30], [14, 24, 36, 50]],
[[36, 52, 70, 90], [84, 104, 126, 150]],
]
|> dynamic.from
|> to_tensor,
{ d_multiply_2_1 |> check_theta_and_gradient2 }(
a,
b,
[
[[6, 12, 20, 30], [14, 24, 36, 50]],
[[36, 52, 70, 90], [84, 104, 126, 150]],
]
|> dynamic.from
|> to_tensor,
[
[[14, 16, 18, 20], [14, 16, 18, 20]] |> dynamic.from |> to_tensor,
[[10, 12, 14, 16], [10, 12, 14, 16]] |> dynamic.from |> to_tensor,
]
|> tensors_to_differentiable,
)
}
}

pub fn sum_test() {
{ d_sum |> check_theta_and_gradient1 }(
[3, 4, 5] |> dynamic.from |> to_tensor,
float_to_tensor(12.0),
[[1, 1, 1] |> dynamic.from |> to_tensor]
|> tensors_to_differentiable,
)

{ d_multiply_2_1 |> check_theta_and_gradient2 }(
a,
b,
[
[[6, 12, 20, 30], [14, 24, 36, 50]],
[[36, 52, 70, 90], [84, 104, 126, 150]],
]
{
let a =
[[3, 4, 5], [6, 7, 8]]
|> dynamic.from
|> to_tensor

d_sum(a |> to_dual).tensor
|> tensor_should_equal(
[12, 21]
|> dynamic.from
|> to_tensor,
[
[[14, 16, 18, 20], [14, 16, 18, 20]] |> dynamic.from |> to_tensor,
[[10, 12, 14, 16], [10, 12, 14, 16]] |> dynamic.from |> to_tensor,
]
|> tensors_to_differentiable,
)
)

check_gradients1(
fn(b) { d_multiply(b, b) |> d_sum },
a,
[
[[6, 8, 10], [12, 14, 16]]
|> dynamic.from
|> to_tensor,
]
|> tensors_to_differentiable,
)
}

let dot_product = fn(a, b) { d_multiply_2_1(a, b) |> d_sum }
let sse = fn(a, b) { d_subtract(a, b) |> d_sqr |> d_sum }
{
let a =
[[3, 4, 5, 6], [7, 8, 9, 10]]
|> dynamic.from
|> to_tensor
let b = [2, 3, 4, 5] |> dynamic.from |> to_tensor

{ d_sum |> check_theta_and_gradient1 }(
b,
float_to_tensor(14.0),
[[1, 1, 1, 1] |> dynamic.from |> to_tensor]
|> tensors_to_differentiable,
)

{ dot_product |> check_theta_and_gradient2 }(
a,
b,
[68, 124] |> dynamic.from |> to_tensor,
[
[[2, 3, 4, 5], [2, 3, 4, 5]]
|> dynamic.from
|> to_tensor,
[10, 12, 14, 16] |> dynamic.from |> to_tensor,
]
|> tensors_to_differentiable,
)

{ sse |> check_theta_and_gradient2 }(
a,
b,
[4, 100] |> dynamic.from |> to_tensor,
[
[[2, 2, 2, 2], [10, 10, 10, 10]]
|> dynamic.from
|> to_tensor,
[-12, -12, -12, -12]
|> dynamic.from
|> to_tensor,
]
|> tensors_to_differentiable,
)
}

{
let a =
[[3, 4, 5, 6], [7, 8, 9, 10]]
|> dynamic.from
|> to_tensor

let b =
[[2, 3, 4, 5], [12, 13, 14, 15]]
|> dynamic.from
|> to_tensor

{ dot_product |> check_theta_and_gradient2 }(
a,
b,
[[68, 124], [248, 464]]
|> dynamic.from
|> to_tensor,
[
[[14, 16, 18, 20], [14, 16, 18, 20]]
|> dynamic.from
|> to_tensor,
[[10, 12, 14, 16], [10, 12, 14, 16]]
|> dynamic.from
|> to_tensor,
]
|> tensors_to_differentiable,
)
}
}

0 comments on commit ad6e181

Please sign in to comment.