tf.strided_slice其实就是TensorFlow中用来分片的函数。
1 | def strided_slice(input_, |
begin是分片开始, end是分片结束,strides是步长,注意这边的步长其实是从步长的最后一位开始的
Question:
I am wondering what tf.strided_slice()
operator actually does.
The doc says,
To a first order, this operation extracts a slice of size end - begin from a tensor input starting at the location specified by begin. The slice continues by adding stride to the begin index until all dimensions are not less than end. Note that components of stride can be negative, which causes a reverse slice.
And in the sample,
1 | # 'input' is [[[1, 1, 1], [2, 2, 2]], |
So in my understanding of the doc, the first sample (tf.slice(input, begin=[1, 0, 0], end=[2, 1, 3], strides=[1, 1, 1])
),
- resulting size is
end - begin = [1, 1, 3]
. The sample result shows[[[3, 3, 3,]]]
, that shape is[1, 1, 3]
, it seems OK. - the first element of the result is at
begin = [1, 0, 0]
. The first element of the sample result is3
, which isinput[1,0,0]
, it seems OK. - the slice continues by adding stride to the begin index. So the second element of the result should be
input[begin + strides] = input[2, 1, 1] = 6
, but the sample shows the second element is3
.
What strided_slice()
does?
Answer
The mistake in your argument is the fact that you are directly adding the lists strides
and begin
element by element. This will make the function a lot less useful. Instead, it increments the begin
list one dimension at a time, starting from the last dimension.
Let’s solve the first example part by part. begin = [1, 0, 0]
and end = [2, 1, 3]
. Also, all the strides
are 1
. Work your way backwards, from the last dimension.
Start with element [1,0,0]
. Now increase the last dimension only by its stride amount, giving you [1,0,1]
. Keep doing this until you reach the limit. Something like [1,0,2]
, [1,0,3]
(end of the loop). Now in your next iteration, start by incrementing the second to last dimension and resetting the last dimension, [1,1,0]
. Here the second to last dimension is equal to end[1]
, so move to the first dimension (third to last) and reset the rest, giving you [2,0,0]
. Again you are at the first dimension’s limit, so quit the loop.
The following code is a recursive implementation of what I described above,
1 | # Assume global `begin`, `end` and `stride` |