diff options
Diffstat (limited to 'chapters/scatter_gather.adoc')
-rw-r--r-- | chapters/scatter_gather.adoc | 51 |
1 files changed, 50 insertions, 1 deletions
diff --git a/chapters/scatter_gather.adoc b/chapters/scatter_gather.adoc index cfee60b..e1be77f 100644 --- a/chapters/scatter_gather.adoc +++ b/chapters/scatter_gather.adoc @@ -36,7 +36,7 @@ for_each(0<=n<N, 0<=w<W, 0<=c<C) { index_t k = tensor_read<index_t>(indices, [N,W], [n,w]) assert(0<=k && k<K) value_t value = tensor_read<value_t>(values, [N,K,C], [n, k, c]) - tensor_write<value_t>(output, [N,W,C], [n,w,c]) + tensor_write<value_t>(output, [N,W,C], [n,w,c], value) } ---- @@ -51,3 +51,52 @@ for_each(0<=n<N, 0<=w<W, 0<=c<C) { |MI,MT|float|int32|float |=== +==== SCATTER + +The values_out tensor is set to the values_in tensor with data modified as follows: data from the input tensor is inserted at the positions specified by the indices tensor. +N is the number of batches, W the number of indices in each batch, K the range of each index and C the number data channels for each index. + +*Arguments:* + +|=== +|Argument|Type|Name|Shape|Description + +|Input|value_t*|values_in|[N,K,C]|3D values in tensor +|Input|index_t*|indices|[N,W]|2D index tensor +|Input|value_t*|input|[N,W,C]|3D input tensor +|Output|value_t*|values_out|[N,K,C]|3D values out tensor +|=== + +*Quantization Parameters:* + +None + +*Operation Function:* + +[source,c] +---- +// Copy the values_in tensor to the values_out tensor. +// Values not written by the scatter operation are unchanged in the output. +for_each(0<=n<N, 0<=k<K, 0<=c<C) { + value_t value = tensor_read<value_t>(values_in, [N,K,C], [n,k,c]) + tensor_write<value_t>(values_out, [N,K,C], [n, k, c], value) +} +// Now perform the SCATTER operation, writing to the positions from the indices tensor +for_each(0<=n<N, 0<=w<W, 0<=c<C) { + index_t k = tensor_read<index_t>(indices, [N,W], [n,w]) + assert(0<=k && k<K) + value_t value = tensor_read<value_t>(input, [N,W,C], [n,w,c]) + tensor_write<value_t>(values_out, [N,K,C], [n, k, c], value) +} +---- + +*Supported Data Types:* + +|=== +|Profile|Mode|index_t|value_t + +|Any|signed 8|int32|aint8 +|Any|signed 16|int32|int16 +|Any|signed 32|int32|int32 +|MI,MT|float|int32|float +|=== |