Thursday, 3 December 2020

Tensorflow probability: AssertionError: ('Sign', '_class')

Consider a toy model y = Ax + noise where x is sampled from a truncated laplace random variables.

Truncated Laplace distribution can be made using the code:

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import inline as inline_lib
from tensorflow_probability.python.distributions import uniform as uniform_lib
from tensorflow_probability.python.distributions import distribution
# from tensorflow_probability.python.distributions import reparameterization
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
import tensorflow_probability.python.bijectors as tfb
from tensorflow_probability.python.internal import prefer_static as ps

class Truncated(distribution.Distribution):
  """Truncates an underlying distribution to be bounded by low, high.

  The truncated distribution is bounded between `low` and `high` (the pdf is 0
  outside these bounds and renormalized).

  Samples from this distribution are differentiable with respect to the
  underlying distribution's parameters as well as the bounds, `low` and `high`,
  i.e., this implementation is fully reparameterized.

  For more details, see [here](

  The distribution is implemented as a
  `tfp.distributions.TransformedDistribution` in which a uniform distribution on
  `[underlying.cdf(low), underlying.cdf(high)]`is transformed by the bijection
  defined by the `underlying.quantile`/`underlying.cdf` pair, and the
  log-determinant is determined using automatic differentiation of

  If using autodiff of the quantile would be more efficient for your scenario,
  please file an issue on github or email ``.

  def __init__(self, underlying, low, high,
               validate_args=False, allow_nan_stats=False, name=None):
    """Creates a truncated distribution from underlying distribution and bounds.

      distribution: A `tfp.distributions.Distribution` which implements
        both `cdf` and `quantile` (implying a scalar `event_shape`).
      low: The lower bound for truncation. Must broadcast with both
        `underlying.batch_shape` and `high`.
      high: The upper bound for truncation. Must broadcast with both
        `underlying.batch_shape` and `low`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked at run-time.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Optional name for ops created by this instance. Defaults to
    with tf.name_scope(name or 'Truncated{}'.format( as name:
      dtype = dtype_util.common_dtype([underlying, low, high], tf.float32)
      self._distribution = underlying
      self._low = tensor_util.convert_nonref_to_tensor(low, dtype=dtype)
      self._high = tensor_util.convert_nonref_to_tensor(high, dtype=dtype)

      def ildj(y):
        return tf.math.log(tf.math.abs(
            tfp.math.value_and_gradient(underlying.cdf, y)[1]))

      self._bijector = tfb.Inline(
          is_increasing=lambda: True)

      super(Truncated, self).__init__(

  def distribution(self):
    return self._distribution

  def low(self):
    return self._low

  def high(self):
    return self._high

  def _batch_shape(self):
    return tf.broadcast_static_shape(
        tf.broadcast_static_shape(self.low.shape, self.high.shape))

  def _batch_shape_tensor(self, low=None, high=None):
    return ps.broadcast_shape(
        ps.broadcast_shape(ps.shape(self.low if low is None else low),
                           ps.shape(self.high if high is None else high)))

  def _event_shape(self):
    return ()

  def _make_transformed_uniform(self):
    low = tf.convert_to_tensor(self.low)
    high = tf.convert_to_tensor(self.high)
    low_cdf = self.distribution.cdf(low)
    high_cdf = tf.broadcast_to(
        self._batch_shape_tensor(low=low, high=high))
    return self._bijector(tfd.Uniform(low_cdf, high_cdf))

  def _sample_n(self, n, seed=None):
    return self._make_transformed_uniform().sample(n, seed=seed)

  def _log_prob(self, x):
    return self._make_transformed_uniform().log_prob(x)

  def _cdf(self, x):
    return self._make_transformed_uniform().cdf(x)

  def _quantile(self, q):
    return self._make_transformed_uniform().quantile(q)


import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np

A = tf.random.normal(
    [10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])

x1 = tfd.Sample(Truncated(tfd.Laplace(55, 10), 10, 100), sample_shape=[10]).sample() 
y = tf.linalg.matvec(A, x1) + noise_std

model = tfd.JointDistributionSequentialAutoBatched([
    tfd.Sample(tfd.Normal(loc=0., scale=1.),1),
    tfd.Sample(Truncated(tfd.Laplace(55.0, 10), 10, 100), sample_shape=[10]),
    lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matvec(A, x_rv) + sigma, scale=1.0)

def target_log_prob_fn(sigma, x_rv):
    return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])

Running the following code for num_samples > 1 results in AssertionError: in user code: ... AssertionError: ('Sign', '_class')

num_samples = 2
sample = model.sample(num_samples)

The same error is returned when the model is run using the NUTS sampler and multiple chains. With single chain the sampling works fine.

Full stacktrace:

from Tensorflow probability: AssertionError: ('Sign', '_class')

