diff options
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r-- | src/backends/cl/workloads/ClConcatWorkload.cpp | 12 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClConcatWorkload.hpp | 6 |
2 files changed, 9 insertions, 9 deletions
diff --git a/src/backends/cl/workloads/ClConcatWorkload.cpp b/src/backends/cl/workloads/ClConcatWorkload.cpp index ee4ba6b65f..fb28946549 100644 --- a/src/backends/cl/workloads/ClConcatWorkload.cpp +++ b/src/backends/cl/workloads/ClConcatWorkload.cpp @@ -19,7 +19,7 @@ using namespace armcomputetensorutils; namespace { -size_t CalcAxis(const MergerDescriptor& desc) +size_t CalcAxis(const OriginsDescriptor& desc) { return (desc.GetNumDimensions() - desc.GetConcatAxis()) - 1; } @@ -27,7 +27,7 @@ size_t CalcAxis(const MergerDescriptor& desc) arm_compute::Status ClConcatWorkloadValidate(const std::vector<const TensorInfo*>& inputs, const TensorInfo& output, - const MergerDescriptor& descriptor) + const OriginsDescriptor& descriptor) { std::vector<arm_compute::TensorInfo> aclInputs; for (const TensorInfo* input : inputs) @@ -46,8 +46,8 @@ arm_compute::Status ClConcatWorkloadValidate(const std::vector<const TensorInfo* return arm_compute::CLConcatenateLayer::validate(aclInputPtrs, &aclOutputInfo, aclAxis); } -ClConcatWorkload::ClConcatWorkload(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) -: BaseWorkload<MergerQueueDescriptor>(descriptor, info) +ClConcatWorkload::ClConcatWorkload(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) +: BaseWorkload<ConcatQueueDescriptor>(descriptor, info) { bool allInputsAreSubtensors = true; @@ -56,7 +56,7 @@ ClConcatWorkload::ClConcatWorkload(const MergerQueueDescriptor& descriptor, cons { if (!input->GetParent()) { - // Non sub-tensor input found so we need to execute the merger function + // Non sub-tensor input found so we need to execute the concat function allInputsAreSubtensors = false; break; } @@ -64,7 +64,7 @@ ClConcatWorkload::ClConcatWorkload(const MergerQueueDescriptor& descriptor, cons if (allInputsAreSubtensors) { - // Can skip configuring the merger function since it's not executed + // Can skip configuring the concat function since it's not executed return; } diff --git a/src/backends/cl/workloads/ClConcatWorkload.hpp b/src/backends/cl/workloads/ClConcatWorkload.hpp index 106193d090..c34de9ff9a 100644 --- a/src/backends/cl/workloads/ClConcatWorkload.hpp +++ b/src/backends/cl/workloads/ClConcatWorkload.hpp @@ -14,12 +14,12 @@ namespace armnn arm_compute::Status ClConcatWorkloadValidate(const std::vector<const TensorInfo*>& inputs, const TensorInfo& output, - const MergerDescriptor& descriptor); + const OriginsDescriptor& descriptor); -class ClConcatWorkload : public BaseWorkload<MergerQueueDescriptor> +class ClConcatWorkload : public BaseWorkload<ConcatQueueDescriptor> { public: - ClConcatWorkload(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info); + ClConcatWorkload(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; |