aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/WorkloadFactory.cpp')
-rw-r--r--src/armnn/backends/WorkloadFactory.cpp50
1 files changed, 46 insertions, 4 deletions
diff --git a/src/armnn/backends/WorkloadFactory.cpp b/src/armnn/backends/WorkloadFactory.cpp
index 32634a6d0f..4e94d7701c 100644
--- a/src/armnn/backends/WorkloadFactory.cpp
+++ b/src/armnn/backends/WorkloadFactory.cpp
@@ -10,7 +10,7 @@
#include "armnn/Types.hpp"
#include "armnn/LayerSupport.hpp"
#include "Layer.hpp"
-#include "Layers.hpp"
+#include "LayersFwd.hpp"
#include "CpuTensorHandle.hpp"
#include <boost/cast.hpp>
@@ -60,8 +60,50 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
{
auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsConvolution2dSupported(compute, input, cLayer->GetParameters(),
- cLayer->m_Weight->GetTensorInfo(), reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
+
+ const TensorInfo * biasInfo = nullptr;
+ static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
+ static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
+
+ const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
+
+ if (descriptor.m_BiasEnabled)
+ {
+ BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
+ biasInfo = &(cLayer->m_Bias->GetTensorInfo());
+ }
+ else
+ {
+ // If biases are not enabled I pass a dummy tensorinfo for the validation
+ switch(input.GetDataType())
+ {
+ case DataType::Float32:
+ {
+ biasInfo = &dummyFloat32Bias;
+ break;
+ }
+ case DataType::QuantisedAsymm8:
+ {
+ biasInfo = &dummyQA8Bias;
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unexpected input type");
+ }
+ }
+ }
+
+ result = IsConvolution2dSupported(compute,
+ input,
+ output,
+ descriptor,
+ cLayer->m_Weight->GetTensorInfo(),
+ *biasInfo,
+ reason,
+ reasonCapacity);
break;
}
case LayerType::MemCopy:
@@ -211,4 +253,4 @@ bool IWorkloadFactory::IsLayerSupported(const Layer& layer, DataType dataType, s
return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported);
}
-} \ No newline at end of file
+}