aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/NeonBackendOptimizationUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/NeonBackendOptimizationUtils.hpp')
-rw-r--r--src/backends/neon/NeonBackendOptimizationUtils.hpp33
1 files changed, 32 insertions, 1 deletions
diff --git a/src/backends/neon/NeonBackendOptimizationUtils.hpp b/src/backends/neon/NeonBackendOptimizationUtils.hpp
index 3a8bf46599..34ab41f09c 100644
--- a/src/backends/neon/NeonBackendOptimizationUtils.hpp
+++ b/src/backends/neon/NeonBackendOptimizationUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -100,6 +100,37 @@ bool BuildAddMulAddTensorInfoLists(Type* layerList[4],
ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[1], BinaryOperation::Mul));
ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[2], BinaryOperation::Add));
+ auto is1D = [](const TensorInfo expanded)
+ {
+ TensorInfo collapsed;
+ if (CollapseLeadingUnitDimensions(expanded, collapsed))
+ {
+ return (collapsed.GetNumDimensions() == 1);
+ }
+ else
+ {
+ return (expanded.GetNumDimensions() == 1);
+ }
+ };
+
+ // One of the 2 inputs for MUL and the Second ADD must be 1D
+ // ref: clframework/src/cpu/kernels/CpuAddMulAddKernel.cpp
+ auto& mulLayer = *(PolymorphicDowncast<ElementwiseBinaryLayer*>(layerList[1]));
+ auto& add2Layer = *(PolymorphicDowncast<ElementwiseBinaryLayer*>(layerList[2]));
+
+ Layer& mulInput0 = mulLayer.GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+ Layer& mulInput1 = mulLayer.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer();
+ Layer& add2Input0 = add2Layer.GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+ Layer& add2Input1 = add2Layer.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer();
+ if (!is1D(mulInput0.GetOutputSlot(0).GetTensorInfo()) && !is1D(mulInput1.GetOutputSlot(0).GetTensorInfo()))
+ {
+ return false;
+ }
+ if (!is1D(add2Input0.GetOutputSlot(0).GetTensorInfo()) && !is1D(add2Input1.GetOutputSlot(0).GetTensorInfo()))
+ {
+ return false;
+ }
+
fuseReLu = (layerList[3] != nullptr);
if (fuseReLu)
{