From 3537c2ca7ebf31c1673b9ec2bb0c17b0406bbae0 Mon Sep 17 00:00:00 2001 From: surmeh01 Date: Fri, 18 May 2018 16:31:43 +0100 Subject: Release 18.05 --- src/armnn/backends/WorkloadFactory.cpp | 50 +++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) (limited to 'src/armnn/backends/WorkloadFactory.cpp') diff --git a/src/armnn/backends/WorkloadFactory.cpp b/src/armnn/backends/WorkloadFactory.cpp index 32634a6d0f..4e94d7701c 100644 --- a/src/armnn/backends/WorkloadFactory.cpp +++ b/src/armnn/backends/WorkloadFactory.cpp @@ -10,7 +10,7 @@ #include "armnn/Types.hpp" #include "armnn/LayerSupport.hpp" #include "Layer.hpp" -#include "Layers.hpp" +#include "LayersFwd.hpp" #include "CpuTensorHandle.hpp" #include @@ -60,8 +60,50 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsConvolution2dSupported(compute, input, cLayer->GetParameters(), - cLayer->m_Weight->GetTensorInfo(), reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); + + const TensorInfo * biasInfo = nullptr; + static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32); + static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32); + + const Convolution2dDescriptor& descriptor = cLayer->GetParameters(); + + if (descriptor.m_BiasEnabled) + { + BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); + biasInfo = &(cLayer->m_Bias->GetTensorInfo()); + } + else + { + // If biases are not enabled I pass a dummy tensorinfo for the validation + switch(input.GetDataType()) + { + case DataType::Float32: + { + biasInfo = &dummyFloat32Bias; + break; + } + case DataType::QuantisedAsymm8: + { + biasInfo = &dummyQA8Bias; + break; + } + default: + { + BOOST_ASSERT_MSG(false, "Unexpected input type"); + } + } + } + + result = IsConvolution2dSupported(compute, + input, + output, + descriptor, + cLayer->m_Weight->GetTensorInfo(), + *biasInfo, + reason, + reasonCapacity); break; } case LayerType::MemCopy: @@ -211,4 +253,4 @@ bool IWorkloadFactory::IsLayerSupported(const Layer& layer, DataType dataType, s return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported); } -} \ No newline at end of file +} -- cgit v1.2.1