aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/AddBroadcastReshapeLayer.hpp')
-rw-r--r--src/armnn/optimizations/AddBroadcastReshapeLayer.hpp54
1 files changed, 25 insertions, 29 deletions
diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
index 0a5ad9d152..aa00b9913c 100644
--- a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
+++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
@@ -15,14 +15,9 @@ namespace armnn
namespace optimizations
{
-static const std::set<armnn::LayerType> broadcastOps {
- LayerType::Addition,
- LayerType::Division,
- LayerType::Maximum,
- LayerType::Minimum,
- LayerType::Multiplication,
- LayerType::Subtraction
-};
+static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition, LayerType::Division,
+ LayerType::Maximum, LayerType::Minimum,
+ LayerType::Multiplication, LayerType::Subtraction };
class AddBroadcastReshapeLayerImpl
{
@@ -35,8 +30,8 @@ public:
layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
- const TensorInfo &inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
- const TensorInfo &inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
+ const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
{
@@ -44,14 +39,14 @@ public:
}
unsigned int reshapeSlot = 1;
- TensorInfo reshapeInfo = inputInfo1;
- TensorInfo inputInfo = inputInfo0;
+ TensorInfo reshapeInfo = inputInfo1;
+ TensorInfo inputInfo = inputInfo0;
if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
{
reshapeSlot = 0;
reshapeInfo = inputInfo0;
- inputInfo = inputInfo1;
+ inputInfo = inputInfo1;
}
uint32_t numDimensions = inputInfo.GetNumDimensions();
@@ -63,38 +58,39 @@ public:
}
std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
- std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
+ std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
- // If the parent layer is a Constant layer we just change the tensor info rather than adding a reshape layer
+ // If the parent layer is a Constant layer and it is only used once we can short circuit by just
+ // changing the tensor info rather than adding a reshape layer.
Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
- if (parentLayer.GetType() == armnn::LayerType::Constant)
+ if ((parentLayer.GetType() == armnn::LayerType::Constant) &&
+ (parentLayer.GetOutputSlot(0).GetNumConnections() == 1))
{
ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);
constantLayer.m_LayerOutput = std::make_unique<ScopedCpuTensorHandle>(
- ConstTensor(reshapeInfo,constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
+ ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
-
- return;
}
-
- const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
- const ReshapeDescriptor descriptor{reshapeInfo.GetShape()};
- ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot),
- descriptor,
- layerName.c_str());
- reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
+ else
+ {
+ const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
+ const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
+ ReshapeLayer* reshapeLayer =
+ graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
+ reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
+ }
}
}
protected:
- AddBroadcastReshapeLayerImpl() = default;
+ AddBroadcastReshapeLayerImpl() = default;
~AddBroadcastReshapeLayerImpl() = default;
};
using AddBroadcastReshapeLayer = OptimizeForType<Layer, AddBroadcastReshapeLayerImpl>;
-} // namespace optimizations
-} // namespace armnn
+} // namespace optimizations
+} // namespace armnn