Monday 2 August 2021

Why is Tensorflow tensor indexing failing in Normalizing Flow "prob" method

I am building a Normalizing Flow (concatenation of Distribution and chain of Bijectors) in Tensorflow. Here is the code for the chain of Bijectors:

class Flow( tfb.Bijector ):

    def __init__( self, theta, a, **kwargs ):
        tfb.Bijector.__init__( self, forward_min_event_ndims = 0, **kwargs )
        bijectors = [ tfb.Tanh() ]
        self.chain = tfb.Chain( bijectors = bijectors )

    def _forward( self, z ):
        return self.chain( z )

    def _inverse( self, x ):
        result = self.chain.inverse( x ) 
        return result

    def _forward_log_det_jacobian( self, z ):
        return self.chain._forward_log_det_jacobian( z, event_ndims = 2 )
    

Here's how I'm trying to test it, specifically, testing the prob method of the base distribution plus Flow:

Z = tf.convert_to_tensor( [ [ [ 0.1, 0.2 ], [ 0.3, 0.4 ], [ 0.5, 0.6 ] ], 
                            [ [ 0.8, 0.7 ], [ 0.6, 0.5 ], [ 0.4, 0.3 ] ],
                            [ [ 0.4, 0.7 ], [ 0.2, 0.1 ], [ 0.8, 0.0 ] ] ] )
print( "Z", Z )
nf = Flow( 1., 2. )  # ### theta, a 
bd = tfd.MultivariateNormalDiag( loc=[0.,0.], scale_diag=[1.,1.] )
td = tfd.TransformedDistribution( bd, nf )
td.log_prob( Z )

The last statement fails with the following stack trace:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-29-9f91e9e1871a> in <module>()
     24 bd = tfd.MultivariateNormalDiag( loc=[0.,0], scale_diag=[1.,1.] )
     25 td = tfd.TransformedDistribution( bd, nf )
---> 26 td.prob( Z )

12 frames
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in prob(self, value, name, **kwargs)
   1322         values of type `self.dtype`.
   1323     """
-> 1324     return self._call_prob(value, name, **kwargs)
   1325 
   1326   def _call_unnormalized_log_prob(self, value, name, **kwargs):

/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in _call_prob(self, value, name, **kwargs)
   1304     with self._name_and_control_scope(name, value, kwargs):
   1305       if hasattr(self, '_prob'):
-> 1306         return self._prob(value, **kwargs)
   1307       if hasattr(self, '_log_prob'):
   1308         return tf.exp(self._log_prob(value, **kwargs))

/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/transformed_distribution.py in _prob(self, y, **kwargs)
    371         )
    372     ildj = self.bijector.inverse_log_det_jacobian(
--> 373         y, event_ndims=event_ndims, **bijector_kwargs)
    374     if self.bijector._is_injective:  # pylint: disable=protected-access
    375       base_prob = self.distribution.prob(x, **distribution_kwargs)

/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/bijectors/bijector.py in inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs)
   1318       ValueError: if the value of `event_ndims` is not valid for this bijector.
   1319     """
-> 1320     return self._call_inverse_log_det_jacobian(y, event_ndims, name, **kwargs)
   1321 
   1322   def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):

/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/bijectors/bijector.py in _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs)
   1274               'is implemented. One or the other is required.')
   1275 
-> 1276         return self._reduce_jacobian_det_over_shape(ildj, reduce_shape)
   1277 
   1278   def inverse_log_det_jacobian(self,

/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/bijectors/bijector.py in _reduce_jacobian_det_over_shape(self, unreduced, reduce_shape)
   1531     ones = tf.ones(reduce_shape, unreduced.dtype)
   1532     reduce_dims = ps.range(-ps.size(reduce_shape), 0)
-> 1533     return tf.reduce_sum(ones * unreduced, axis=reduce_dims)
   1534 
   1535   def _parameter_control_dependencies(self, is_init):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
   1232         #   r_binary_op_wrapper use different force_same_dtype values.
   1233         x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
-> 1234         return func(x, y, name=name)
   1235       except (TypeError, ValueError) as e:
   1236         # Even if dispatching the op failed, the RHS may be a tensor aware

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in _mul_dispatch(x, y, name)
   1573     return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
   1574   else:
-> 1575     return multiply(x, y, name=name)
   1576 
   1577 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    204     """Call target, and fall back on dispatchers if there is a TypeError."""
    205     try:
--> 206       return target(*args, **kwargs)
    207     except (TypeError, ValueError):
    208       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in multiply(x, y, name)
    528   """
    529 
--> 530   return gen_math_ops.mul(x, y, name)
    531 
    532 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_math_ops.py in mul(x, y, name)
   6238       return _result
   6239     except _core._NotOkStatusException as e:
-> 6240       _ops.raise_from_not_ok_status(e, name)
   6241     except _core._FallbackException:
   6242       pass

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6895   message = e.message + (" name: " + name if name is not None else "")
   6896   # pylint: disable=protected-access
-> 6897   six.raise_from(core._status_to_exception(e.code, message), None)
   6898   # pylint: enable=protected-access
   6899 

/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: required broadcastable shapes at loc(unknown) [Op:Mul]
    

I'm not able to figure out from the stack trace where things are going wrong.

Can you help?



from Why is Tensorflow tensor indexing failing in Normalizing Flow "prob" method

No comments:

Post a Comment