diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 15 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 6 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 40 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefRankWorkload.hpp | 32 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 3 |
8 files changed, 102 insertions, 4 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 696c6d9dac..877d200208 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1651,6 +1651,21 @@ bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsRankSupported(const TensorInfo& input, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const +{ + IgnoreUnused(input); + // Define supported output types. + std::array<DataType,1> supportedOutputTypes = + { + DataType::Signed32, + }; + + return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported, + "Reference rank: input type not supported."); +} + bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input, const TensorInfo& output, const ReshapeDescriptor& descriptor, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 7d2bbf240e..a233082aaa 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -265,6 +265,10 @@ public: const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsRankSupported(const TensorInfo& input, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsReshapeSupported(const TensorInfo& input, const TensorInfo& output, const ReshapeDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index dcdabe17ff..cac1d1bd8a 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -549,6 +549,12 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateQuantize(const QuantizeQueu return std::make_unique<RefQuantizeWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<RefRankWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 941f1a6636..e2eab072e3 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -209,6 +209,9 @@ public: std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateRank(const RankQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index d96fa8be59..53df9a36b3 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -9,7 +9,6 @@ #include <reference/RefWorkloadFactory.hpp> -#include <test/TensorHelpers.hpp> #include <test/UnitTests.hpp> #include <boost/test/unit_test.hpp> @@ -797,6 +796,43 @@ ARMNN_AUTO_TEST_CASE(BatchNormUint8Nhwc, BatchNormUint8NhwcTest) ARMNN_AUTO_TEST_CASE(BatchNormInt16, BatchNormInt16Test) ARMNN_AUTO_TEST_CASE(BatchNormInt16Nhwc, BatchNormInt16NhwcTest) +// Rank +ARMNN_AUTO_TEST_CASE(RankDimSize1Float16, RankDimSize1Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE(RankDimSize1Float32, RankDimSize1Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(RankDimSize1QAsymmU8, RankDimSize1Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(RankDimSize1Signed32, RankDimSize1Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE(RankDimSize1QSymmS16, RankDimSize1Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(RankDimSize1QSymmS8, RankDimSize1Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize1QAsymmS8, RankDimSize1Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize1BFloat16, RankDimSize1Test<DataType::BFloat16>) + +ARMNN_AUTO_TEST_CASE(RankDimSize2Float16, RankDimSize2Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE(RankDimSize2Float32, RankDimSize2Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(RankDimSize2QAsymmU8, RankDimSize2Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(RankDimSize2Signed32, RankDimSize2Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE(RankDimSize2QSymmS16, RankDimSize2Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(RankDimSize2QSymmS8, RankDimSize2Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize2QAsymmS8, RankDimSize2Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize2BFloat16, RankDimSize2Test<DataType::BFloat16>) + +ARMNN_AUTO_TEST_CASE(RankDimSize3Float16, RankDimSize3Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE(RankDimSize3Float32, RankDimSize3Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(RankDimSize3QAsymmU8, RankDimSize3Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(RankDimSize3Signed32, RankDimSize3Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE(RankDimSize3QSymmS16, RankDimSize3Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(RankDimSize3QSymmS8, RankDimSize3Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize3QAsymmS8, RankDimSize3Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize3BFloat16, RankDimSize3Test<DataType::BFloat16>) + +ARMNN_AUTO_TEST_CASE(RankDimSize4Float16, RankDimSize4Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE(RankDimSize4Float32, RankDimSize4Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(RankDimSize4QAsymmU8, RankDimSize4Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(RankDimSize4Signed32, RankDimSize4Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE(RankDimSize4QSymmS16, RankDimSize4Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(RankDimSize4QSymmS8, RankDimSize4Test<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize4QAsymmS8, RankDimSize4Test<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE(RankDimSize4BFloat16, RankDimSize4Test<DataType::BFloat16>) + // Resize Bilinear - NCHW ARMNN_AUTO_TEST_CASE(SimpleResizeBilinear, SimpleResizeBilinearTest<DataType::Float32>, diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index d51db365cc..937a32029e 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright © 2017 Arm Ltd. All rights reserved. +# Copyright © 2017 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -129,6 +129,7 @@ list(APPEND armnnRefBackendWorkloads_sources RefQuantizeWorkload.hpp RefQLstmWorkload.cpp RefQLstmWorkload.hpp + RefRankWorkload.hpp RefReshapeWorkload.cpp RefReshapeWorkload.hpp RefResizeBilinearWorkload.cpp diff --git a/src/backends/reference/workloads/RefRankWorkload.hpp b/src/backends/reference/workloads/RefRankWorkload.hpp new file mode 100644 index 0000000000..780d3be533 --- /dev/null +++ b/src/backends/reference/workloads/RefRankWorkload.hpp @@ -0,0 +1,32 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +#include "RefWorkloadUtils.hpp" + +namespace armnn +{ + +struct RefRankWorkload : public BaseWorkload<RankQueueDescriptor> +{ +public: + using BaseWorkload<RankQueueDescriptor>::BaseWorkload; + virtual void Execute() const override + { + const int32_t rank = static_cast<int32_t>(GetTensorInfo(m_Data.m_Inputs[0]).GetNumDimensions()); + + std::memcpy(GetOutputTensorData<void>(0, m_Data), &rank, sizeof(int32_t)); + } +}; + +} //namespace armnn + + + + diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 951e3a1e29..fc47cff84f 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -51,6 +51,7 @@ #include "RefPreluWorkload.hpp" #include "RefQLstmWorkload.hpp" #include "RefQuantizeWorkload.hpp" +#include "RefRankWorkload.hpp" #include "RefReshapeWorkload.hpp" #include "RefResizeBilinearWorkload.hpp" #include "RefResizeWorkload.hpp" |