diff options
Diffstat (limited to 'src/armnn/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp')
-rw-r--r-- | src/armnn/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp | 69 |
1 files changed, 50 insertions, 19 deletions
diff --git a/src/armnn/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp b/src/armnn/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp index 423f02bcb0..e76afb6cf7 100644 --- a/src/armnn/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp +++ b/src/armnn/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp @@ -9,6 +9,9 @@ #include "NeonConvolution2dBaseWorkload.hpp" +#include "armnn/Types.hpp" +#include "Half.hpp" + namespace armnn { @@ -41,28 +44,28 @@ arm_compute::Status NeonConvolution2dWorkloadValidate(const TensorInfo& input, layerInfo); } -template<armnn::DataType dataType> -NeonConvolution2dBaseWorkload<dataType>::NeonConvolution2dBaseWorkload(const Convolution2dQueueDescriptor& descriptor, - const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager) - : TypedWorkload<Convolution2dQueueDescriptor, dataType>(descriptor, info) +template<armnn::DataType... dataTypes> +NeonConvolution2dBaseWorkload<dataTypes...>::NeonConvolution2dBaseWorkload( + const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info, + std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager) + : TypedWorkload<Convolution2dQueueDescriptor, dataTypes...>(descriptor, info) { using arm_compute::NEDirectConvolutionLayer; - using namespace armcomputetensorutils; ValidateData(); - // todo: check tensor shapes match + // todo: check tensor shapes match. arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); - BuildArmComputeTensor(m_KernelTensor, m_Data.m_Weight->GetTensorInfo()); + m_KernelTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_KernelTensor, m_Data.m_Weight->GetTensorInfo()); - arm_compute::Tensor* optionalBiasTensor = nullptr; if (m_Data.m_Parameters.m_BiasEnabled) { - BuildArmComputeTensor(m_BiasTensor, m_Data.m_Bias->GetTensorInfo()); - optionalBiasTensor = &m_BiasTensor; + m_BiasTensor = std::make_unique<arm_compute::Tensor>(); + BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo()); } arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX, @@ -81,8 +84,8 @@ NeonConvolution2dBaseWorkload<dataType>::NeonConvolution2dBaseWorkload(const Con { auto directConvolutionLayer = std::make_unique<arm_compute::NEDirectConvolutionLayer>(memoryManager); directConvolutionLayer->configure(&input, - &m_KernelTensor, - optionalBiasTensor, + m_KernelTensor.get(), + m_BiasTensor.get(), &output, padStrideInfo); m_ConvolutionLayer.reset(directConvolutionLayer.release()); @@ -91,22 +94,50 @@ NeonConvolution2dBaseWorkload<dataType>::NeonConvolution2dBaseWorkload(const Con { auto convolutionLayer = std::make_unique<arm_compute::NEConvolutionLayer>(memoryManager); convolutionLayer->configure(&input, - &m_KernelTensor, - optionalBiasTensor, + m_KernelTensor.get(), + m_BiasTensor.get(), &output, padStrideInfo); m_ConvolutionLayer.reset(convolutionLayer.release()); } BOOST_ASSERT(m_ConvolutionLayer); - using Type = ResolveType<dataType>; + armnn::DataType dataType = m_Data.m_Weight->GetTensorInfo().GetDataType(); + + switch (dataType) + { + case DataType::Float16: + { + InitialiseArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<Half>()); + break; + } + case DataType::Float32: + { + InitialiseArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<float>()); + break; + } + case DataType::QuantisedAsymm8: + { + InitialiseArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<uint8_t>()); + break; + } + default: + { + BOOST_ASSERT_MSG(false, "Unknown DataType."); + } + } +} - InitialiseArmComputeTensorData(m_KernelTensor, m_Data.m_Weight->template GetConstTensor<Type>()); +template<armnn::DataType... dataTypes> +void NeonConvolution2dBaseWorkload<dataTypes...>::FreeUnusedTensors() +{ + FreeTensorIfUnused(m_KernelTensor); + FreeTensorIfUnused(m_BiasTensor); } -// Generate known implementations for linker -template class NeonConvolution2dBaseWorkload<DataType::Float32>; -template class NeonConvolution2dBaseWorkload<DataType::QuantisedAsymm8>; +// Generates known implementations for linker. +template class NeonConvolution2dBaseWorkload<armnn::DataType::Float16, armnn::DataType::Float32>; +template class NeonConvolution2dBaseWorkload<armnn::DataType::QuantisedAsymm8>; } //namespace armnn |