aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TensorHelpers.hpp
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2018-10-31 11:47:53 +0000
committerJames Conroy <james.conroy@arm.com>2018-10-31 12:06:53 +0000
commit45a9b775bf63283320315d90e4e9a6c641df6e20 (patch)
treee1f0d33d98410255a6804ea9cccf16805fc6080f /src/armnn/test/TensorHelpers.hpp
parentd84216a013445e86183e39c8b5b904836c71a95b (diff)
downloadarmnn-45a9b775bf63283320315d90e4e9a6c641df6e20.tar.gz
IVGCVSW-2102: Fix Pooling2D CpuRef indexing bug
* Fixes bug when calcuating indexes for NHWC in Pooling2D CpuRef implementation, it now uses TensorBufferArrayView. * Adds 2-Channel unit tests for Pooling2d on CpuRef, Cl and Neon. The single channel tests were not properly exercising Pooling2d using NHWC data layout. * Refactors Pooling2D NHWC tests so that the input and output data are permuted to NHWC when necessary, instead of hard coding the data in NHWC format. Change-Id: I5b9d41ed425ff283ea8c8ef6b1266ae0bc80f43b
Diffstat (limited to 'src/armnn/test/TensorHelpers.hpp')
-rw-r--r--src/armnn/test/TensorHelpers.hpp19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/armnn/test/TensorHelpers.hpp b/src/armnn/test/TensorHelpers.hpp
index 7f3ac9ec95..f1ab6c99b5 100644
--- a/src/armnn/test/TensorHelpers.hpp
+++ b/src/armnn/test/TensorHelpers.hpp
@@ -210,3 +210,22 @@ boost::multi_array<T, n> MakeRandomTensor(const armnn::TensorInfo& tensorInfo,
int32_t qOffset = tensorInfo.GetQuantizationOffset();
return MakeTensor<T, n>(tensorInfo, QuantizedVector<T>(qScale, qOffset, init));
}
+
+template<typename T>
+armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
+ unsigned int numberOfChannels,
+ unsigned int height,
+ unsigned int width,
+ const armnn::DataLayoutIndexed& dataLayout)
+{
+ switch (dataLayout.GetDataLayout())
+ {
+ case armnn::DataLayout::NCHW:
+ return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, armnn::GetDataType<T>());
+ case armnn::DataLayout::NHWC:
+ return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, armnn::GetDataType<T>());
+ default:
+ throw armnn::InvalidArgumentException("unknown data layout ["
+ + std::to_string(static_cast<int>(dataLayout.GetDataLayout())) + "]");
+ }
+}