I am building a Normalizing Flow (concatenation of Distribution and chain of Bijectors) in Tensorflow. Here is the code for the chain of Bijectors:
class Flow( tfb.Bijector ):
def __init__( self, theta, a, **kwargs ):
tfb.Bijector.__init__( self, forward_min_event_ndims = 0, **kwargs )
bijectors = [ tfb.Tanh() ]
self.chain = tfb.Chain( bijectors = bijectors )
def _forward( self, z ):
return self.chain( z )
def _inverse( self, x ):
result = self.chain.inverse( x )
return result
def _forward_log_det_jacobian( self, z ):
return self.chain._forward_log_det_jacobian( z, event_ndims = 2 )
Here's how I'm trying to test it, specifically, testing the prob
method of the base distribution plus Flow:
Z = tf.convert_to_tensor( [ [ [ 0.1, 0.2 ], [ 0.3, 0.4 ], [ 0.5, 0.6 ] ],
[ [ 0.8, 0.7 ], [ 0.6, 0.5 ], [ 0.4, 0.3 ] ],
[ [ 0.4, 0.7 ], [ 0.2, 0.1 ], [ 0.8, 0.0 ] ] ] )
print( "Z", Z )
nf = Flow( 1., 2. ) # ### theta, a
bd = tfd.MultivariateNormalDiag( loc=[0.,0.], scale_diag=[1.,1.] )
td = tfd.TransformedDistribution( bd, nf )
td.log_prob( Z )
The last statement fails with the following stack trace:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-29-9f91e9e1871a> in <module>()
24 bd = tfd.MultivariateNormalDiag( loc=[0.,0], scale_diag=[1.,1.] )
25 td = tfd.TransformedDistribution( bd, nf )
---> 26 td.prob( Z )
12 frames
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in prob(self, value, name, **kwargs)
1322 values of type `self.dtype`.
1323 """
-> 1324 return self._call_prob(value, name, **kwargs)
1325
1326 def _call_unnormalized_log_prob(self, value, name, **kwargs):
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/distribution.py in _call_prob(self, value, name, **kwargs)
1304 with self._name_and_control_scope(name, value, kwargs):
1305 if hasattr(self, '_prob'):
-> 1306 return self._prob(value, **kwargs)
1307 if hasattr(self, '_log_prob'):
1308 return tf.exp(self._log_prob(value, **kwargs))
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/transformed_distribution.py in _prob(self, y, **kwargs)
371 )
372 ildj = self.bijector.inverse_log_det_jacobian(
--> 373 y, event_ndims=event_ndims, **bijector_kwargs)
374 if self.bijector._is_injective: # pylint: disable=protected-access
375 base_prob = self.distribution.prob(x, **distribution_kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/bijectors/bijector.py in inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs)
1318 ValueError: if the value of `event_ndims` is not valid for this bijector.
1319 """
-> 1320 return self._call_inverse_log_det_jacobian(y, event_ndims, name, **kwargs)
1321
1322 def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/bijectors/bijector.py in _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs)
1274 'is implemented. One or the other is required.')
1275
-> 1276 return self._reduce_jacobian_det_over_shape(ildj, reduce_shape)
1277
1278 def inverse_log_det_jacobian(self,
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/bijectors/bijector.py in _reduce_jacobian_det_over_shape(self, unreduced, reduce_shape)
1531 ones = tf.ones(reduce_shape, unreduced.dtype)
1532 reduce_dims = ps.range(-ps.size(reduce_shape), 0)
-> 1533 return tf.reduce_sum(ones * unreduced, axis=reduce_dims)
1534
1535 def _parameter_control_dependencies(self, is_init):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
1232 # r_binary_op_wrapper use different force_same_dtype values.
1233 x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
-> 1234 return func(x, y, name=name)
1235 except (TypeError, ValueError) as e:
1236 # Even if dispatching the op failed, the RHS may be a tensor aware
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in _mul_dispatch(x, y, name)
1573 return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
1574 else:
-> 1575 return multiply(x, y, name=name)
1576
1577
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
204 """Call target, and fall back on dispatchers if there is a TypeError."""
205 try:
--> 206 return target(*args, **kwargs)
207 except (TypeError, ValueError):
208 # Note: convert_to_eager_tensor currently raises a ValueError, not a
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py in multiply(x, y, name)
528 """
529
--> 530 return gen_math_ops.mul(x, y, name)
531
532
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_math_ops.py in mul(x, y, name)
6238 return _result
6239 except _core._NotOkStatusException as e:
-> 6240 _ops.raise_from_not_ok_status(e, name)
6241 except _core._FallbackException:
6242 pass
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
6895 message = e.message + (" name: " + name if name is not None else "")
6896 # pylint: disable=protected-access
-> 6897 six.raise_from(core._status_to_exception(e.code, message), None)
6898 # pylint: enable=protected-access
6899
/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: required broadcastable shapes at loc(unknown) [Op:Mul]
I'm not able to figure out from the stack trace where things are going wrong.
Can you help?
from Why is Tensorflow tensor indexing failing in Normalizing Flow "prob" method
No comments:
Post a Comment