aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp')
-rw-r--r--src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp
index dbdc06f174..6b8a230912 100644
--- a/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp
+++ b/src/armnn/backends/ClWorkloads/ClPooling2dBaseWorkload.cpp
@@ -25,10 +25,10 @@ arm_compute::Status ClPooling2dWorkloadValidate(const TensorInfo& input,
return arm_compute::CLPoolingLayer::validate(&aclInputInfo, &aclOutputInfo, layerInfo);
}
-template <armnn::DataType dataType>
-ClPooling2dBaseWorkload<dataType>::ClPooling2dBaseWorkload(
+template <armnn::DataType... dataTypes>
+ClPooling2dBaseWorkload<dataTypes...>::ClPooling2dBaseWorkload(
const Pooling2dQueueDescriptor& descriptor, const WorkloadInfo& info, const std::string& name)
- : TypedWorkload<Pooling2dQueueDescriptor, dataType>(descriptor, info)
+ : TypedWorkload<Pooling2dQueueDescriptor, dataTypes...>(descriptor, info)
{
m_Data.ValidateInputsOutputs(name, 1, 1);
@@ -37,11 +37,11 @@ ClPooling2dBaseWorkload<dataType>::ClPooling2dBaseWorkload(
arm_compute::PoolingLayerInfo layerInfo = BuildArmComputePoolingLayerInfo(m_Data.m_Parameters);
- // Run the layer
+ // Run the layer.
m_PoolingLayer.configure(&input, &output, layerInfo);
}
-template class ClPooling2dBaseWorkload<DataType::Float32>;
+template class ClPooling2dBaseWorkload<DataType::Float16, DataType::Float32>;
template class ClPooling2dBaseWorkload<DataType::QuantisedAsymm8>;
}