aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2018-10-11 12:39:05 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:53 +0100
commit595408218a0e17f04d91ff131a8227a4f352ff61 (patch)
tree515316e28abbed3dce388bc99be5ff52bc042765 /include
parenta0944791e87902b35e06c306c7b1a6f0f5bbfbd7 (diff)
downloadarmnn-595408218a0e17f04d91ff131a8227a4f352ff61.tar.gz
IVGCVSW-1978: Support NHWC for ResizeBilinear CpuRef
* Adds implementation to plumb DataLayout parameter for ResizeBilinear on CpuRef. * Adds unit tests to execute ResizeBilinear on CpuRef using the NHWC data layout. * Adds DataLayoutIndexed API, allowing easy access to the Channels, Height and Width of a tensor based on its data layout. This reduces code duplication. * Refactors original ResizeBilinear implementation and tests to use the DataLayoutIndexed API when required. Change-Id: Ic2b8916cdd2e370d070175547079d774daf6d7bf
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp6
-rw-r--r--include/armnn/Types.hpp44
2 files changed, 47 insertions, 3 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index c2efa211a6..c0510055f2 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -313,9 +313,9 @@ struct ResizeBilinearDescriptor
, m_DataLayout(DataLayout::NCHW)
{}
- uint32_t m_TargetWidth;
- uint32_t m_TargetHeight;
- DataLayout m_DataLayout;
+ uint32_t m_TargetWidth;
+ uint32_t m_TargetHeight;
+ DataLayoutIndexed m_DataLayout;
};
struct ReshapeDescriptor
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 4afc50b1b1..bb0b1e6ca7 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -31,12 +31,56 @@ enum class DataType
Signed32 = 3
};
+// Begin: DataLayout
+
enum class DataLayout
{
NCHW = 1,
NHWC = 2
};
+// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
+class DataLayoutIndexed
+{
+public:
+ DataLayoutIndexed(DataLayout dataLayout) : m_DataLayout(dataLayout)
+ {
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ m_ChannelsIndex = 3;
+ m_HeightIndex = 1;
+ m_WidthIndex = 2;
+ break;
+ case DataLayout::NCHW:
+ m_ChannelsIndex = 1;
+ m_HeightIndex = 2;
+ m_WidthIndex = 3;
+ break;
+ default:
+ throw InvalidArgumentException("Unknown DataLayout value: " +
+ std::to_string(static_cast<int>(dataLayout)));
+ }
+ }
+
+ DataLayout GetDataLayout() const { return m_DataLayout; }
+ unsigned int GetChannelsIndex() const { return m_ChannelsIndex; }
+ unsigned int GetHeightIndex() const { return m_HeightIndex; }
+ unsigned int GetWidthIndex() const { return m_WidthIndex; }
+
+private:
+ DataLayout m_DataLayout;
+ unsigned int m_ChannelsIndex;
+ unsigned int m_HeightIndex;
+ unsigned int m_WidthIndex;
+};
+
+// Conversion methods - implementations in src/armnn/InternalTypes.cpp
+bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed);
+bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout);
+
+// End: DataLayout
+
enum class ActivationFunction
{
Sigmoid = 0,