Consider a toy model y = Ax + noise where x is sampled from a truncated laplace random variables.
Truncated Laplace distribution can be made using the code: https://github.com/tensorflow/probability/issues/1135#issuecomment-713716395
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import inline as inline_lib
from tensorflow_probability.python.distributions import uniform as uniform_lib
from tensorflow_probability.python.distributions import distribution
# from tensorflow_probability.python.distributions import reparameterization
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
import tensorflow_probability.python.bijectors as tfb
from tensorflow_probability.python.internal import prefer_static as ps
class Truncated(distribution.Distribution):
"""Truncates an underlying distribution to be bounded by low, high.
The truncated distribution is bounded between `low` and `high` (the pdf is 0
outside these bounds and renormalized).
Samples from this distribution are differentiable with respect to the
underlying distribution's parameters as well as the bounds, `low` and `high`,
i.e., this implementation is fully reparameterized.
For more details, see [here](
https://en.wikipedia.org/wiki/Truncated_normal_distribution).
The distribution is implemented as a
`tfp.distributions.TransformedDistribution` in which a uniform distribution on
`[underlying.cdf(low), underlying.cdf(high)]`is transformed by the bijection
defined by the `underlying.quantile`/`underlying.cdf` pair, and the
log-determinant is determined using automatic differentiation of
`underlying.cdf`.
If using autodiff of the quantile would be more efficient for your scenario,
please file an issue on github or email `tfprobability@tensorflow.org`.
"""
def __init__(self, underlying, low, high,
validate_args=False, allow_nan_stats=False, name=None):
"""Creates a truncated distribution from underlying distribution and bounds.
Args:
distribution: A `tfp.distributions.Distribution` which implements
both `cdf` and `quantile` (implying a scalar `event_shape`).
low: The lower bound for truncation. Must broadcast with both
`underlying.batch_shape` and `high`.
high: The upper bound for truncation. Must broadcast with both
`underlying.batch_shape` and `low`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked at run-time.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value '`NaN`' to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: Optional name for ops created by this instance. Defaults to
`f"Truncated{distribution.name}"`.
"""
with tf.name_scope(name or 'Truncated{}'.format(underlying.name)) as name:
dtype = dtype_util.common_dtype([underlying, low, high], tf.float32)
print(dtype)
self._distribution = underlying
self._low = tensor_util.convert_nonref_to_tensor(low, dtype=dtype)
self._high = tensor_util.convert_nonref_to_tensor(high, dtype=dtype)
def ildj(y):
return tf.math.log(tf.math.abs(
tfp.math.value_and_gradient(underlying.cdf, y)[1]))
self._bijector = tfb.Inline(
forward_fn=underlying.quantile,
inverse_fn=underlying.cdf,
inverse_log_det_jacobian_fn=ildj,
forward_min_event_ndims=0,
is_increasing=lambda: True)
super(Truncated, self).__init__(
dtype=dtype,
reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name)
@property
def distribution(self):
return self._distribution
@property
def low(self):
return self._low
@property
def high(self):
return self._high
def _batch_shape(self):
return tf.broadcast_static_shape(
self.distribution.batch_shape,
tf.broadcast_static_shape(self.low.shape, self.high.shape))
def _batch_shape_tensor(self, low=None, high=None):
return ps.broadcast_shape(
self.distribution.batch_shape_tensor(),
ps.broadcast_shape(ps.shape(self.low if low is None else low),
ps.shape(self.high if high is None else high)))
def _event_shape(self):
return ()
def _make_transformed_uniform(self):
low = tf.convert_to_tensor(self.low)
high = tf.convert_to_tensor(self.high)
low_cdf = self.distribution.cdf(low)
high_cdf = tf.broadcast_to(
self.distribution.cdf(high),
self._batch_shape_tensor(low=low, high=high))
return self._bijector(tfd.Uniform(low_cdf, high_cdf))
def _sample_n(self, n, seed=None):
return self._make_transformed_uniform().sample(n, seed=seed)
def _log_prob(self, x):
return self._make_transformed_uniform().log_prob(x)
def _cdf(self, x):
return self._make_transformed_uniform().cdf(x)
def _quantile(self, q):
return self._make_transformed_uniform().quantile(q)
Model:
import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np
A = tf.random.normal(
[10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])
x1 = tfd.Sample(Truncated(tfd.Laplace(55, 10), 10, 100), sample_shape=[10]).sample()
y = tf.linalg.matvec(A, x1) + noise_std
model = tfd.JointDistributionSequentialAutoBatched([
tfd.Sample(tfd.Normal(loc=0., scale=1.),1),
tfd.Sample(Truncated(tfd.Laplace(55.0, 10), 10, 100), sample_shape=[10]),
lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matvec(A, x_rv) + sigma, scale=1.0)
])
def target_log_prob_fn(sigma, x_rv):
return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])
Running the following code for num_samples > 1
results in AssertionError: in user code: ... AssertionError: ('Sign', '_class')
num_samples = 2
sample = model.sample(num_samples)
model.log_prob(sample)
The same error is returned when the model is run using the NUTS sampler and multiple chains. With single chain the sampling works fine.
Full stacktrace: https://pastebin.com/SRWC3ajS
from Tensorflow probability: AssertionError: ('Sign', '_class')
No comments:
Post a Comment