aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClSplitterWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads/ClSplitterWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClSplitterWorkload.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/backends/cl/workloads/ClSplitterWorkload.cpp b/src/backends/cl/workloads/ClSplitterWorkload.cpp
index 9bbbcab797..296e0a3dde 100644
--- a/src/backends/cl/workloads/ClSplitterWorkload.cpp
+++ b/src/backends/cl/workloads/ClSplitterWorkload.cpp
@@ -9,6 +9,7 @@
#include <aclCommon/ArmComputeTensorUtils.hpp>
#include <aclCommon/ArmComputeUtils.hpp>
+#include <arm_compute/runtime/CL/functions/CLSplit.h>
#include <backendsCommon/CpuTensorHandle.hpp>
#include <cl/ClTensorHandle.hpp>
@@ -84,7 +85,6 @@ ClSplitterWorkload::ClSplitterWorkload(const SplitterQueueDescriptor& descriptor
}
// Create the layer function
- m_Layer.reset(new arm_compute::CLSplit());
// Configure input and output tensors
std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor.m_Parameters, m_Data.m_Inputs[0]->GetShape());
@@ -94,10 +94,13 @@ ClSplitterWorkload::ClSplitterWorkload(const SplitterQueueDescriptor& descriptor
}
unsigned int aclAxis = CalcAclAxis(descriptor.m_Parameters.GetNumDimensions(), *splitAxis.begin());
- m_Layer->configure(&input, aclOutputs, aclAxis);
+ auto layer = std::make_unique<arm_compute::CLSplit>();
+ layer->configure(&input, aclOutputs, aclAxis);
// Prepare
- m_Layer->prepare();
+ layer->prepare();
+
+ m_Layer = std::move(layer);
}
void ClSplitterWorkload::Execute() const