diff options
Diffstat (limited to 'src/backends/cl/workloads/ClConvolution2dWorkload.cpp')
-rw-r--r-- | src/backends/cl/workloads/ClConvolution2dWorkload.cpp | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/src/backends/cl/workloads/ClConvolution2dWorkload.cpp b/src/backends/cl/workloads/ClConvolution2dWorkload.cpp index 7b52f2784f..50cb9ded37 100644 --- a/src/backends/cl/workloads/ClConvolution2dWorkload.cpp +++ b/src/backends/cl/workloads/ClConvolution2dWorkload.cpp @@ -25,7 +25,8 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input, const Convolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional<TensorInfo>& biases, - bool isFastMathEnabled) + bool isFastMathEnabled, + const ActivationDescriptor* activationDescriptor) { const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout); const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout); @@ -47,6 +48,9 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input, arm_compute::PadStrideInfo layerInfo = BuildArmComputePadStrideInfo(descriptor); + const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo( + activationDescriptor); + return arm_compute::CLConvolutionLayer::validate(&aclInputInfo, &aclWeightsInfo, optionalAclBiasesInfo, @@ -54,7 +58,7 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input, layerInfo, arm_compute::WeightsInfo(), aclDilationInfo, - arm_compute::ActivationLayerInfo(), + activationInfo, isFastMathEnabled); } @@ -91,6 +95,8 @@ ClConvolution2dWorkload::ClConvolution2dWorkload(const Convolution2dQueueDescrip arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters); + const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor); + m_ConvolutionLayer.configure(&input, m_KernelTensor.get(), m_BiasTensor.get(), @@ -98,7 +104,7 @@ ClConvolution2dWorkload::ClConvolution2dWorkload(const Convolution2dQueueDescrip padStrideInfo, arm_compute::WeightsInfo(), aclDilationInfo, - arm_compute::ActivationLayerInfo(), + activationInfo, isFastMathEnabled); m_ConvolutionMethod = @@ -107,7 +113,7 @@ ClConvolution2dWorkload::ClConvolution2dWorkload(const Convolution2dQueueDescrip output.info(), padStrideInfo, arm_compute::WeightsInfo(), - arm_compute::ActivationLayerInfo(), + activationInfo, arm_compute::CLScheduler::get().target(), aclDilationInfo, isFastMathEnabled); |