aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonConvolution2dWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonConvolution2dWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonConvolution2dWorkload.cpp25
1 files changed, 22 insertions, 3 deletions
diff --git a/src/backends/neon/workloads/NeonConvolution2dWorkload.cpp b/src/backends/neon/workloads/NeonConvolution2dWorkload.cpp
index 83f761158a..d35b9685be 100644
--- a/src/backends/neon/workloads/NeonConvolution2dWorkload.cpp
+++ b/src/backends/neon/workloads/NeonConvolution2dWorkload.cpp
@@ -59,8 +59,10 @@ arm_compute::Status NeonConvolution2dWorkloadValidate(const TensorInfo& input,
}
NeonConvolution2dWorkload::NeonConvolution2dWorkload(
- const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info,
- std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
+ const Convolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager,
+ const bool isFastMathEnabled)
: BaseWorkload<Convolution2dQueueDescriptor>(descriptor, info)
{
using arm_compute::NEDirectConvolutionLayer;
@@ -97,7 +99,19 @@ NeonConvolution2dWorkload::NeonConvolution2dWorkload(
&output,
padStrideInfo,
arm_compute::WeightsInfo(),
- aclDilationInfo);
+ aclDilationInfo,
+ arm_compute::ActivationLayerInfo(),
+ isFastMathEnabled);
+
+ m_ConvolutionMethod =
+ convolutionLayer->get_convolution_method(input.info(),
+ m_KernelTensor->info(),
+ output.info(),
+ padStrideInfo,
+ arm_compute::WeightsInfo(),
+ aclDilationInfo,
+ arm_compute::ActivationLayerInfo(),
+ isFastMathEnabled);
m_ConvolutionLayer.reset(convolutionLayer.release());
@@ -120,6 +134,11 @@ void NeonConvolution2dWorkload::Execute() const
m_ConvolutionLayer->run();
}
+arm_compute::ConvolutionMethod NeonConvolution2dWorkload::GetConvolutionMethod() const
+{
+ return m_ConvolutionMethod;
+}
+
void NeonConvolution2dWorkload::FreeUnusedTensors()
{
FreeTensorIfUnused(m_KernelTensor);