diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 55 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.hpp | 2 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 2 | ||||
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 70 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 34 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/Concatenate.cpp (renamed from src/backends/reference/workloads/Merger.cpp) | 6 | ||||
-rw-r--r-- | src/backends/reference/workloads/Concatenate.hpp (renamed from src/backends/reference/workloads/Merger.hpp) | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConcatWorkload.cpp | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConcatWorkload.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 2 |
13 files changed, 95 insertions, 96 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 2adcb1099d..9a691a6fa7 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -316,18 +316,38 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, - const OriginsDescriptor& descriptor, + const ConcatDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - ARMNN_NO_DEPRECATE_WARN_BEGIN - return IsMergerSupported(inputs, output, descriptor, reasonIfUnsupported); - ARMNN_NO_DEPRECATE_WARN_END + ignore_unused(descriptor); + + bool supported = true; + std::array<DataType,3> supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference concatenation: output type not supported"); + for (const TensorInfo* input : inputs) + { + supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported, + "Reference concatenation: input type not supported"); + + supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported, + "Reference concatenation: input and output types mismatched."); + } + + return supported; } bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - std::array<DataType,4> supportedTypes = { + std::array<DataType,4> supportedTypes = + { DataType::Float32, DataType::Signed32, DataType::QuantisedAsymm8, @@ -815,31 +835,10 @@ bool RefLayerSupport::IsMeanSupported(const TensorInfo& input, bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, - const OriginsDescriptor& descriptor, + const MergerDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(descriptor); - - bool supported = true; - std::array<DataType,3> supportedTypes = - { - DataType::Float32, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, - "Reference concatenation: output type not supported"); - for (const TensorInfo* input : inputs) - { - supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported, - "Reference concatenation: input type not supported"); - - supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported, - "Reference concatenation: input and output types mismatched."); - } - - return supported; + return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 944061d5a6..8850c6e105 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -38,7 +38,7 @@ public: bool IsConcatSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, - const OriginsDescriptor& descriptor, + const ConcatDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsConstantSupported(const TensorInfo& output, @@ -170,7 +170,7 @@ public: ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead") bool IsMergerSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, - const OriginsDescriptor& descriptor, + const MergerDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsMemCopySupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 1243328852..a21becdb13 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -245,7 +245,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2Nor return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info); } -std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConcat(const MergerQueueDescriptor& descriptor, +std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const { if (IsFloat16(info)) diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 985b634d77..78f6bab92c 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -115,7 +115,7 @@ public: std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr<IWorkload> CreateConcat(const MergerQueueDescriptor& descriptor, + std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const override; std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 1c7f8dc22c..9a4cf146c6 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -21,7 +21,7 @@ BACKEND_SOURCES := \ workloads/FullyConnected.cpp \ workloads/Gather.cpp \ workloads/Mean.cpp \ - workloads/Merger.cpp \ + workloads/Concatenate.cpp \ workloads/Pad.cpp \ workloads/Pooling2d.cpp \ workloads/RefActivationWorkload.cpp \ diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 3f4cc75fea..a96d656d9b 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -473,28 +473,28 @@ BOOST_AUTO_TEST_CASE(CreateSplitterUint8Workload) RefCreateSplitterWorkloadTest<RefSplitterUint8Workload, armnn::DataType::QuantisedAsymm8>(); } -template <typename SplitterWorkloadType, typename MergerWorkloadType, armnn::DataType DataType> -static void RefCreateSplitterMergerWorkloadTest() +template <typename SplitterWorkloadType, typename ConcatWorkloadType, armnn::DataType DataType> +static void RefCreateSplitterConcatWorkloadTest() { // Tests that it is possible to decide which output of the splitter layer - // should be lined to which input of the merger layer. + // should be lined to which input of the concat layer. // We tested that is is possible to specify 0th output - // of the splitter to be the 1st input to the merger and the 1st output of the splitter to be 0th input - // of the merger. + // of the splitter to be the 1st input to the concat and the 1st output of the splitter to be 0th input + // of the concat. Graph graph; RefWorkloadFactory factory; - auto workloads = CreateSplitterMergerWorkloadTest<SplitterWorkloadType, MergerWorkloadType, DataType> - (factory, graph); + auto workloads = CreateSplitterConcatWorkloadTest<SplitterWorkloadType, ConcatWorkloadType, DataType> + (factory, graph); auto wlSplitter = std::move(workloads.first); - auto wlMerger = std::move(workloads.second); + auto wlConcat = std::move(workloads.second); //Checks that the index of inputs/outputs matches what we declared on InputDescriptor construction. armnn::CpuTensorHandle* sOut0 = dynamic_cast<armnn::CpuTensorHandle*>(wlSplitter->GetData().m_Outputs[0]); armnn::CpuTensorHandle* sOut1 = dynamic_cast<armnn::CpuTensorHandle*>(wlSplitter->GetData().m_Outputs[1]); - armnn::CpuTensorHandle* mIn0 = dynamic_cast<armnn::CpuTensorHandle*>(wlMerger->GetData().m_Inputs[0]); - armnn::CpuTensorHandle* mIn1 = dynamic_cast<armnn::CpuTensorHandle*>(wlMerger->GetData().m_Inputs[1]); + armnn::CpuTensorHandle* mIn0 = dynamic_cast<armnn::CpuTensorHandle*>(wlConcat->GetData().m_Inputs[0]); + armnn::CpuTensorHandle* mIn1 = dynamic_cast<armnn::CpuTensorHandle*>(wlConcat->GetData().m_Inputs[1]); BOOST_TEST(sOut0); BOOST_TEST(sOut1); @@ -506,14 +506,14 @@ static void RefCreateSplitterMergerWorkloadTest() BOOST_TEST(validDataPointers); } -BOOST_AUTO_TEST_CASE(CreateSplitterMergerFloat32) +BOOST_AUTO_TEST_CASE(CreateSplitterConcatFloat32) { - RefCreateSplitterMergerWorkloadTest<RefSplitterFloat32Workload, RefConcatWorkload, DataType::Float32>(); + RefCreateSplitterConcatWorkloadTest<RefSplitterFloat32Workload, RefConcatWorkload, DataType::Float32>(); } -BOOST_AUTO_TEST_CASE(CreateSplitterMergerUint8) +BOOST_AUTO_TEST_CASE(CreateSplitterConcatUint8) { - RefCreateSplitterMergerWorkloadTest<RefSplitterUint8Workload, RefConcatWorkload, DataType::QuantisedAsymm8>(); + RefCreateSplitterConcatWorkloadTest<RefSplitterUint8Workload, RefConcatWorkload, DataType::QuantisedAsymm8>(); } template <typename SplitterWorkloadType, typename ActivationWorkloadType, armnn::DataType DataType> @@ -671,13 +671,13 @@ BOOST_AUTO_TEST_CASE(CreateReshapeUint8Workload) RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QuantisedAsymm8>(); } -template <typename MergerWorkloadType, armnn::DataType DataType> -static void RefCreateMergerWorkloadTest(const armnn::TensorShape& outputShape, +template <typename ConcatWorkloadType, armnn::DataType DataType> +static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape, unsigned int concatAxis) { Graph graph; RefWorkloadFactory factory; - auto workload = CreateMergerWorkloadTest<MergerWorkloadType, DataType>(factory, graph, outputShape, concatAxis); + auto workload = CreateConcatWorkloadTest<ConcatWorkloadType, DataType>(factory, graph, outputShape, concatAxis); CheckInputsOutput(std::move(workload), TensorInfo({ 2, 3, 2, 5 }, DataType), @@ -685,49 +685,49 @@ static void RefCreateMergerWorkloadTest(const armnn::TensorShape& outputShape, TensorInfo(outputShape, DataType)); } -BOOST_AUTO_TEST_CASE(CreateMergerDim0Float32Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim0Float32Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 4, 3, 2, 5 }, 0); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 4, 3, 2, 5 }, 0); } -BOOST_AUTO_TEST_CASE(CreateMergerDim0Uint8Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim0Uint8Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 4, 3, 2, 5 }, 0); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 4, 3, 2, 5 }, 0); } -BOOST_AUTO_TEST_CASE(CreateMergerDim0Uint16Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim0Uint16Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedSymm16>({ 4, 3, 2, 5 }, 0); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedSymm16>({ 4, 3, 2, 5 }, 0); } -BOOST_AUTO_TEST_CASE(CreateMergerDim1Float32Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim1Float32Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 6, 2, 5 }, 1); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 6, 2, 5 }, 1); } -BOOST_AUTO_TEST_CASE(CreateMergerDim1Uint8Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim1Uint8Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 2, 6, 2, 5 }, 1); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 2, 6, 2, 5 }, 1); } -BOOST_AUTO_TEST_CASE(CreateMergerDim2Float32Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim2Float32Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 4, 5 }, 2); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 4, 5 }, 2); } -BOOST_AUTO_TEST_CASE(CreateMergerDim2Uint8Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim2Uint8Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 2, 3, 4, 5 }, 2); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 2, 3, 4, 5 }, 2); } -BOOST_AUTO_TEST_CASE(CreateMergerDim3Float32Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim3Float32Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 }, 3); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 }, 3); } -BOOST_AUTO_TEST_CASE(CreateMergerDim3Uint8Workload) +BOOST_AUTO_TEST_CASE(CreateConcatDim3Uint8Workload) { - RefCreateMergerWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 2, 3, 2, 10 }, 3); + RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QuantisedAsymm8>({ 2, 3, 2, 10 }, 3); } template <typename ConstantWorkloadType, armnn::DataType DataType> diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 6dacfab4d1..2b7fb774b5 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -7,7 +7,7 @@ #include <backendsCommon/test/DetectionPostProcessTestImpl.hpp> #include <backendsCommon/test/GatherEndToEndTestImpl.hpp> -#include <backendsCommon/test/MergerTestImpl.hpp> +#include <backendsCommon/test/ConcatTestImpl.hpp> #include <backendsCommon/test/ArithmeticTestImpl.hpp> #include <backendsCommon/test/SplitterEndToEndTestImpl.hpp> @@ -396,44 +396,44 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) expectedOutput); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim0Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim0Test) { - MergerDim0EndToEnd<armnn::DataType::Float32>(defaultBackends); + ConcatDim0EndToEnd<armnn::DataType::Float32>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim0Uint8Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim0Uint8Test) { - MergerDim0EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); + ConcatDim0EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim1Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim1Test) { - MergerDim1EndToEnd<armnn::DataType::Float32>(defaultBackends); + ConcatDim1EndToEnd<armnn::DataType::Float32>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim1Uint8Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim1Uint8Test) { - MergerDim1EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); + ConcatDim1EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim2Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim2Test) { - MergerDim2EndToEnd<armnn::DataType::Float32>(defaultBackends); + ConcatDim2EndToEnd<armnn::DataType::Float32>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim2Uint8Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim2Uint8Test) { - MergerDim2EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); + ConcatDim2EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim3Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim3Test) { - MergerDim3EndToEnd<armnn::DataType::Float32>(defaultBackends); + ConcatDim3EndToEnd<armnn::DataType::Float32>(defaultBackends); } -BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim3Uint8Test) +BOOST_AUTO_TEST_CASE(RefConcatEndToEndDim3Uint8Test) { - MergerDim3EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); + ConcatDim3EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); } BOOST_AUTO_TEST_CASE(RefGatherFloatTest) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 508dfdc293..3db0314346 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -28,8 +28,8 @@ list(APPEND armnnRefBackendWorkloads_sources Gather.hpp LstmUtils.hpp Maximum.hpp - Merger.hpp - Merger.cpp + Concatenate.hpp + Concatenate.cpp Minimum.hpp Pad.cpp Pad.hpp diff --git a/src/backends/reference/workloads/Merger.cpp b/src/backends/reference/workloads/Concatenate.cpp index e0b70ee5cb..bb55424c0c 100644 --- a/src/backends/reference/workloads/Merger.cpp +++ b/src/backends/reference/workloads/Concatenate.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT // -#include "Merger.hpp" +#include "Concatenate.hpp" #include "RefWorkloadUtils.hpp" #include "Decoders.hpp" #include "Encoders.hpp" @@ -11,7 +11,7 @@ namespace armnn { -void Merger(const MergerQueueDescriptor& data) +void Concatenate(const ConcatQueueDescriptor &data) { const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]); @@ -34,7 +34,7 @@ void Merger(const MergerQueueDescriptor& data) for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx) { - MergerQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx]; + ConcatQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx]; //Split view extents are defined by the size of (the corresponding) input tensor. const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[viewIdx]); diff --git a/src/backends/reference/workloads/Merger.hpp b/src/backends/reference/workloads/Concatenate.hpp index eaa154d25a..ac82a87af3 100644 --- a/src/backends/reference/workloads/Merger.hpp +++ b/src/backends/reference/workloads/Concatenate.hpp @@ -10,5 +10,5 @@ namespace armnn { -void Merger(const MergerQueueDescriptor& data); +void Concatenate(const ConcatQueueDescriptor &data); } //namespace armnn diff --git a/src/backends/reference/workloads/RefConcatWorkload.cpp b/src/backends/reference/workloads/RefConcatWorkload.cpp index 9abddc0ff8..152eae93b3 100644 --- a/src/backends/reference/workloads/RefConcatWorkload.cpp +++ b/src/backends/reference/workloads/RefConcatWorkload.cpp @@ -5,7 +5,7 @@ #include "RefConcatWorkload.hpp" -#include "Merger.hpp" +#include "Concatenate.hpp" #include "Profiling.hpp" @@ -15,7 +15,7 @@ namespace armnn void RefConcatWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConcatWorkload_Execute"); - Merger(m_Data); + Concatenate(m_Data); } } //namespace armnn diff --git a/src/backends/reference/workloads/RefConcatWorkload.hpp b/src/backends/reference/workloads/RefConcatWorkload.hpp index 9fc9c7ef7e..7d0b6b7cd1 100644 --- a/src/backends/reference/workloads/RefConcatWorkload.hpp +++ b/src/backends/reference/workloads/RefConcatWorkload.hpp @@ -11,10 +11,10 @@ namespace armnn { -class RefConcatWorkload : public BaseWorkload<MergerQueueDescriptor> +class RefConcatWorkload : public BaseWorkload<ConcatQueueDescriptor> { public: - using BaseWorkload<MergerQueueDescriptor>::BaseWorkload; + using BaseWorkload<ConcatQueueDescriptor>::BaseWorkload; virtual void Execute() const override; }; diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 20649d93ce..6ffec2bd06 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -38,7 +38,7 @@ #include "RefPooling2dUint8Workload.hpp" #include "BatchNormImpl.hpp" #include "Activation.hpp" -#include "Merger.hpp" +#include "Concatenate.hpp" #include "RefSpaceToBatchNdWorkload.hpp" #include "RefSplitterFloat32Workload.hpp" #include "RefStridedSliceWorkload.hpp" |