aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-03-12 15:05:49 +0000
committerJim Flynn <jim.flynn@arm.com>2021-03-16 09:50:58 +0000
commit9cd4ce1e6f76c070ac20ebcf4c67fc7ba8ba358a (patch)
treef5be908ea8c47ce36dd16758d08245e63c5e7d50 /src/armnn/optimizations
parentc2d9559287bd9df0bb361d4d977c170e80dd4475 (diff)
downloadarmnn-9cd4ce1e6f76c070ac20ebcf4c67fc7ba8ba358a.tar.gz
IVGCVSW-5754 Change the behaviour of the AddBroadcastReshapeLayer Optimisation when the input is a const tensor
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I8b1357bdefc45880d064d7e448af364ac8644c0d
Diffstat (limited to 'src/armnn/optimizations')
-rw-r--r--src/armnn/optimizations/AddBroadcastReshapeLayer.hpp15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
index 6bb53d0f12..26661cfcde 100644
--- a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
+++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
@@ -8,6 +8,7 @@
#include <armnn/utility/IgnoreUnused.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
+#include <backendsCommon/CpuTensorHandle.hpp>
namespace armnn
{
@@ -65,6 +66,20 @@ public:
std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
+
+ // If the parent layer is a Constant layer we just change the tensor info rather than adding a reshape layer
+ Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
+ if (parentLayer.GetType() == armnn::LayerType::Constant)
+ {
+ ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);
+
+ constantLayer.m_LayerOutput = std::make_unique<ScopedCpuTensorHandle>(
+ ConstTensor(reshapeInfo,constantLayer.m_LayerOutput.get()->GetTensor<void>()));
+ constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
+
+ return;
+ }
+
const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
const ReshapeDescriptor descriptor{reshapeInfo.GetShape()};
ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot),