aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorSimon Obute <simon.obute@arm.com>2021-09-03 15:50:13 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2021-09-24 16:06:30 +0100
commit51f67776a695c217a32596af806afeeb080f5528 (patch)
tree33ccfd87ba365bcc6fc86d5a2181991a130b3061 /include
parentf10b15a8946f39bdf3f60cebc59d2963069eedca (diff)
downloadarmnn-51f67776a695c217a32596af806afeeb080f5528.tar.gz
IVGCVSW-3705 Add Channel Shuffle Front end and Ref Implementation
* Add front end * Add reference workload * Add unit tests * Add Serializer and Deserializer * Update ArmNN Versioning Signed-off-by: Simon Obute <simon.obute@arm.com> Change-Id: I9ac1f953af3974382eac8e8d62d794d2344e8f47
Diffstat (limited to 'include')
-rw-r--r--include/armnn/BackendHelper.hpp5
-rw-r--r--include/armnn/Descriptors.hpp22
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/INetwork.hpp7
-rw-r--r--include/armnn/Types.hpp2
-rw-r--r--include/armnn/Version.hpp2
-rw-r--r--include/armnn/backends/ILayerSupport.hpp5
-rw-r--r--include/armnnOnnxParser/Version.hpp2
-rw-r--r--include/armnnTfLiteParser/Version.hpp2
9 files changed, 44 insertions, 4 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index dee3b48b81..e3478a79c5 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -66,6 +66,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+ bool IsChannelShuffleSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ChannelShuffleDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
bool IsComparisonSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 341dbecd4f..d571f2297b 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1342,4 +1342,26 @@ struct ReduceDescriptor : BaseDescriptor
ReduceOperation m_ReduceOperation;
};
+/// A ChannelShuffleDescriptor for the ChannelShuffle operator
+struct ChannelShuffleDescriptor : BaseDescriptor
+{
+ ChannelShuffleDescriptor()
+ : m_NumGroups(0), m_Axis(0)
+ {}
+
+ ChannelShuffleDescriptor(const uint32_t& numGroups, const uint32_t& axis)
+ : m_NumGroups(numGroups), m_Axis(axis)
+ {}
+
+ bool operator ==(const ChannelShuffleDescriptor& rhs) const
+ {
+ return m_NumGroups == rhs.m_NumGroups;
+ }
+
+ /// Number of groups for the channel shuffle operation
+ uint32_t m_NumGroups;
+ /// Axis to apply channel shuffle operation on
+ uint32_t m_Axis;
+};
+
} // namespace armnn
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 3b43c42d23..396b7285fd 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -13,6 +13,7 @@ struct ActivationDescriptor;
struct ArgMinMaxDescriptor;
struct BatchNormalizationDescriptor;
struct BatchToSpaceNdDescriptor;
+struct ChannelShuffleDescriptor;
struct ComparisonDescriptor;
struct Convolution2dDescriptor;
struct DepthwiseConvolution2dDescriptor;
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 3bbc406bff..37aeaf47fe 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -704,6 +704,13 @@ public:
const LstmInputParams& params,
const char* name = nullptr);
+ /// Add a ChannelShuffle layer to the network
+ /// @param descriptor - Parameters for the ChannelShuffle operation
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer
+ IConnectableLayer* AddChannelShuffleLayer(const ChannelShuffleDescriptor& descriptor,
+ const char* name = nullptr);
+
void Accept(ILayerVisitor& visitor) const;
void ExecuteStrategy(IStrategy& strategy) const;
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index c3b439a2d9..2fab6b44a9 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -421,7 +421,7 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(Cast) \
X(Shape) \
X(UnidirectionalSequenceLstm) \
-
+ X(ChannelShuffle) \
// New layers should be added at last to minimize instability.
/// When adding a new layer, adapt also the LastLayer enum value in the
diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp
index 5347097982..3a5b568169 100644
--- a/include/armnn/Version.hpp
+++ b/include/armnn/Version.hpp
@@ -10,7 +10,7 @@
#define STRINGIFY_MACRO(s) #s
// ArmNN version components
-#define ARMNN_MAJOR_VERSION 26
+#define ARMNN_MAJOR_VERSION 27
#define ARMNN_MINOR_VERSION 0
#define ARMNN_PATCH_VERSION 0
diff --git a/include/armnn/backends/ILayerSupport.hpp b/include/armnn/backends/ILayerSupport.hpp
index 7ba565a138..f511ee4c89 100644
--- a/include/armnn/backends/ILayerSupport.hpp
+++ b/include/armnn/backends/ILayerSupport.hpp
@@ -64,6 +64,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsChannelShuffleSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ChannelShuffleDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsComparisonSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/include/armnnOnnxParser/Version.hpp b/include/armnnOnnxParser/Version.hpp
index 78b4b0453d..da3e392bc8 100644
--- a/include/armnnOnnxParser/Version.hpp
+++ b/include/armnnOnnxParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnOnnxParser
// OnnxParser version components
#define ONNX_PARSER_MAJOR_VERSION 24
-#define ONNX_PARSER_MINOR_VERSION 2
+#define ONNX_PARSER_MINOR_VERSION 3
#define ONNX_PARSER_PATCH_VERSION 0
/// ONNX_PARSER_VERSION: "X.Y.Z"
diff --git a/include/armnnTfLiteParser/Version.hpp b/include/armnnTfLiteParser/Version.hpp
index c781b5809e..b0490cebec 100644
--- a/include/armnnTfLiteParser/Version.hpp
+++ b/include/armnnTfLiteParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnTfLiteParser
// TfLiteParser version components
#define TFLITE_PARSER_MAJOR_VERSION 24
-#define TFLITE_PARSER_MINOR_VERSION 2
+#define TFLITE_PARSER_MINOR_VERSION 3
#define TFLITE_PARSER_PATCH_VERSION 0
/// TFLITE_PARSER_VERSION: "X.Y.Z"