diff options
Diffstat (limited to 'src/backends/neon/workloads')
-rw-r--r-- | src/backends/neon/workloads/NeonConcatWorkload.cpp | 12 | ||||
-rw-r--r-- | src/backends/neon/workloads/NeonConcatWorkload.hpp | 8 |
2 files changed, 10 insertions, 10 deletions
diff --git a/src/backends/neon/workloads/NeonConcatWorkload.cpp b/src/backends/neon/workloads/NeonConcatWorkload.cpp index 91f81090ce..8ea535b40a 100644 --- a/src/backends/neon/workloads/NeonConcatWorkload.cpp +++ b/src/backends/neon/workloads/NeonConcatWorkload.cpp @@ -19,7 +19,7 @@ using namespace armcomputetensorutils; namespace { -size_t CalcAxis(const armnn::MergerDescriptor& desc) +size_t CalcAxis(const armnn::OriginsDescriptor& desc) { return (desc.GetNumDimensions() - desc.GetConcatAxis()) - 1; } @@ -27,7 +27,7 @@ size_t CalcAxis(const armnn::MergerDescriptor& desc) arm_compute::Status NeonConcatWorkloadValidate(const std::vector<const TensorInfo*>& inputs, const TensorInfo& output, - const MergerDescriptor& descriptor) + const OriginsDescriptor& descriptor) { std::vector<arm_compute::TensorInfo> aclInputs; @@ -48,8 +48,8 @@ arm_compute::Status NeonConcatWorkloadValidate(const std::vector<const TensorInf } NeonConcatWorkload::NeonConcatWorkload( -const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload<MergerQueueDescriptor>(descriptor, info) +const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload<ConcatQueueDescriptor>(descriptor, info) { bool allInputsAreSubtensors = true; @@ -58,7 +58,7 @@ const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) { if (!input->GetParent()) { - // Non sub-tensor input found so we need to execute the merger function + // Non sub-tensor input found so we need to execute the concat function allInputsAreSubtensors = false; break; } @@ -66,7 +66,7 @@ const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) if (allInputsAreSubtensors) { - // Can skip configuring the merger function since it's not executed + // Can skip configuring the concat function since it's not executed return; } diff --git a/src/backends/neon/workloads/NeonConcatWorkload.hpp b/src/backends/neon/workloads/NeonConcatWorkload.hpp index e5a8d15055..bf0733b431 100644 --- a/src/backends/neon/workloads/NeonConcatWorkload.hpp +++ b/src/backends/neon/workloads/NeonConcatWorkload.hpp @@ -17,14 +17,14 @@ namespace armnn { arm_compute::Status NeonConcatWorkloadValidate(const std::vector<const TensorInfo*>& inputs, const TensorInfo& output, - const MergerDescriptor& descriptor); + const OriginsDescriptor& descriptor); -class NeonConcatWorkload : public BaseWorkload<MergerQueueDescriptor> +class NeonConcatWorkload : public BaseWorkload<ConcatQueueDescriptor> { public: - NeonConcatWorkload(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info); + NeonConcatWorkload(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info); - using BaseWorkload<MergerQueueDescriptor>::BaseWorkload; + using BaseWorkload<ConcatQueueDescriptor>::BaseWorkload; void Execute() const override; private: |