Friday, 28 April 2023

Computing a norm in a loop slows down the computation with Dask

I was trying to implement a conjugate gradient algorithm using Dask (for didactic purposes) when I realized that the performance were way worst that a simple numpy implementation. After a few experiments, I have been able to reduce the problem to the following snippet:

import numpy as np
import dask.array as da
from time import time


def test_operator(f, test_vector, library=np):
    for n in (10, 20, 30):
        v = test_vector()

        start_time = time()
        for i in range(n):
            v = f(v)
            k = library.linalg.norm(v)
    
            try:
                k = k.compute()
            except AttributeError:
                pass
            print(k)
        end_time = time()

        print('Time for {} iterations: {}'.format(n, end_time - start_time))

print('NUMPY!')
test_operator(
    lambda x: x + x,
    lambda: np.random.rand(4_000, 4_000)
)

print('DASK!')
test_operator(
    lambda x: x + x,
    lambda: da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
    da
)

In the code, I simply multiply by 2 a vector (this is what f does) and print its norm. When running with dask, each iteration slows down a little bit more. This problem does not happen if I do not compute k, the norm of v.

Unfortunately, in my case, that k is the norm of the residual that I use to stop the conjugate gradient algorithm. How can I avoid this problem? And why does it happen?

Thank you!



from Computing a norm in a loop slows down the computation with Dask

No comments:

Post a Comment