-
David Hall authored
* actually make sum etc return jnp.ndarray when None is provided * no cover * remove the coverage pragmas and just automatically exclude typing.overload * fix signature for ReductionFunction * cleanup typing.overloads * refactor dot to support None * make precommit stricter, please mypy * use auto_sharded internally, undeprecate it b/c it has a point * refactor dot to take a keyword argument for axis instead of it being a leading argument * fix some deprecation warnings * move dot to its own file * add support for out_axes to `dot` * add docs for the new dot * move rearrange_to_fit_order to axis * make "old-style" rearrange accept multiple ellipses * add support for multiple ellipsis to einops-style rearrange * try to fix dtypelike for older jax versions * missed a spot * remove old scalar thing from docs * allow AxisSpec for Embed in Embedding * copy in Equinox compile cache so Patrick isn't disappointed in me * use named arrays for the where check * let `named` work with anything asarray likes * catch some some dumb oversights in calling sharding * simplify handling of non-named arrays in binary ops * move losses to their own file, add binary cross entropy loss * less coverage noise * autocoerce scalar named arrays to plain jax arrays * don't reduce losses when reduction axis == () * on second thought, less magic in the Haliax API is probably better * on second thought, less magic in the Haliax API is probably better * minimize use of jax internals * make it so we don't scan scalars, which will be nicer * older jax doesn't expose DTypeLike * log the name that's being scanned * parens seem more normal * change default prevent_cse to False for Stacked * Adds a new einsum syntax (#63) Adds a new einsum syntax to Haliax that has three modes: "ordered", "unordered", and "output only". * Ordered looks like normal (einops-flavored) einsum * unordered uses syntax similar to rearrange: specify the names of axes involved in the operation, others are batched. * output only lets you specify only the expected output shape. use judiciously. * add a method for flattening all "batch" axes into a single batch axis, which is useful in a few situations