diff options
Diffstat (limited to 'src/backends/neon/NeonBackendOptimizationUtils.hpp')
-rw-r--r-- | src/backends/neon/NeonBackendOptimizationUtils.hpp | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/src/backends/neon/NeonBackendOptimizationUtils.hpp b/src/backends/neon/NeonBackendOptimizationUtils.hpp new file mode 100644 index 0000000000..3a8bf46599 --- /dev/null +++ b/src/backends/neon/NeonBackendOptimizationUtils.hpp @@ -0,0 +1,215 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <aclCommon/ArmComputeSubgraphUtils.hpp> + +namespace armnn +{ + +// Changes shapes of the form [1, 1, ..., W] to [ W ] +inline bool CollapseLeadingUnitDimensions(const TensorInfo& in, TensorInfo& out) +{ + unsigned int numDimensions = in.GetNumDimensions(); + for (unsigned int i = 0; i < (numDimensions-1); ++i) + { + if (in.GetShape()[i] != 1) + { + return false; + } + } + + unsigned int w = in.GetShape()[numDimensions-1]; + out = in; + out.SetShape({w}); + + return true; +} + +// +// Build slot and tensor info lists for Add/Mul/Add replacement +// +template<typename SlotListType> +void BuildAddMulAddSlotLists(bool handleReLu, + bool multipleOutputs, + std::vector<SlotListType>& inputLayersSlotLists, + std::vector<SlotListType>& outputLayersSlotLists) +{ + // Build input slot list + inputLayersSlotLists.push_back({0, 1}); // Add + inputLayersSlotLists.push_back({1}); // Mul + inputLayersSlotLists.push_back({1}); // Add + if (handleReLu) + { + inputLayersSlotLists.push_back({}); // Relu + } + + // Build output slot list + if (multipleOutputs) + { + outputLayersSlotLists.push_back({0}); // Add + } + else + { + outputLayersSlotLists.push_back({}); // Add + } + outputLayersSlotLists.push_back({}); // Mul + if (handleReLu) + { + outputLayersSlotLists.push_back({}); // Add + outputLayersSlotLists.push_back({0}); // Relu + } + else + { + outputLayersSlotLists.push_back({0}); // Add + } +} + +inline void GetFusedName(Layer *layerList[4], std::string& fusedName) +{ + // Build the fused name string + fusedName = "fused"; + for (unsigned int layerIdx = 0; layerIdx< 4; ++layerIdx) + { + if (! layerList[layerIdx]) + { + break; + } + fusedName += "-"; + fusedName += layerList[layerIdx]->GetNameStr(); + } +} + +template<typename Type> +bool BuildAddMulAddTensorInfoLists(Type* layerList[4], + unsigned int& numInputs, + unsigned int& numOutputs, + std::vector<TensorInfo>& inputInfos, + std::vector<TensorInfo>& outputInfos, + const ActivationDescriptor*& activationDescriptor, + bool& fuseReLu) +{ + ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[0]); + ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[1]); + ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[2]); + + ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[0], BinaryOperation::Add)); + ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[1], BinaryOperation::Mul)); + ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[2], BinaryOperation::Add)); + + fuseReLu = (layerList[3] != nullptr); + if (fuseReLu) + { + activationDescriptor = &PolymorphicDowncast<ActivationLayer *>(layerList[3])->GetParameters(); + ARMNN_THROW_INVALIDARG_IF_FALSE((activationDescriptor->m_Function == ActivationFunction::ReLu) || + (activationDescriptor->m_Function == ActivationFunction::BoundedReLu)); + } + + numInputs = 0; + numOutputs = 0; + + // Ensure that there are 6 input slots in the add/mul/add layers + // we are going to replace + unsigned int layerIdx = 0; + unsigned int inputSlotCount = 0; + for (layerIdx = 0; layerIdx < 3; ++layerIdx) + { + for (unsigned int slotIdx = 0; slotIdx < layerList[layerIdx]->GetNumInputSlots(); ++slotIdx) + { + InputSlot* inputSlot = &layerList[layerIdx]->GetInputSlot(slotIdx); + OutputSlot* outputSlot = inputSlot->GetConnectedOutputSlot(); + if (outputSlot) + { + if (layerIdx == 0) + { + // Always count the input connections of the first add + inputInfos.push_back(inputSlot->GetTensorInfo()); + numInputs++; + } + else + { + // For subsequent layers, we skip connections to the previous layers in the counting + if (&outputSlot->GetOwningLayer() != layerList[layerIdx-1]) + { + TensorInfo inputSlotInfo = inputSlot->GetTensorInfo(); + if (numInputs == 2 || numInputs == 3) + { + // Workaround the broadcast optimization to collapse shapes such as + // [1, 1, 1, 2] to [2] as required by backend + if (CollapseLeadingUnitDimensions(inputSlot->GetTensorInfo(), inputSlotInfo)) + { + OutputSlot* previousLayerSlot = inputSlot->GetConnectedOutputSlot(); + if (previousLayerSlot) + { + if (previousLayerSlot->GetOwningLayer().GetType() == LayerType::Constant) + { + // First update the TensorInfo in the constant owning layer + previousLayerSlot->SetTensorInfo(inputSlotInfo); + // Then update the TensorInfo in the workload for the owning layer + ConstantLayer* layer = PolymorphicDowncast<ConstantLayer*>( + &previousLayerSlot->GetOwningLayer()); + layer->m_LayerOutput + = std::make_unique<ScopedTensorHandle>( + ConstTensor(inputSlotInfo, + layer->m_LayerOutput.get()->GetConstTensor<void>())); + } + } + } + } + inputInfos.push_back(inputSlotInfo); + numInputs++; + } + } + inputSlotCount++; + } + } + } + + // Check the input counts + bool validInputCount = (inputSlotCount == 6) && (inputInfos.size() == 4); + if (! validInputCount) + { + return false; + } + + const unsigned int maxIdx = (fuseReLu) ? 4 : 3; + for (layerIdx = 0; layerIdx < maxIdx; ++layerIdx) + { + for (unsigned int slotIdx = 0; slotIdx < layerList[layerIdx]->GetNumOutputSlots(); ++slotIdx) + { + OutputSlot* outputSlot = &layerList[layerIdx]->GetOutputSlot(slotIdx); + + for (unsigned int connectionIdx = 0; connectionIdx < outputSlot->GetNumConnections(); ++connectionIdx) + { + InputSlot* inputSlot = outputSlot->GetConnection(connectionIdx); + if (layerIdx < (maxIdx-1)) + { + if (&inputSlot->GetOwningLayer() != layerList[layerIdx+1]) + { + outputInfos.push_back(outputSlot->GetTensorInfo()); + numOutputs++; + } + } + else if (layerList[layerIdx] != nullptr) + { + outputInfos.push_back(outputSlot->GetTensorInfo()); + numOutputs++; + } + } + } + } + + // Check the output count + bool validOutputCount = (outputInfos.size() > 0); + if (! validOutputCount) + { + return false; + } + + return true; +} + +} |