aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp')
-rw-r--r--src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp22
1 files changed, 9 insertions, 13 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp b/src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
index a3c6ac9dca..cf419e752e 100644
--- a/src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
+++ b/src/armnn/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
@@ -14,8 +14,9 @@ namespace armnn
using namespace armcomputetensorutils;
ClConvolution2dUint8Workload::ClConvolution2dUint8Workload(const Convolution2dQueueDescriptor& descriptor,
- const WorkloadInfo& info)
+ const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
: Uint8Workload<Convolution2dQueueDescriptor>(descriptor, info)
+ , m_ConvolutionLayer(memoryManager)
{
// todo: check tensor shapes match
@@ -42,16 +43,11 @@ ClConvolution2dUint8Workload::ClConvolution2dUint8Workload(const Convolution2dQu
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
- BOOST_ASSERT_MSG(IsClDirectConvolution2dSupported(weightInfo, m_Data.m_Parameters),
- "Unsupported parameters for u8 convolution");
-
- m_pConvolutionLayer = std::make_unique<arm_compute::CLDirectConvolutionLayer>();
- static_cast<arm_compute::CLDirectConvolutionLayer*>(m_pConvolutionLayer.get())->configure(&input,
- &m_KernelTensor,
- optionalBias,
- &output,
- padStrideInfo);
- BOOST_ASSERT(m_pConvolutionLayer);
+ m_ConvolutionLayer.configure(&input,
+ &m_KernelTensor,
+ optionalBias,
+ &output,
+ padStrideInfo);
InitialiseArmComputeClTensorData(m_KernelTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
@@ -64,9 +60,9 @@ ClConvolution2dUint8Workload::ClConvolution2dUint8Workload(const Convolution2dQu
void ClConvolution2dUint8Workload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::GpuAcc, "ClConvolution2dUint8Workload_Execute");
- BOOST_ASSERT(m_pConvolutionLayer);
- m_pConvolutionLayer->run();
+ m_ConvolutionLayer.run();
}
} //namespace armnn
+