Skip to content

Consider adding ".support" property to distributions #86

@gvcallen

Description

@gvcallen

I'm not sure if this is out of scope or not, but it would be great to know the support of simple distributions (normals, uniform etc.). In my downstream library Parax I currently have the following code:

def infer_distribution_constraint(dist: dists.AbstractDistribution) -> AbstractConstraint:
    """
    Infers the physical support of a distreqx distribution and returns 
    the corresponding constraint mapping.

    Resolution Strategy:
    1. Recursive unwrapping of meta-distributions (Joint, Transformed).
    2. Exact type matching for common `distreqx` distributions.
    3. Fallback to `icdf(0.0)` and `icdf(1.0)` evaluation.
    4. Last resort: Unconstrained `RealLine`.
    """
    # Recursive early return for Joint and Transformed wrappers
    if hasattr(dists, 'Joint') and isinstance(dist, dists.Joint):
        sub_distributions = dist.distributions 
        constraints_tree = jax.tree.map(
            infer_distribution_constraint, 
            sub_distributions,
            is_leaf=lambda x: isinstance(x, dists.AbstractDistribution)
        )
        return Leafwise(tree=constraints_tree)
    elif isinstance(dist, dists.Transformed):
        base_constraint = infer_distribution_constraint(dist.distribution)
        return Transformed(
            constraint=base_constraint, 
            bijector=dist.bijector
        )

    # Hard-coded supports for common distributions
    if isinstance(dist, (dists.Normal, dists.Logistic)):
        return RealLine(shape=dist.event_shape)
        
    elif isinstance(dist, (dists.MultivariateNormalDiag, 
                           dists.MultivariateNormalFullCovariance, 
                           dists.MultivariateNormalTri)):
        return RealLine(shape=dist.event_shape)

    elif (hasattr(dists, 'LogNormal') and isinstance(dist, (dists.LogNormal)) or isinstance(dist, dists.Gamma)):
        return Positive(shape=dist.event_shape)

    elif isinstance(dist, dists.Beta):
        return Interval(
            lower=jnp.zeros(dist.event_shape), 
            upper=jnp.ones(dist.event_shape)
        )

    elif isinstance(dist, dists.Uniform):
        return Interval(lower=dist.low, upper=dist.high)

    try:
        lower_bound = dist.icdf(0.0)
        upper_bound = dist.icdf(1.0)
        
        is_lower_bounded = not jnp.all(jnp.isneginf(lower_bound))
        is_upper_bounded = not jnp.all(jnp.isposinf(upper_bound))

        if is_lower_bounded and not is_upper_bounded:
            if jnp.all(lower_bound == 0.0):
                return Positive(shape=dist.event_shape, dtype=lower_bound.dtype)
            return GreaterThan(lower=lower_bound)

        elif not is_lower_bounded and is_upper_bounded:
            if jnp.all(upper_bound == 0.0):
                return Negative(shape=dist.event_shape, dtype=upper_bound.dtype)
            return LessThan(upper=upper_bound)

        elif is_lower_bounded and is_upper_bounded:
            return Interval(lower=lower_bound, upper=upper_bound)

    except (NotImplementedError, AttributeError, ValueError, TypeError):
        pass

    return RealLine(shape=dist.event_shape)

While this works, its of course note the cleanest solution. It would be great if I could rather extract the support from the distribution itself, or perhaps if there was a helper utility in distreqx that would do this for me

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions