diff options
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/NeonLayerSupport.cpp | 22 | ||||
-rw-r--r-- | src/backends/neon/workloads/NeonSplitterWorkload.cpp | 3 |
2 files changed, 19 insertions, 6 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index ee8f6f28f0..b6db52342e 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -19,6 +19,7 @@ #if defined(ARMCOMPUTENEON_ENABLED) #include <aclCommon/ArmComputeUtils.hpp> #include <aclCommon/ArmComputeTensorUtils.hpp> +#include <backendsCommon/WorkloadUtils.hpp> #include "workloads/NeonAbsWorkload.hpp" #include "workloads/NeonAdditionWorkload.hpp" #include "workloads/NeonActivationWorkload.hpp" @@ -101,11 +102,22 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ { return info; } - return TensorInfo(info.GetShape(), - type.value(), - info.GetQuantizationScale(), - info.GetQuantizationOffset(), - info.IsConstant()); + if (info.HasMultipleQuantizationScales()) + { + return TensorInfo(info.GetShape(), + type.value(), + info.GetQuantizationScales(), + info.GetQuantizationDim().value(), + info.IsConstant()); + } + else + { + return TensorInfo(info.GetShape(), + type.value(), + info.GetQuantizationScale(), + info.GetQuantizationOffset(), + info.IsConstant()); + } } template< typename ... Args> diff --git a/src/backends/neon/workloads/NeonSplitterWorkload.cpp b/src/backends/neon/workloads/NeonSplitterWorkload.cpp index c307822325..bfde497640 100644 --- a/src/backends/neon/workloads/NeonSplitterWorkload.cpp +++ b/src/backends/neon/workloads/NeonSplitterWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -12,6 +12,7 @@ #include <armnn/utility/PolymorphicDowncast.hpp> #include <armnn/backends/TensorHandle.hpp> #include <neon/NeonTensorHandle.hpp> +#include <backendsCommon/WorkloadUtils.hpp> #include "NeonWorkloadUtils.hpp" |