aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-09-14 16:12:44 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-09-15 10:19:00 +0100
commit16f82f987b44b090a01807a2c79ed7fcc6bf80ea (patch)
tree5e26fccece92956c19e14d0d5c106e5d38ea4576 /src/armnn/optimizations
parent919c14ef132986aa1514b2070ce6d19b5579a6ab (diff)
downloadarmnn-16f82f987b44b090a01807a2c79ed7fcc6bf80ea.tar.gz
IVGCVSW-5305 AddBroadcastReshapeLayer as optimizer
* Remove AddBroadcastReshapeLayer from TfLiteParser * Add AddBroadcastReshapeLayer as optimizer * AddBroadcastReshapeLayer optimizer unit tests * Load-scope dynamic tensor broadcasting unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I3549e85b71b41cbd4d96c0f1ece7887acbca76d1
Diffstat (limited to 'src/armnn/optimizations')
-rw-r--r--src/armnn/optimizations/AddBroadcastReshapeLayer.hpp85
-rw-r--r--src/armnn/optimizations/All.hpp1
2 files changed, 86 insertions, 0 deletions
diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
new file mode 100644
index 0000000000..6bb53d0f12
--- /dev/null
+++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
@@ -0,0 +1,85 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "Optimization.hpp"
+
+#include <armnn/utility/IgnoreUnused.hpp>
+#include <armnn/utility/PolymorphicDowncast.hpp>
+
+namespace armnn
+{
+namespace optimizations
+{
+
+static const std::set<armnn::LayerType> broadcastOps {
+ LayerType::Addition,
+ LayerType::Division,
+ LayerType::Maximum,
+ LayerType::Minimum,
+ LayerType::Multiplication,
+ LayerType::Subtraction
+};
+
+class AddBroadcastReshapeLayerImpl
+{
+public:
+ /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.
+ void Run(Graph& graph, Layer& layer) const
+ {
+ if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
+ {
+ layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
+ layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
+
+ const TensorInfo &inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ const TensorInfo &inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
+
+ if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
+ {
+ return;
+ }
+
+ unsigned int reshapeSlot = 1;
+ TensorInfo reshapeInfo = inputInfo1;
+ TensorInfo inputInfo = inputInfo0;
+
+ if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
+ {
+ reshapeSlot = 0;
+ reshapeInfo = inputInfo0;
+ inputInfo = inputInfo1;
+ }
+
+ uint32_t numDimensions = inputInfo.GetNumDimensions();
+
+ std::vector<unsigned> reshapedDim;
+ for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
+ {
+ reshapedDim.push_back(reshapeInfo.GetShape()[i]);
+ }
+
+ std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
+ std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
+
+ reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
+ const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
+ const ReshapeDescriptor descriptor{reshapeInfo.GetShape()};
+ ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot),
+ descriptor,
+ layerName.c_str());
+ reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
+ }
+ }
+
+protected:
+ AddBroadcastReshapeLayerImpl() = default;
+ ~AddBroadcastReshapeLayerImpl() = default;
+};
+
+using AddBroadcastReshapeLayer = OptimizeForType<Layer, AddBroadcastReshapeLayerImpl>;
+
+} // namespace optimizations
+} // namespace armnn
diff --git a/src/armnn/optimizations/All.hpp b/src/armnn/optimizations/All.hpp
index cb484d5a59..e89c36b834 100644
--- a/src/armnn/optimizations/All.hpp
+++ b/src/armnn/optimizations/All.hpp
@@ -4,6 +4,7 @@
//
#pragma once
+#include "AddBroadcastReshapeLayer.hpp"
#include "AddDebug.hpp"
#include "ConvertConstants.hpp"
#include "ConvertFp32NetworkToBf16.hpp"