aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp5
-rw-r--r--src/backends/neon/test/NeonLayerTests.cpp3
-rw-r--r--src/backends/neon/workloads/NeonWorkloadUtils.hpp3
3 files changed, 6 insertions, 5 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index ed0f41a888..4474b12d37 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -274,11 +274,6 @@ bool NeonLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
const Optional<TensorInfo>& biases,
Optional<std::string&> reasonIfUnsupported) const
{
- if (weights.HasPerAxisQuantization())
- {
- return false;
- }
-
// Multiplier > 1.0f currently not supported in ACL
if ((input.GetQuantizationScale() * weights.GetQuantizationScale()) / output.GetQuantizationScale() > 1.0f)
{
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 26c55365cf..d74a4c6ebe 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -110,6 +110,9 @@ ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dMult2,
false,
armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2dPerAxisQuantTestNchw, Convolution2dPerAxisQuantTest, DataLayout::NCHW);
+ARMNN_AUTO_TEST_CASE(Convolution2dPerAxisQuantTestNhwc, Convolution2dPerAxisQuantTest, DataLayout::NHWC);
+
// DepthToSpace
ARMNN_AUTO_TEST_CASE(DepthToSpaceNchwFloat32_1, DepthToSpaceTest1<DataType::Float32>, DataLayout::NCHW);
ARMNN_AUTO_TEST_CASE(DepthToSpaceNchwFloat32_2, DepthToSpaceTest2<DataType::Float32>, DataLayout::NCHW);
diff --git a/src/backends/neon/workloads/NeonWorkloadUtils.hpp b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
index f63946ec07..e9edc8901e 100644
--- a/src/backends/neon/workloads/NeonWorkloadUtils.hpp
+++ b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
@@ -46,6 +46,9 @@ inline void InitializeArmComputeTensorData(arm_compute::Tensor& tensor,
case DataType::QuantisedAsymm8:
CopyArmComputeTensorData(tensor, handle->GetConstTensor<uint8_t>());
break;
+ case DataType::QuantizedSymm8PerAxis:
+ CopyArmComputeTensorData(tensor, handle->GetConstTensor<int8_t>());
+ break;
case DataType::Signed32:
CopyArmComputeTensorData(tensor, handle->GetConstTensor<int32_t>());
break;