aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads/ClConvolution2dWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClConvolution2dWorkload.cpp14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/backends/cl/workloads/ClConvolution2dWorkload.cpp b/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
index 7b52f2784f..50cb9ded37 100644
--- a/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
+++ b/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
@@ -25,7 +25,8 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
const Convolution2dDescriptor& descriptor,
const TensorInfo& weights,
const Optional<TensorInfo>& biases,
- bool isFastMathEnabled)
+ bool isFastMathEnabled,
+ const ActivationDescriptor* activationDescriptor)
{
const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
@@ -47,6 +48,9 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
arm_compute::PadStrideInfo layerInfo = BuildArmComputePadStrideInfo(descriptor);
+ const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
+ activationDescriptor);
+
return arm_compute::CLConvolutionLayer::validate(&aclInputInfo,
&aclWeightsInfo,
optionalAclBiasesInfo,
@@ -54,7 +58,7 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
layerInfo,
arm_compute::WeightsInfo(),
aclDilationInfo,
- arm_compute::ActivationLayerInfo(),
+ activationInfo,
isFastMathEnabled);
}
@@ -91,6 +95,8 @@ ClConvolution2dWorkload::ClConvolution2dWorkload(const Convolution2dQueueDescrip
arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters);
+ const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
+
m_ConvolutionLayer.configure(&input,
m_KernelTensor.get(),
m_BiasTensor.get(),
@@ -98,7 +104,7 @@ ClConvolution2dWorkload::ClConvolution2dWorkload(const Convolution2dQueueDescrip
padStrideInfo,
arm_compute::WeightsInfo(),
aclDilationInfo,
- arm_compute::ActivationLayerInfo(),
+ activationInfo,
isFastMathEnabled);
m_ConvolutionMethod =
@@ -107,7 +113,7 @@ ClConvolution2dWorkload::ClConvolution2dWorkload(const Convolution2dQueueDescrip
output.info(),
padStrideInfo,
arm_compute::WeightsInfo(),
- arm_compute::ActivationLayerInfo(),
+ activationInfo,
arm_compute::CLScheduler::get().target(),
aclDilationInfo,
isFastMathEnabled);