I'm not sure if this is out of scope or not, but it would be great to know the support of simple distributions (normals, uniform etc.). In my downstream library Parax I currently have the following code:
def infer_distribution_constraint(dist: dists.AbstractDistribution) -> AbstractConstraint:
"""
Infers the physical support of a distreqx distribution and returns
the corresponding constraint mapping.
Resolution Strategy:
1. Recursive unwrapping of meta-distributions (Joint, Transformed).
2. Exact type matching for common `distreqx` distributions.
3. Fallback to `icdf(0.0)` and `icdf(1.0)` evaluation.
4. Last resort: Unconstrained `RealLine`.
"""
# Recursive early return for Joint and Transformed wrappers
if hasattr(dists, 'Joint') and isinstance(dist, dists.Joint):
sub_distributions = dist.distributions
constraints_tree = jax.tree.map(
infer_distribution_constraint,
sub_distributions,
is_leaf=lambda x: isinstance(x, dists.AbstractDistribution)
)
return Leafwise(tree=constraints_tree)
elif isinstance(dist, dists.Transformed):
base_constraint = infer_distribution_constraint(dist.distribution)
return Transformed(
constraint=base_constraint,
bijector=dist.bijector
)
# Hard-coded supports for common distributions
if isinstance(dist, (dists.Normal, dists.Logistic)):
return RealLine(shape=dist.event_shape)
elif isinstance(dist, (dists.MultivariateNormalDiag,
dists.MultivariateNormalFullCovariance,
dists.MultivariateNormalTri)):
return RealLine(shape=dist.event_shape)
elif (hasattr(dists, 'LogNormal') and isinstance(dist, (dists.LogNormal)) or isinstance(dist, dists.Gamma)):
return Positive(shape=dist.event_shape)
elif isinstance(dist, dists.Beta):
return Interval(
lower=jnp.zeros(dist.event_shape),
upper=jnp.ones(dist.event_shape)
)
elif isinstance(dist, dists.Uniform):
return Interval(lower=dist.low, upper=dist.high)
try:
lower_bound = dist.icdf(0.0)
upper_bound = dist.icdf(1.0)
is_lower_bounded = not jnp.all(jnp.isneginf(lower_bound))
is_upper_bounded = not jnp.all(jnp.isposinf(upper_bound))
if is_lower_bounded and not is_upper_bounded:
if jnp.all(lower_bound == 0.0):
return Positive(shape=dist.event_shape, dtype=lower_bound.dtype)
return GreaterThan(lower=lower_bound)
elif not is_lower_bounded and is_upper_bounded:
if jnp.all(upper_bound == 0.0):
return Negative(shape=dist.event_shape, dtype=upper_bound.dtype)
return LessThan(upper=upper_bound)
elif is_lower_bounded and is_upper_bounded:
return Interval(lower=lower_bound, upper=upper_bound)
except (NotImplementedError, AttributeError, ValueError, TypeError):
pass
return RealLine(shape=dist.event_shape)
While this works, its of course note the cleanest solution. It would be great if I could rather extract the support from the distribution itself, or perhaps if there was a helper utility in distreqx that would do this for me
I'm not sure if this is out of scope or not, but it would be great to know the support of simple distributions (normals, uniform etc.). In my downstream library Parax I currently have the following code:
While this works, its of course note the cleanest solution. It would be great if I could rather extract the support from the distribution itself, or perhaps if there was a helper utility in distreqx that would do this for me