I am trying to use JAX on another SO question to evaluate JAX applicability and performance on the code (There are useful information on that about what the code does). For this purpose, I have modified the code by jax.numpy (jnp) equivalent methods (Substituting NumPy related codes with their equivalent jnp codes were not as easy as I thought due to my little experience by JAX, and may be it could be written better). Finally, I checked the results with the ex-code (optimized algorithm) and the results were the same, but it takes 7.5 seconds by JAX, which took 0.10 seconds by the ex-one for a sample case (using Colab). I think this long runtime may be related to for loop in the code, which might be substituted by JAX related modules e.g. fori-loop or vectorization and …; but I don’t know what changes, and how, must be done to make this code satisfying in terms of performance and speed (using JAX).
import numpy as np
from scipy.spatial import cKDTree, distance
import jax
from jax import numpy as jnp
jax.config.update("jax_enable_x64", True)
# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')
poss = np.load('b.npy')
"""
rnd = np.random.RandomState(70)
data_volume = 1000
radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()
x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------
# @jax.jit
def ends_gap(poss, dia_max):
particle_corsp_overlaps = jnp.array([], dtype=np.float64)
kdtree = cKDTree(poss) # Using SciPy
for particle_idx in range(len(poss)):
cur_point = poss[particle_idx]
nears_i_ind = jnp.array(kdtree.query_ball_point(cur_point, r=dia_max, return_sorted=True), dtype=np.int64) # Using SciPy
''' # Using NumPy
unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
poss_without = poss[unshared_idx]
dist_max = radii[particle_idx] + radii.max()
lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max
nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
'''
assert len(nears_i_ind) > 0
if len(nears_i_ind) <= 1:
continue
nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
dist_i = distance.cdist(poss[tuple(nears_i_ind[None, :])], cur_point[None, :]).squeeze() # Using SciPy
# dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1) # Using NumPy
contact_check = dist_i - (radii[tuple(nears_i_ind[None, :])] + radii[particle_idx])
connected = contact_check[contact_check <= 0]
particle_corsp_overlaps = jnp.concatenate((particle_corsp_overlaps, connected))
contacts_ind = jnp.where(contact_check <= 0)[0]
contacts_sec_ind = jnp.array(nears_i_ind)[contacts_ind]
sphere_olps_ind = jnp.sort(contacts_sec_ind)
ends_ind_mod_temp = jnp.array([jnp.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
if particle_idx > 0: # ---> these 4-lines perhaps be better to be substituted by just one-line list appending as "ends_ind.append(ends_ind_mod_temp)"
ends_ind = jnp.concatenate((ends_ind, ends_ind_mod_temp))
else:
ends_ind = jnp.array(ends_ind_mod_temp, dtype=np.int64)
ends_ind_org = ends_ind
ends_ind, ends_ind_idx = jnp.unique(jnp.sort(ends_ind_org), axis=0, return_index=True)
gap = jnp.array(particle_corsp_overlaps)[ends_ind_idx]
return gap, ends_ind, ends_ind_idx, ends_ind_org
I have tried to use @jax.jit on this code, but it shows errors: TracerArrayConversionError or ConcretizationTypeError on COLAB TPU:
Using SciPy:
TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[1000,3])>with<DynamicJaxprTrace(level=0/1)> While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'poss'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Using NumPy:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'poss' and 'dia_max'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
I would be appreciated for any help to speed up this code by passing these problems using JAX (and jax.jit if possible). How to utilize JAX to have the best performances on both CPU and GPU or TPU?
Prepared sample test data:
a.npy = Radii data
b.npy = Poss data
from How could I speed up this looping code by JAX
No comments:
Post a Comment