From a68241066c3e797dab70f515d2c55aaa74abf564 Mon Sep 17 00:00:00 2001 From: arovir01 Date: Tue, 28 Aug 2018 17:40:45 +0100 Subject: IVGCVSW-1324: Make biases truly optional for Conv2d and DepthwiseConvolution !android-nn-driver:145961 Change-Id: I039ab0adc61725859514246365b5e5b5fa603eaa --- src/armnn/backends/WorkloadFactory.cpp | 90 ++++++---------------------------- 1 file changed, 15 insertions(+), 75 deletions(-) (limited to 'src/armnn/backends/WorkloadFactory.cpp') diff --git a/src/armnn/backends/WorkloadFactory.cpp b/src/armnn/backends/WorkloadFactory.cpp index 1b3f29421a..5708dc0b0c 100644 --- a/src/armnn/backends/WorkloadFactory.cpp +++ b/src/armnn/backends/WorkloadFactory.cpp @@ -130,49 +130,20 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boo case LayerType::Convolution2d: { auto cLayer = boost::polymorphic_downcast(&layer); - const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType); + + const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); - TensorInfo biasInfo; - const TensorInfo * biasInfoPtr = nullptr; - static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16); - static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32); - static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32); - - const Convolution2dDescriptor& descriptor = cLayer->GetParameters(); + const Convolution2dDescriptor& descriptor = cLayer->GetParameters(); + // Construct optional biases object based on the value of m_BiasEnabled + boost::optional biases(boost::none); if (descriptor.m_BiasEnabled) { - BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); - biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); - biasInfoPtr = &biasInfo; - } - else - { - // If biases are not enabled pass a dummy tensorinfo for the validation. - switch(input.GetDataType()) - { - case DataType::Float16: - { - biasInfoPtr = &dummyFloat16Bias; - break; - } - case DataType::Float32: - { - biasInfoPtr = &dummyFloat32Bias; - break; - } - case DataType::QuantisedAsymm8: - { - biasInfoPtr = &dummyQA8Bias; - break; - } - default: - { - BOOST_ASSERT_MSG(false, "Unexpected input type"); - } - } + biases = boost::make_optional( + OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType))); } result = IsConvolution2dSupported(compute, @@ -180,7 +151,7 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boo output, descriptor, OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType), - *biasInfoPtr, + biases, reason, reasonCapacity); break; @@ -202,53 +173,22 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boo const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); - TensorInfo biasInfo; - const TensorInfo * biasInfoPtr = nullptr; - static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16); - static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32); - static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32); - const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters(); + + // Construct optional biases object based on the value of m_BiasEnabled + boost::optional biases(boost::none); if (descriptor.m_BiasEnabled) { - BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); - biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); - biasInfoPtr = &biasInfo; - } - else - { - // If biases are not enabled pass a dummy tensorinfo for the validation - switch(input.GetDataType()) - { - case DataType::Float16: - { - biasInfoPtr = &dummyFloat16Bias; - break; - } - case DataType::Float32: - { - biasInfoPtr = &dummyFloat32Bias; - break; - } - case DataType::QuantisedAsymm8: - { - biasInfoPtr = &dummyQA8Bias; - break; - } - default: - { - BOOST_ASSERT_MSG(false, "Unexpected bias type"); - } - } + biases = boost::make_optional( + OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType))); } - result = IsDepthwiseConvolutionSupported(compute, input, output, descriptor, OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType), - *biasInfoPtr, + biases, reason, reasonCapacity); break; -- cgit v1.2.1