diff options
author | James Conroy <james.conroy@arm.com> | 2018-10-31 11:47:53 +0000 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2018-10-31 12:06:53 +0000 |
commit | 45a9b775bf63283320315d90e4e9a6c641df6e20 (patch) | |
tree | e1f0d33d98410255a6804ea9cccf16805fc6080f /src/armnn/test/TensorHelpers.hpp | |
parent | d84216a013445e86183e39c8b5b904836c71a95b (diff) | |
download | armnn-45a9b775bf63283320315d90e4e9a6c641df6e20.tar.gz |
IVGCVSW-2102: Fix Pooling2D CpuRef indexing bug
* Fixes bug when calcuating indexes for NHWC in
Pooling2D CpuRef implementation, it now uses
TensorBufferArrayView.
* Adds 2-Channel unit tests for Pooling2d on CpuRef,
Cl and Neon. The single channel tests were not
properly exercising Pooling2d using NHWC data layout.
* Refactors Pooling2D NHWC tests so that the input and
output data are permuted to NHWC when necessary,
instead of hard coding the data in NHWC format.
Change-Id: I5b9d41ed425ff283ea8c8ef6b1266ae0bc80f43b
Diffstat (limited to 'src/armnn/test/TensorHelpers.hpp')
-rw-r--r-- | src/armnn/test/TensorHelpers.hpp | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/src/armnn/test/TensorHelpers.hpp b/src/armnn/test/TensorHelpers.hpp index 7f3ac9ec95..f1ab6c99b5 100644 --- a/src/armnn/test/TensorHelpers.hpp +++ b/src/armnn/test/TensorHelpers.hpp @@ -210,3 +210,22 @@ boost::multi_array<T, n> MakeRandomTensor(const armnn::TensorInfo& tensorInfo, int32_t qOffset = tensorInfo.GetQuantizationOffset(); return MakeTensor<T, n>(tensorInfo, QuantizedVector<T>(qScale, qOffset, init)); } + +template<typename T> +armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches, + unsigned int numberOfChannels, + unsigned int height, + unsigned int width, + const armnn::DataLayoutIndexed& dataLayout) +{ + switch (dataLayout.GetDataLayout()) + { + case armnn::DataLayout::NCHW: + return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, armnn::GetDataType<T>()); + case armnn::DataLayout::NHWC: + return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, armnn::GetDataType<T>()); + default: + throw armnn::InvalidArgumentException("unknown data layout [" + + std::to_string(static_cast<int>(dataLayout.GetDataLayout())) + "]"); + } +} |