diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 32 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 6 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 11 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 1 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 14 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefTransposeWorkload.cpp | 35 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefTransposeWorkload.hpp | 35 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 1 |
10 files changed, 139 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 8f1f170c5c..25334c3b52 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1388,9 +1388,10 @@ bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input, bool supported = true; // Define supported output and inputs types. - std::array<DataType,3> supportedTypes = + std::array<DataType, 4> supportedTypes = { DataType::Float32, + DataType::Float16, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1912,4 +1913,33 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input, + const TensorInfo& output, + const TransposeDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + ignore_unused(descriptor); + bool supported = true; + + // Define supported output and inputs types. + std::array<DataType, 4> supportedTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QAsymmU8, + DataType::QSymmS16 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference transpose: input is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference transpose: output is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference transpose: input and output types are mismatched."); + + return supported; +} + } // namespace armnn diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 1551a55694..27f3f81489 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -318,6 +318,12 @@ public: const TensorInfo& weights, const Optional<TensorInfo>& biases, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + + bool IsTransposeSupported(const TensorInfo& input, + const TensorInfo& output, + const TransposeDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + }; } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 02dbbabf9f..2a415bfbf0 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -561,6 +561,17 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const Subtracti return std::make_unique<RefSubtractionWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + if (IsQSymmS16(info)) + { + return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info); + } + return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload, + NullWorkload, NullWorkload, NullWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTransposeConvolution2d( const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index b5b9b0faf0..030ce6f03d 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -236,6 +236,9 @@ public: std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 1987bd59fa..010d54871a 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -85,6 +85,7 @@ BACKEND_SOURCES := \ workloads/RefStridedSliceWorkload.cpp \ workloads/RefSplitterWorkload.cpp \ workloads/RefTransposeConvolution2dWorkload.cpp \ + workloads/RefTransposeWorkload.cpp \ workloads/Resize.cpp \ workloads/Slice.cpp \ workloads/SpaceToBatchNd.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index d5c67ef6c7..ed2b995bd5 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1460,6 +1460,20 @@ ARMNN_AUTO_TEST_CASE(Slice3dInt16, Slice3dInt16Test) ARMNN_AUTO_TEST_CASE(Slice2dInt16, Slice2dInt16Test) ARMNN_AUTO_TEST_CASE(Slice1dInt16, Slice1dInt16Test) +// Transpose +ARMNN_AUTO_TEST_CASE(SimpleTransposeFloat32, SimpleTransposeTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet1Test, TransposeValueSet1Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet2Test, TransposeValueSet2Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet3Test, TransposeValueSet3Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE(SimpleTransposeQASymm8, SimpleTransposeTest<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(TransposeQASymm8ValueSet1Test, TransposeValueSet1Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(TransposeQASymm8ValueSet2Test, TransposeValueSet2Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(TransposeQASymm8ValueSet3Test, TransposeValueSet3Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE(SimpleTransposeQSymm16, SimpleTransposeTest<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(TransposeQSymm16ValueSet1Test, TransposeValueSet1Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(TransposeQSymm16ValueSet2Test, TransposeValueSet2Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE(TransposeQSymm16ValueSet3Test, TransposeValueSet3Test<DataType::QSymmS16>) + // TransposeConvolution2d ARMNN_AUTO_TEST_CASE(SimpleTransposeConvolution2dFloatNchw, SimpleTransposeConvolution2dTest<DataType::Float32, DataType::Float32>, diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 6795204d59..b2d8938745 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -141,6 +141,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefStridedSliceWorkload.hpp RefTransposeConvolution2dWorkload.cpp RefTransposeConvolution2dWorkload.hpp + RefTransposeWorkload.cpp + RefTransposeWorkload.hpp RefWorkloads.hpp RefWorkloadUtils.hpp Resize.cpp diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp new file mode 100644 index 0000000000..6bdfb2111d --- /dev/null +++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp @@ -0,0 +1,35 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefTransposeWorkload.hpp" +#include "RefWorkloadUtils.hpp" + +#include <armnnUtils/Transpose.hpp> + +#include <ResolveType.hpp> + +namespace armnn +{ + +template <armnn::DataType DataType> +void RefTransposeWorkload<DataType>::Execute() const +{ + using T = ResolveType<DataType>; + + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); + + const ITensorHandle* src = m_Data.m_Inputs[0]; + ITensorHandle* dst = m_Data.m_Outputs[0]; + const PermutationVector& mappings = m_Data.m_Parameters.m_DimMappings; + + armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T)); +} + +template class RefTransposeWorkload<DataType::Float16>; +template class RefTransposeWorkload<DataType::Float32>; +template class RefTransposeWorkload<DataType::QAsymmU8>; +template class RefTransposeWorkload<DataType::QSymmS16>; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp new file mode 100644 index 0000000000..4b1c3d303b --- /dev/null +++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp @@ -0,0 +1,35 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> + +#include <armnn/TypesUtils.hpp> + +namespace armnn +{ + +template <armnn::DataType DataType> +class RefTransposeWorkload : public TypedWorkload<TransposeQueueDescriptor, DataType> +{ +public: + static const std::string& GetName() + { + static const std::string name = std::string("RefTranspose") + GetDataTypeName(DataType) + "Workload"; + return name; + } + + using TypedWorkload<TransposeQueueDescriptor, DataType>::m_Data; + using TypedWorkload<TransposeQueueDescriptor, DataType>::TypedWorkload; + void Execute() const override; +}; + +using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>; +using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>; +using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>; +using RefTransposeQSymm16Workload = RefTransposeWorkload<DataType::QSymmS16>; + +} //namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 7034b67aa5..a0558ff06e 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -58,6 +58,7 @@ #include "RefStridedSliceWorkload.hpp" #include "RefSpaceToDepthWorkload.hpp" #include "RefTransposeConvolution2dWorkload.hpp" +#include "RefTransposeWorkload.hpp" #include "RefWorkloadUtils.hpp" #include "Resize.hpp" #include "Softmax.hpp" |