diff options
Diffstat (limited to 'tests/validation/reference/ScatterLayer.cpp')
-rw-r--r-- | tests/validation/reference/ScatterLayer.cpp | 64 |
1 files changed, 59 insertions, 5 deletions
diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp index 188cce100b..920f2b9990 100644 --- a/tests/validation/reference/ScatterLayer.cpp +++ b/tests/validation/reference/ScatterLayer.cpp @@ -32,16 +32,70 @@ namespace validation { namespace reference { +namespace +{ template <typename T> +T reduce_op(const T ¤t,const T &update,const ScatterFunction func) +{ + switch(func) + { + case ScatterFunction::Update: + return update; + break; + case ScatterFunction::Add: + return current + update; + break; + case ScatterFunction::Sub: + return current - update; + break; + case ScatterFunction::Max: + return std::max(current, update); + break; + case ScatterFunction::Min: + return std::min(current, update); + break; + default: + ARM_COMPUTE_ERROR("Unsupported Scatter function"); + break; + } +} + +template float reduce_op(const float ¤t,const float &update,const ScatterFunction func); +} + +// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors. +template <typename T> SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info) { - ARM_COMPUTE_UNUSED(src); - ARM_COMPUTE_UNUSED(updates); - ARM_COMPUTE_UNUSED(indices); - ARM_COMPUTE_UNUSED(info); - // Unimplemented reference. SimpleTensor<T> dst{ out_shape, src.data_type(), 1 }; + + // 1. If zero initialization variable is true, fill dst with 0 values. Else copy src data to dst. + if(info.zero_initialization) + { + for (int i = 0; i < src.num_elements(); ++i) + { + dst[i] = static_cast<T>(0); + } + } + else + { + std::copy_n(src.data(), src.num_elements(), dst.data()); + } + + // 2. Get max index of output tensor, then iterate over index tensor. + const auto x_bound = dst.shape().x(); + + + for(int i = 0; i < indices.num_elements(); ++i) + { + // 3. Check whether index is out of bounds for dst, if not then apply reduce op. + const auto index = indices[i]; + if (index < x_bound) // Note : index is always >= 0 as datatype is unsigned. + { + dst[index] = reduce_op(dst[index], updates[i], info.func); + } + } return dst; } |