Tuesday, 30 October 2018

How to efficiently extract all slices of given length using tensorflow

I am trying to extract all slices of length 4 along 0th axis of a 2-dim tensor. So far I can do it mixing numpy with tensorflow.

r = test.shape[0]
n = 4
a_list = list(range(r))
the_list = np.array([a_list[slice(i, i+n)] for i in range(r - n+1)])
test_stacked = tf.stack(tf.gather(test, the_list))

What would be an efficient way of doing that without using numpy and list comprehension?

A full vanilla example:

array = np.array([[0, 1],[1, 2],[2, 3],[3, 4],[4, 5],[5, 6]])
array.shape # (6,2)

r = array.shape[0]
n = 4
a_list = list(range(r))
the_list = np.array([a_list[slice(i, i+n)] for i in range(r - n+1)])

result = array[the_list] # all possible slices of length 4 of the array along 0th axis
result.shape # (3, 4, 2)

result: [[[0 1] [1 2] [2 3] [3 4]]

[[1 2] [2 3] [3 4] [4 5]]

[[2 3] [3 4] [4 5] [5 6]]]



from How to efficiently extract all slices of given length using tensorflow

No comments:

Post a Comment