I have a function that takes in an array, performs an arbitrary calculation and returns a new shape in which it can be broadcasted. I would like to use this function in a numba.njit
environment:
import numpy as np
import numba as nb
@nb.njit
def generate_target_shape(my_array):
### some functionality that calculates the desired target shape ###
return tuple([2,2])
@nb.njit
def test():
my_array = np.array([1,2,3,4])
target_shape = generate_target_shape(my_array)
reshaped = my_array.reshape(target_shape)
print(reshaped)
test()
However, tuple creation is not supported in numba and I get the following error message when trying to cast the result of generate_target_shape
to a tuple with the tuple()
operator:
No implementation of function Function(<class 'tuple'>) found for signature:
>>> tuple(list(int64)<iv=None>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
With argument(s): '(list(int64)<iv=None>)':
No match.
During: resolving callee type: Function(<class 'tuple'>
If I try to change the return type of generate_target_shape
from tuple
to list
or np.array
, I receive the following error message:
Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))
Is there a way for me to create an iterable object inside a nb.njit
function that can be passed to np.reshape
?
from Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?
No comments:
Post a Comment