diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-04-12 22:07:09 +0100 |
---|---|---|
committer | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-05-03 21:24:52 +0100 |
commit | b2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c (patch) | |
tree | 74ee2c47e76fddff249a9f25db01960a52eb2360 /src/backends/reference | |
parent | 04cd60384f5fc8455bb7cf64416daa7b001754d1 (diff) | |
download | armnn-b2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c.tar.gz |
IVGCVSW-6856 Add GATHERNd FrontEnd and Ref Implementation
* Add front end
* Add reference workload
* Add unit tests
* Add EndToEnd test
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I4cebd17b18476df86162e2dda3366c10e80bd2f8
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 37 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 5 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 1 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 31 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefGatherNdWorkload.cpp | 91 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefGatherNdWorkload.hpp | 24 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefGatherWorkload.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 1 |
11 files changed, 210 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index b55adfa958..3bc4affb28 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -212,6 +212,11 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, infos[2], *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)), reasonIfUnsupported); + case LayerType::GatherNd: + return IsGatherNdSupported(infos[0], + infos[1], + infos[2], + reasonIfUnsupported); case LayerType::Input: return IsInputSupported(infos[0], reasonIfUnsupported); case LayerType::InstanceNormalization: @@ -1591,6 +1596,38 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0, + const armnn::TensorInfo& input1, + const armnn::TensorInfo& output, + armnn::Optional<std::string&> reasonIfUnsupported) const +{ + bool supported = true; + std::array<DataType,7> supportedTypes = + { + DataType::BFloat16, + DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS16, + DataType::Signed32 + }; + + supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, + "Reference GatherNd: input type not supported"); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference GatherNd: output type not supported"); + + supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported, + "Reference GatherNd: indices (input1) type not supported"); + + supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported, + "Reference GatherNd: input and output types not matching"); + + return supported; +} + bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0, const armnn::TensorInfo& input1, const armnn::TensorInfo& output, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index b787d25fbd..98770ad64a 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -169,6 +169,11 @@ public: const FullyConnectedDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsGatherNdSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; + bool IsGatherSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 9db81fc9cb..2d956582db 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -353,6 +353,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type, auto gatherQueueDescriptor = PolymorphicDowncast<const GatherQueueDescriptor*>(&descriptor); return std::make_unique<RefGatherWorkload>(*gatherQueueDescriptor, info); } + case LayerType::GatherNd: + { + auto gatherNdQueueDescriptor = PolymorphicDowncast<const GatherNdQueueDescriptor*>(&descriptor); + return std::make_unique<RefGatherNdWorkload>(*gatherNdQueueDescriptor, info); + } case LayerType::Input: { auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor); diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 33e161c6d8..d9a5a1d32c 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -73,6 +73,7 @@ BACKEND_SOURCES := \ workloads/RefFillWorkload.cpp \ workloads/RefFloorWorkload.cpp \ workloads/RefFullyConnectedWorkload.cpp \ + workloads/RefGatherNdWorkload.cpp \ workloads/RefGatherWorkload.cpp \ workloads/RefInstanceNormalizationWorkload.cpp \ workloads/RefL2NormalizationWorkload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index e1c2e2f2a7..2ed5ad812c 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -19,6 +19,7 @@ #include <backendsCommon/test/FillEndToEndTestImpl.hpp> #include <backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp> #include <backendsCommon/test/GatherEndToEndTestImpl.hpp> +#include <backendsCommon/test/GatherNdEndToEndTestImpl.hpp> #include <backendsCommon/test/InstanceNormalizationEndToEndTestImpl.hpp> #include <backendsCommon/test/LogSoftmaxEndToEndTestImpl.hpp> #include <backendsCommon/test/PreluEndToEndTestImpl.hpp> @@ -720,6 +721,36 @@ TEST_CASE("RefGatherMultiDimInt16Test") GatherMultiDimEndToEnd<armnn::DataType::QSymmS16>(defaultBackends); } +TEST_CASE("RefGatherNdFloatTest") +{ + GatherNdEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +TEST_CASE("RefGatherNdUint8Test") +{ + GatherNdEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends); +} + +TEST_CASE("RefGatherNdInt16Test") +{ + GatherNdEndToEnd<armnn::DataType::QSymmS16>(defaultBackends); +} + +TEST_CASE("RefGatherNdMultiDimFloatTest") +{ + GatherNdMultiDimEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +TEST_CASE("RefGatherNdMultiDimUint8Test") +{ + GatherNdMultiDimEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends); +} + +TEST_CASE("RefGatherNdMultiDimInt16Test") +{ + GatherNdMultiDimEndToEnd<armnn::DataType::QSymmS16>(defaultBackends); +} + // DepthToSpace TEST_CASE("DephtToSpaceEndToEndNchwFloat32") { diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 9dca621e13..496b11db91 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -2155,6 +2155,18 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesUint8, GatherMu ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt16, GatherMultiDimParamsMultiDimIndicesInt16Test) ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt32, GatherMultiDimParamsMultiDimIndicesInt32Test) + +// GatherNd +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd2dFloat32, SimpleGatherNd2dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd3dFloat32, SimpleGatherNd3dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd4dFloat32, SimpleGatherNd4dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd2dInt8, SimpleGatherNd2dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd3dInt8, SimpleGatherNd3dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd4dInt8, SimpleGatherNd4dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd2dInt32, SimpleGatherNd2dTest<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd3dInt32, SimpleGatherNd3dTest<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd4dInt32, SimpleGatherNd4dTest<DataType::Signed32>) + // Abs ARMNN_AUTO_TEST_CASE_WITH_THF(Abs2d, Abs2dTest<DataType::Float32>) ARMNN_AUTO_TEST_CASE_WITH_THF(Abs3d, Abs3dTest<DataType::Float32>) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index c18342fb73..b1f6d8b250 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -118,6 +118,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefFloorWorkload.hpp RefFullyConnectedWorkload.cpp RefFullyConnectedWorkload.hpp + RefGatherNdWorkload.cpp + RefGatherNdWorkload.hpp RefGatherWorkload.cpp RefGatherWorkload.hpp RefInstanceNormalizationWorkload.cpp diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp new file mode 100644 index 0000000000..4c6b559943 --- /dev/null +++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp @@ -0,0 +1,91 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefGatherNdWorkload.hpp" + +#include "Gather.hpp" +#include "Profiling.hpp" +#include "RefWorkloadUtils.hpp" +#include "backendsCommon/WorkloadUtils.hpp" + +namespace armnn +{ + +void RefGatherNdWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefGatherNdWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefGatherNdWorkload_Execute"); + + const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); + + std::unique_ptr<Decoder<float>> params_decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->Map()); + + const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map()); + std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements()); + + std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map()); + + std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(inputInfo0, inputInfo1); + + /// Calculate flattened indices: flattenedIndices = indices * flattenedCoefficients + // Calculate the flattened coefficients to use in the multiplication + // to calculate the flattened indices needed by gather + TensorShape paramsShape = inputInfo0.GetShape(); + std::vector<unsigned int> flattenedCoeff(keyIndices["ND"], 1); + for (unsigned int i = 1; i < keyIndices["ND"]; ++i) + { + flattenedCoeff[i-1] = paramsShape[i]; + } + for (unsigned int i = keyIndices["ND"]-1; i > 0; --i) + { + flattenedCoeff[i-1] *= flattenedCoeff[i]; + } + + // Prepare the vector to store the output of the matrix multiplication, + // which will represent the flattened indices needed by gather + armnn::TensorInfo flattenedIndices_Info = inputInfo1; + flattenedIndices_Info.SetShape({ keyIndices["W"] }); + std::vector<int32_t> flattenedIndices(flattenedIndices_Info.GetNumElements(), 0); + + // Multiplication to calculate the flattened indices, which are the indices needed by gather. + for (unsigned int i = 0; i < keyIndices["W"]; ++i) + { + for (unsigned int j = 0; j < keyIndices["ND"]; ++j) + { + flattenedIndices[i] += indices[i * keyIndices["ND"] + j] * static_cast<int32_t>(flattenedCoeff[j]); + } + } + + /// Call Gather with adequate shapes + // Reshape params into {K, C} + armnn::TensorInfo params_K_C_Info = inputInfo0; + params_K_C_Info.SetShape({ keyIndices["K"], keyIndices["C"] }); + + // Reshape indices into {N, W} + armnn::TensorInfo indices_N_W_Info = inputInfo1; + indices_N_W_Info.SetShape({ keyIndices["N"], keyIndices["W"] }); + + // Reshape output to have the shape given by gather {N, W, C} + // (the original outputInfo has the shape given by gatherNd) + armnn::TensorInfo outputGather_Info = outputInfo; + outputGather_Info.SetShape({ keyIndices["N"], keyIndices["W"], keyIndices["C"] }); + + // output_gather = gather(params_K_C, indices_N_W) + Gather(params_K_C_Info, indices_N_W_Info, outputGather_Info, + *params_decoderPtr, flattenedIndices.data(), *output_encoderPtr, 0); +} + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.hpp b/src/backends/reference/workloads/RefGatherNdWorkload.hpp new file mode 100644 index 0000000000..a0d91586cc --- /dev/null +++ b/src/backends/reference/workloads/RefGatherNdWorkload.hpp @@ -0,0 +1,24 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "RefBaseWorkload.hpp" + +namespace armnn +{ + +class RefGatherNdWorkload : public RefBaseWorkload<GatherNdQueueDescriptor> +{ +public: + using RefBaseWorkload<GatherNdQueueDescriptor>::RefBaseWorkload; + void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; + +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp index be3274f00a..8ad36e43b4 100644 --- a/src/backends/reference/workloads/RefGatherWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 700a1d6184..3e83304616 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -42,6 +42,7 @@ #include "RefFillWorkload.hpp" #include "RefFloorWorkload.hpp" #include "RefFullyConnectedWorkload.hpp" +#include "RefGatherNdWorkload.hpp" #include "RefGatherWorkload.hpp" #include "RefInstanceNormalizationWorkload.hpp" #include "RefL2NormalizationWorkload.hpp" |