aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/FuseBatchNorm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/FuseBatchNorm.hpp')
-rw-r--r--src/armnn/optimizations/FuseBatchNorm.hpp68
1 files changed, 24 insertions, 44 deletions
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp
index 6a50fc4a0c..bca0c7d00a 100644
--- a/src/armnn/optimizations/FuseBatchNorm.hpp
+++ b/src/armnn/optimizations/FuseBatchNorm.hpp
@@ -14,8 +14,8 @@ namespace armnn
namespace optimizations
{
-template <typename ConvLayer, armnn::DataType ArmnnType,
- typename T = armnn::ResolveType<ArmnnType>>
+template<typename ConvLayer, armnn::DataType ArmnnType,
+ typename T = armnn::ResolveType<ArmnnType>>
class FuseBatchNorm
{
public:
@@ -26,7 +26,7 @@ public:
/// combined with the parameters of the child BatchNorm layer.
void Run(Graph& graph, InputSlot& connection) const
{
- Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
+ Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
Layer& child = connection.GetOwningLayer();
bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
@@ -37,7 +37,7 @@ public:
if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
{
OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
- auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
+ auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
// Read convolution and batch norm parameters
@@ -50,25 +50,16 @@ public:
ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
- auto convDescriptor = convLayer->GetParameters();
+ auto convDescriptor = convLayer->GetParameters();
ConstTensor weightsTensor;
- if (convLayer->GetNumInputSlots() > 1)
- {
- ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr,
- "FuseBatchNorm: Weight data should not be null.");
- InputSlot & oldSlotWeights = const_cast<InputSlot&>(convLayer->GetInputSlots()[1]);
- OutputSlot & constantSlotWeights = const_cast<OutputSlot&>(*oldSlotWeights.GetConnectedOutputSlot());
- ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
- &constantSlotWeights.GetOwningLayer());
- weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
- weightLayer->m_LayerOutput->Map(true));
- }
- else
- {
- ARMNN_ASSERT_MSG(convLayer->m_Weight != nullptr,
- "FuseBatchNorm: Bias data should not be null if bias is enabled.");
- weightsTensor = ConstTensor(convLayer->m_Weight->GetTensorInfo(), convLayer->m_Weight->Map(true));
- }
+ ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr,
+ "FuseBatchNorm: Weight data should not be null.");
+
+ ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
+ &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer());
+
+ weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
+ weightLayer->m_LayerOutput->Map(true));
armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
auto weightsShape = weightsTensor.GetInfo().GetShape();
@@ -76,9 +67,9 @@ public:
const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
const unsigned int outputChannels = depthwise ? weightsShape[3] : weightsShape[0];
const unsigned int weightsHeight = depthwise ? weightsShape[1] :
- weightsShape[dataLayout.GetHeightIndex()];
+ weightsShape[dataLayout.GetHeightIndex()];
const unsigned int weightsWidth = depthwise ? weightsShape[2] :
- weightsShape[dataLayout.GetWidthIndex()];
+ weightsShape[dataLayout.GetWidthIndex()];
const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea());
@@ -99,7 +90,7 @@ public:
{
for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
{
- T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon));
+ T mult = gammaVector[cOut] / static_cast<T>(sqrtf(varianceVector[cOut] + epsilon));
for (unsigned int h = 0; h < weightsHeight; ++h)
{
@@ -140,23 +131,14 @@ public:
if (biasWasEnabledBeforeOpt)
{
ConstTensor biasTensor;
- if (convLayer->GetNumInputSlots() > 1)
- {
- ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr,
- "FuseBatchNorm: Bias data should not be null if bias is enabled.");
- InputSlot & oldSlotBias = const_cast<InputSlot&>(convLayer->GetInputSlots()[2]);
- OutputSlot & constantSlotBias = const_cast<OutputSlot&>(*oldSlotBias.GetConnectedOutputSlot());
- ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>(
- &constantSlotBias.GetOwningLayer());
- biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
- biasLayer->m_LayerOutput->Map(true));
- }
- else
- {
- ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr,
- "FuseBatchNorm: Bias data should not be null if bias is enabled.");
- biasTensor = ConstTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true));
- }
+ ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr,
+ "FuseBatchNorm: Bias data should not be null if bias is enabled.");
+
+ ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>(
+ &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
+
+ biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
+ biasLayer->m_LayerOutput->Map(true));
const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
@@ -192,8 +174,6 @@ public:
// This optimization will always have 3 input slots on the Conv2d base layer
if (newConv2dLayer.GetNumInputSlots() > 1)
{
- ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
- &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer());
// Remove old connection and connect to new layer2d
weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1));
weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1));