Implementing these argument checks is not totally straightforward because the arguments we would like to check may be traced: jax-ml/jax#12785 (comment)
I think we can do it but it probably belongs in a separate PR.
Originally posted by @dylanhmorris in #493 (comment)