aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon')
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp22
-rw-r--r--src/backends/neon/workloads/NeonSplitterWorkload.cpp3
2 files changed, 19 insertions, 6 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index ee8f6f28f0..b6db52342e 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -19,6 +19,7 @@
#if defined(ARMCOMPUTENEON_ENABLED)
#include <aclCommon/ArmComputeUtils.hpp>
#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
#include "workloads/NeonAbsWorkload.hpp"
#include "workloads/NeonAdditionWorkload.hpp"
#include "workloads/NeonActivationWorkload.hpp"
@@ -101,11 +102,22 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ
{
return info;
}
- return TensorInfo(info.GetShape(),
- type.value(),
- info.GetQuantizationScale(),
- info.GetQuantizationOffset(),
- info.IsConstant());
+ if (info.HasMultipleQuantizationScales())
+ {
+ return TensorInfo(info.GetShape(),
+ type.value(),
+ info.GetQuantizationScales(),
+ info.GetQuantizationDim().value(),
+ info.IsConstant());
+ }
+ else
+ {
+ return TensorInfo(info.GetShape(),
+ type.value(),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset(),
+ info.IsConstant());
+ }
}
template< typename ... Args>
diff --git a/src/backends/neon/workloads/NeonSplitterWorkload.cpp b/src/backends/neon/workloads/NeonSplitterWorkload.cpp
index c307822325..bfde497640 100644
--- a/src/backends/neon/workloads/NeonSplitterWorkload.cpp
+++ b/src/backends/neon/workloads/NeonSplitterWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -12,6 +12,7 @@
#include <armnn/utility/PolymorphicDowncast.hpp>
#include <armnn/backends/TensorHandle.hpp>
#include <neon/NeonTensorHandle.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
#include "NeonWorkloadUtils.hpp"