-
Notifications
You must be signed in to change notification settings - Fork 10
Add metrax_example colab notebook #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, thanks Jiwon! Left a bunch of comments but they're mainly optional suggestions.
"source": [ | ||
"Please connect to `Metrax (go/metrax)` colab runtime.\n", | ||
"\n", | ||
"If you dont see `Metrax (go/metrax)` from the dropdown menu, please run `/google/bin/releases/colaboratory/public/tools/authorize_colab` on your gLinux workstation or cloudtop and try again." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove from external version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
"The core `metrax` API is functional and stateless, making it a natural fit for JAX. It works by creating immutable `Metric` state objects that can be merged.\n", | ||
"\n", | ||
"Each `metrax` metric inherits the CLU [`metric`](http://shortn/_e70RtO7j36) class and provides the following APIs:\n", | ||
"\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea (feel free to ignore): it might be useful to describe the lifecycle of a CLU metric so it's easier for users to understand the list of methods below. Something like:
The usual pattern of using a CLU metric is to call Metric.empty()
once to create a metric object, then call metric.merge(Metric.from_model_output(y_true, y_pred))
for each subsequent batch of outputs, then finally call metric.compute()
to get the final result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
"print(\"--- Method 1: Full-Batch Calculation (on all 32 samples) ---\")\n", | ||
"full_batch_results = {}\n", | ||
"for name, MetricClass in metrics_to_compute.items():\n", | ||
" # Conditionally add sample_weights for supported metrics.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would make sense to split this cell into an initial simpler example without sample weights, then an example with sample weights? I just worry that there's a lot of logic in this cell that might obscure the basic usage of the API.
"for name in metrics_to_compute.keys():\n", | ||
" assert np.allclose(full_batch_results[name], iterative_results[name])\n", | ||
"\n", | ||
"print(\"✅ Success! Both methods produce identical results.\")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if verifying that batch vs. iterative results is relevant to end users? It seems like more of a detail that's important to the implementers of the library but as long as it's tested I'm not sure if end users will be worrying about doing this check?
" update_kwargs['sample_weights'] = sample_weights\n", | ||
" if name in metrics_with_threshold:\n", | ||
" update_kwargs['threshold'] = 0.5\n", | ||
" metric_obj.update(**update_kwargs)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of logic and kwarg updating here, I wonder if it'd be easier for users to understand if it didn't automate as much? I.e. just calling Precision.update()
directly instead of computing the kwargs? It may result in around the same number of LOC and be more readable for newcomers.
" print(f\"{name}: {full_batch_results_nnx[name]}\")\n", | ||
"\n", | ||
"\n", | ||
"# --- Method 2: Iterative Updating by Batch (nnx) ---\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Optional) it might be worth splitting these two methods into separate cells for clarity (feel free to ignore)
" print(f\"{name}: {iterative_results_nnx[name]}\")\n", | ||
"\n", | ||
"\n", | ||
"# --- Verification ---\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this might be able to be removed depending on how you feel about how relevant this is to end users
"\n", | ||
"### Method 2: The `jit` and `Mesh` Approach (Advanced Parallelism)\n", | ||
"\n", | ||
"For more advanced control over distributed computation, JAX provides an explicit sharding mechanism using the `jax.sharding` API. This **SPMD (Single Program, Multiple Data)** approach is more powerful and flexible than `pmap` and is the standard for large-scale models.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought Method 1 was also SPMD?
"id": "C3YWS1_x19DJ" | ||
}, | ||
"source": [ | ||
"## 🧠 Advanced Use: Multi-Host Environments\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the shoutout here, for anything further I think it makes sense for this to be in its own Colab
"\n", | ||
"# --- 1. Metric Calculation Functions ---\n", | ||
"\n", | ||
"# Method 1: pmap (Simple Data Parallelism)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Just confirming: is pmap still recommended? It seems like maybe shard_map is the new API for manual parallelism[1], though for this intro guide I think it could make sense to only discuss jit()
.
No description provided.