diff options
author | Keith Davis <keith.davis@arm.com> | 2020-02-11 16:51:50 +0000 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-02-17 21:53:29 +0000 |
commit | 0c2eeac6347533a1d3d456aebea492f5123388f3 (patch) | |
tree | f218fc236137791c491b680dfd24fb9706c171a6 /src/backends/reference/RefLayerSupport.cpp | |
parent | 4c3c1f486ab775eacb1f6455f8468f9be2c3e4f7 (diff) | |
download | armnn-0c2eeac6347533a1d3d456aebea492f5123388f3.tar.gz |
IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8
* Add QAsymmS8 to QueueDescriptor supportedTypes
* Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes
* Some additional comments and refactoring
Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab
Signed-off-by: Keith Davis <keith.davis@arm.com>
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index c60348e529..bba83e23d4 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -4,15 +4,11 @@ // #include "RefLayerSupport.hpp" -#include "RefBackendId.hpp" +#include <armnn/TypesUtils.hpp> #include <armnn/Types.hpp> #include <armnn/Descriptors.hpp> -#include <armnn/BackendRegistry.hpp> -#include <armnnUtils/DataLayoutIndexed.hpp> - -#include <InternalTypes.hpp> #include <LayerSupportCommon.hpp> #include <backendsCommon/LayerSupportRules.hpp> @@ -21,7 +17,6 @@ #include <boost/core/ignore_unused.hpp> #include <vector> -#include <algorithm> #include <array> using namespace boost; @@ -84,9 +79,11 @@ bool RefLayerSupport::IsActivationSupported(const TensorInfo& input, bool supported = true; // Define supported types. - std::array<DataType,4> supportedTypes = { + std::array<DataType,6> supportedTypes = { DataType::Float32, DataType::Float16, + DataType::QSymmS8, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -147,10 +144,11 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0, { bool supported = true; - std::array<DataType,5> supportedTypes = { + std::array<DataType,6> supportedTypes = { DataType::Float32, DataType::Float16, DataType::QSymmS8, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -420,11 +418,12 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, bool supported = true; // Define supported types. - std::array<DataType,5> supportedTypes = + std::array<DataType,6> supportedTypes = { DataType::Float32, DataType::Float16, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS8, DataType::QSymmS16 }; @@ -439,13 +438,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, "Reference Convolution2d: input and output types mismatched."); const DataType inputType = input.GetDataType(); - if (inputType == DataType::QAsymmU8) + if (IsQuantized8BitType(inputType)) { ARMNN_NO_DEPRECATE_WARN_BEGIN - std::array<DataType, 3> supportedWeightTypes = + std::array<DataType, 4> supportedWeightTypes = { DataType::QAsymmU8, DataType::QSymmS8, + DataType::QAsymmS8, DataType::QuantizedSymm8PerAxis // deprecated }; ARMNN_NO_DEPRECATE_WARN_END @@ -485,11 +485,12 @@ bool RefLayerSupport::IsDebugSupported(const TensorInfo& input, { bool supported = true; - std::array<DataType, 6> supportedTypes = + std::array<DataType, 7> supportedTypes = { DataType::Float16, DataType::Float32, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS8, DataType::QSymmS16, DataType::Signed32 @@ -545,10 +546,12 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, bool supported = true; // Define supported types. - std::array<DataType,4> supportedTypes = + std::array<DataType,6> supportedTypes = { DataType::Float32, DataType::Float16, + DataType::QSymmS8, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -572,7 +575,7 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, ARMNN_NO_DEPRECATE_WARN_END const DataType inputType = input.GetDataType(); - if (inputType == DataType::QAsymmU8) + if (IsQuantized8BitType(inputType)) { supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, @@ -1413,10 +1416,12 @@ bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input, bool supported = true; // Define supported output and inputs types. - std::array<DataType,4> supportedTypes = + std::array<DataType,6> supportedTypes = { DataType::Float32, DataType::Float16, + DataType::QSymmS8, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1476,15 +1481,17 @@ bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input, ignore_unused(output); ignore_unused(descriptor); // Define supported output types. - std::array<DataType,6> supportedOutputTypes = + std::array<DataType,7> supportedOutputTypes = { DataType::Float32, DataType::Float16, DataType::Signed32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS8, DataType::QSymmS16 }; + return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported, "Reference reshape: input type not supported."); } @@ -1586,10 +1593,12 @@ bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input, { boost::ignore_unused(descriptor); bool supported = true; - std::array<DataType,4> supportedTypes = + std::array<DataType,6> supportedTypes = { DataType::Float32, DataType::Float16, + DataType::QSymmS8, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; |