Skip to content

Commit

Permalink
WIP: implement ext1 ext2
Browse files Browse the repository at this point in the history
  • Loading branch information
nohzafk committed Jul 21, 2024
1 parent a0c8068 commit e94ec12
Show file tree
Hide file tree
Showing 8 changed files with 590 additions and 55 deletions.
1 change: 1 addition & 0 deletions gleam.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mat = ">= 1.0.0 and < 2.0.0"
qcheck_gleeunit_utils = ">= 0.1.0 and < 1.0.0"
birl = ">= 1.7.1 and < 2.0.0"
parallel_map = ">= 2.0.0 and < 3.0.0"
# parallel_map = { path = "../parallel_map" }
gleam_erlang = ">= 0.25.0 and < 1.0.0"

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ packages = [
{ name = "gleam_community_maths", version = "1.1.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_community_maths", source = "hex", outer_checksum = "6C4ED7BC7E7DF6977719B5F2CFE717EE8280D1CF6EA81D55FD9953758C7FD14E" },
{ name = "gleam_erlang", version = "0.25.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "054D571A7092D2A9727B3E5D183B7507DAB0DA41556EC9133606F09C15497373" },
{ name = "gleam_otp", version = "0.10.0", build_tools = ["gleam"], requirements = ["gleam_erlang", "gleam_stdlib"], otp_app = "gleam_otp", source = "hex", outer_checksum = "0B04FE915ACECE539B317F9652CAADBBC0F000184D586AAAF2D94C100945D72B" },
{ name = "gleam_stdlib", version = "0.38.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "663CF11861179AF415A625307447775C09404E752FF99A24E2057C835319F1BE" },
{ name = "gleam_stdlib", version = "0.39.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "2D7DE885A6EA7F1D5015D1698920C9BAF7241102836CE0C3837A4F160128A9C4" },
{ name = "gleeunit", version = "1.2.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "F7A7228925D3EE7D0813C922E062BFD6D7E9310F0BEE585D3A42F3307E3CFD13" },
{ name = "mat", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "mat", source = "hex", outer_checksum = "CCC6642918C2FB70CE69967EE649E8AF95180059423DE1ED2ED362A0ABDAF739" },
{ name = "parallel_map", version = "2.0.0", build_tools = ["gleam"], requirements = ["gleam_erlang", "gleam_otp", "gleam_stdlib"], otp_app = "parallel_map", source = "hex", outer_checksum = "653714A9FD63EACD1A9D0A6582A972B0EC109AE275CDDD2E99CFC3DFAFAB9225" },
Expand Down
10 changes: 5 additions & 5 deletions src/iris.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import gleam/list
import gleam/string
import iris_data
import malt0.{
type Block, Block, accuracy, grid_search, init_theta, l2_loss, model, relu,
sampling_obj, stack_blocks, tensor,
type Block, type Hyperparameters, Block, accuracy, grid_search, init_theta,
l2_loss, model, relu, sampling_obj, stack_blocks, tensor,
}

//*----------------------------------------
Expand Down Expand Up @@ -152,9 +152,9 @@ pub fn accurate_enough_iris_theta(theta) {
>=. 0.9
}

pub fn grid_search_iris_theta() {
fn(hp) {
{ hp |> malt0.naked_gradient_descent }(
pub fn grid_search_iris_theta(gs) {
fn(hp: Hyperparameters) {
{ hp |> gs }(
{ hp.batch_size |> sampling_obj }(
l2_loss(iris_classifier()),
iris_train_xs(),
Expand Down
6 changes: 2 additions & 4 deletions src/malt0.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,7 @@ pub fn grid_search(
batch_size |> list.map(dynamic.from),
]
|> cartesian_product
// use parallel_map.list_find_pmap
|> parallel_map.list_pmap(
fn(hypers) {
let assert [revs, alpha, batch_size] = hypers
Expand All @@ -1337,13 +1338,10 @@ pub fn grid_search(
}
},
parallel_map.WorkerAmount(16),
60 * 1000,
20 * 1000,
)
}

// how can i stop the parallel execution when I find a Ok value?
// under the hood of parallel_map.list_pmap it use erlang process

// fn tensor_to_list(t: Tensor) {
// case t {
// ScalarTensor(s) -> s.real |> dynamic.from
Expand Down
Loading

0 comments on commit e94ec12

Please sign in to comment.