diff options
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp')
-rw-r--r-- | src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp index dbdc06f174..6b8a230912 100644 --- a/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp +++ b/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp @@ -25,10 +25,10 @@ arm_compute::Status ClPooling2dWorkloadValidate(const TensorInfo& input, return arm_compute::CLPoolingLayer::validate(&aclInputInfo, &aclOutputInfo, layerInfo); } -template <armnn::DataType dataType> -ClPooling2dBaseWorkload<dataType>::ClPooling2dBaseWorkload( +template <armnn::DataType... dataTypes> +ClPooling2dBaseWorkload<dataTypes...>::ClPooling2dBaseWorkload( const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info, const std::string& name) - : TypedWorkload<Pooling2dQueueDescriptor, dataType>(descriptor, info) + : TypedWorkload<Pooling2dQueueDescriptor, dataTypes...>(descriptor, info) { m_Data.ValidateInputsOutputs(name, 1, 1); @@ -37,11 +37,11 @@ ClPooling2dBaseWorkload<dataType>::ClPooling2dBaseWorkload( arm_compute::PoolingLayerInfo layerInfo = BuildArmComputePoolingLayerInfo(m_Data.m_Parameters); - // Run the layer + // Run the layer. m_PoolingLayer.configure(&input, &output, layerInfo); } -template class ClPooling2dBaseWorkload<DataType::Float32>; +template class ClPooling2dBaseWorkload<DataType::Float16, DataType::Float32>; template class ClPooling2dBaseWorkload<DataType::QuantisedAsymm8>; } |