aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2020-02-28 18:11:58 +0000
committermike.kelly <mike.kelly@arm.com>2020-03-02 16:44:09 +0000
commitc9ea45adefdde2890e9aa191a5b31563a3dd35ea (patch)
tree2ea65c972d24cc2d823ea39eb105d4062db54934 /src/backends/reference
parent510f6183d289b176702a18f020449c68be6f1075 (diff)
downloadarmnn-c9ea45adefdde2890e9aa191a5b31563a3dd35ea.tar.gz
IVGCVSW-4375 Add support for Transpose
* Added TransposeLayer * Added CL, Neon and Ref Workloads * Added Transpose utilities * Added Serializer and Deserializer support * Added Quantizer support Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I04c755ba7cb5b1edf72b3c9f3c0314878032e3c7
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp32
-rw-r--r--src/backends/reference/RefLayerSupport.hpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp11
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp3
-rw-r--r--src/backends/reference/backend.mk1
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp14
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.cpp35
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.hpp35
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
10 files changed, 139 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 8f1f170c5c..25334c3b52 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1388,9 +1388,10 @@ bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
bool supported = true;
// Define supported output and inputs types.
- std::array<DataType,3> supportedTypes =
+ std::array<DataType, 4> supportedTypes =
{
DataType::Float32,
+ DataType::Float16,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1912,4 +1913,33 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const TransposeDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(descriptor);
+ bool supported = true;
+
+ // Define supported output and inputs types.
+ std::array<DataType, 4> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference transpose: input is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference transpose: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference transpose: input and output types are mismatched.");
+
+ return supported;
+}
+
} // namespace armnn
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 1551a55694..27f3f81489 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -318,6 +318,12 @@ public:
const TensorInfo& weights,
const Optional<TensorInfo>& biases,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
+ bool IsTransposeSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const TransposeDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
};
} // namespace armnn
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 02dbbabf9f..2a415bfbf0 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -561,6 +561,17 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const Subtracti
return std::make_unique<RefSubtractionWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ if (IsQSymmS16(info))
+ {
+ return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info);
+ }
+ return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload,
+ NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTransposeConvolution2d(
const TransposeConvolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index b5b9b0faf0..030ce6f03d 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -236,6 +236,9 @@ public:
std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 1987bd59fa..010d54871a 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -85,6 +85,7 @@ BACKEND_SOURCES := \
workloads/RefStridedSliceWorkload.cpp \
workloads/RefSplitterWorkload.cpp \
workloads/RefTransposeConvolution2dWorkload.cpp \
+ workloads/RefTransposeWorkload.cpp \
workloads/Resize.cpp \
workloads/Slice.cpp \
workloads/SpaceToBatchNd.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index d5c67ef6c7..ed2b995bd5 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1460,6 +1460,20 @@ ARMNN_AUTO_TEST_CASE(Slice3dInt16, Slice3dInt16Test)
ARMNN_AUTO_TEST_CASE(Slice2dInt16, Slice2dInt16Test)
ARMNN_AUTO_TEST_CASE(Slice1dInt16, Slice1dInt16Test)
+// Transpose
+ARMNN_AUTO_TEST_CASE(SimpleTransposeFloat32, SimpleTransposeTest<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet1Test, TransposeValueSet1Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet2Test, TransposeValueSet2Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet3Test, TransposeValueSet3Test<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(SimpleTransposeQASymm8, SimpleTransposeTest<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE(TransposeQASymm8ValueSet1Test, TransposeValueSet1Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE(TransposeQASymm8ValueSet2Test, TransposeValueSet2Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE(TransposeQASymm8ValueSet3Test, TransposeValueSet3Test<DataType::QAsymmU8>)
+ARMNN_AUTO_TEST_CASE(SimpleTransposeQSymm16, SimpleTransposeTest<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE(TransposeQSymm16ValueSet1Test, TransposeValueSet1Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE(TransposeQSymm16ValueSet2Test, TransposeValueSet2Test<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE(TransposeQSymm16ValueSet3Test, TransposeValueSet3Test<DataType::QSymmS16>)
+
// TransposeConvolution2d
ARMNN_AUTO_TEST_CASE(SimpleTransposeConvolution2dFloatNchw,
SimpleTransposeConvolution2dTest<DataType::Float32, DataType::Float32>,
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 6795204d59..b2d8938745 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -141,6 +141,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefStridedSliceWorkload.hpp
RefTransposeConvolution2dWorkload.cpp
RefTransposeConvolution2dWorkload.hpp
+ RefTransposeWorkload.cpp
+ RefTransposeWorkload.hpp
RefWorkloads.hpp
RefWorkloadUtils.hpp
Resize.cpp
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp
new file mode 100644
index 0000000000..6bdfb2111d
--- /dev/null
+++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefTransposeWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <armnnUtils/Transpose.hpp>
+
+#include <ResolveType.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+void RefTransposeWorkload<DataType>::Execute() const
+{
+ using T = ResolveType<DataType>;
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
+
+ const ITensorHandle* src = m_Data.m_Inputs[0];
+ ITensorHandle* dst = m_Data.m_Outputs[0];
+ const PermutationVector& mappings = m_Data.m_Parameters.m_DimMappings;
+
+ armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T));
+}
+
+template class RefTransposeWorkload<DataType::Float16>;
+template class RefTransposeWorkload<DataType::Float32>;
+template class RefTransposeWorkload<DataType::QAsymmU8>;
+template class RefTransposeWorkload<DataType::QSymmS16>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
new file mode 100644
index 0000000000..4b1c3d303b
--- /dev/null
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+class RefTransposeWorkload : public TypedWorkload<TransposeQueueDescriptor, DataType>
+{
+public:
+ static const std::string& GetName()
+ {
+ static const std::string name = std::string("RefTranspose") + GetDataTypeName(DataType) + "Workload";
+ return name;
+ }
+
+ using TypedWorkload<TransposeQueueDescriptor, DataType>::m_Data;
+ using TypedWorkload<TransposeQueueDescriptor, DataType>::TypedWorkload;
+ void Execute() const override;
+};
+
+using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>;
+using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>;
+using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>;
+using RefTransposeQSymm16Workload = RefTransposeWorkload<DataType::QSymmS16>;
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 7034b67aa5..a0558ff06e 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -58,6 +58,7 @@
#include "RefStridedSliceWorkload.hpp"
#include "RefSpaceToDepthWorkload.hpp"
#include "RefTransposeConvolution2dWorkload.hpp"
+#include "RefTransposeWorkload.hpp"
#include "RefWorkloadUtils.hpp"
#include "Resize.hpp"
#include "Softmax.hpp"