Thursday, 10 March 2022

How could I speed up this looping code by JAX

I am trying to use JAX on another SO question to evaluate JAX applicability and performance on the code (There are useful information on that about what the code does). For this purpose, I have modified the code by jax.numpy (jnp) equivalent methods (Substituting NumPy related codes with their equivalent jnp codes were not as easy as I thought due to my little experience by JAX, and may be it could be written better). Finally, I checked the results with the ex-code (optimized algorithm) and the results were the same, but it takes 7.5 seconds by JAX, which took 0.10 seconds by the ex-one for a sample case (using Colab). I think this long runtime may be related to for loop in the code, which might be substituted by JAX related modules e.g. fori-loop or vectorization and …; but I don’t know what changes, and how, must be done to make this code satisfying in terms of performance and speed (using JAX).

import numpy as np
from scipy.spatial import cKDTree, distance
import jax
from jax import numpy as jnp
jax.config.update("jax_enable_x64", True)


# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')
poss = np.load('b.npy')
"""

rnd = np.random.RandomState(70)
data_volume = 1000

radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()

x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------


# @jax.jit
def ends_gap(poss, dia_max):
    particle_corsp_overlaps = jnp.array([], dtype=np.float64)

    kdtree = cKDTree(poss)                                                                                              # Using SciPy

    for particle_idx in range(len(poss)):

        cur_point = poss[particle_idx]
        nears_i_ind = jnp.array(kdtree.query_ball_point(cur_point, r=dia_max, return_sorted=True), dtype=np.int64)      # Using SciPy
        '''                                                                                                             # Using NumPy
        unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
        poss_without = poss[unshared_idx]
        dist_max = radii[particle_idx] + radii.max()

        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max

        nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        '''

        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]

        dist_i = distance.cdist(poss[tuple(nears_i_ind[None, :])], cur_point[None, :]).squeeze()                        # Using SciPy
        # dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1)                     # Using NumPy
        contact_check = dist_i - (radii[tuple(nears_i_ind[None, :])] + radii[particle_idx])

        connected = contact_check[contact_check <= 0]
        particle_corsp_overlaps = jnp.concatenate((particle_corsp_overlaps, connected))

        contacts_ind = jnp.where(contact_check <= 0)[0]
        contacts_sec_ind = jnp.array(nears_i_ind)[contacts_ind]
        sphere_olps_ind = jnp.sort(contacts_sec_ind)

        ends_ind_mod_temp = jnp.array([jnp.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
        if particle_idx > 0:   # ---> these 4-lines perhaps be better to be substituted by just one-line list appending as "ends_ind.append(ends_ind_mod_temp)"
            ends_ind = jnp.concatenate((ends_ind, ends_ind_mod_temp))
        else:
            ends_ind = jnp.array(ends_ind_mod_temp, dtype=np.int64)

    ends_ind_org = ends_ind
    ends_ind, ends_ind_idx = jnp.unique(jnp.sort(ends_ind_org), axis=0, return_index=True)
    gap = jnp.array(particle_corsp_overlaps)[ends_ind_idx]

    return gap, ends_ind, ends_ind_idx, ends_ind_org

I have tried to use @jax.jit on this code, but it shows errors: TracerArrayConversionError or ConcretizationTypeError on COLAB TPU:

Using SciPy:

TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[1000,3])>with<DynamicJaxprTrace(level=0/1)> While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'poss'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Using NumPy:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'poss' and 'dia_max'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

I would be appreciated for any help to speed up this code by passing these problems using JAX (and jax.jit if possible). How to utilize JAX to have the best performances on both CPU and GPU or TPU?


Prepared sample test data:
a.npy = Radii data
b.npy = Poss data



from How could I speed up this looping code by JAX

No comments:

Post a Comment