aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.cpp35
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.hpp35
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
4 files changed, 73 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 6795204d59..b2d8938745 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -141,6 +141,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefStridedSliceWorkload.hpp
RefTransposeConvolution2dWorkload.cpp
RefTransposeConvolution2dWorkload.hpp
+ RefTransposeWorkload.cpp
+ RefTransposeWorkload.hpp
RefWorkloads.hpp
RefWorkloadUtils.hpp
Resize.cpp
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp
new file mode 100644
index 0000000000..6bdfb2111d
--- /dev/null
+++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefTransposeWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <armnnUtils/Transpose.hpp>
+
+#include <ResolveType.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+void RefTransposeWorkload<DataType>::Execute() const
+{
+ using T = ResolveType<DataType>;
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
+
+ const ITensorHandle* src = m_Data.m_Inputs[0];
+ ITensorHandle* dst = m_Data.m_Outputs[0];
+ const PermutationVector& mappings = m_Data.m_Parameters.m_DimMappings;
+
+ armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T));
+}
+
+template class RefTransposeWorkload<DataType::Float16>;
+template class RefTransposeWorkload<DataType::Float32>;
+template class RefTransposeWorkload<DataType::QAsymmU8>;
+template class RefTransposeWorkload<DataType::QSymmS16>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
new file mode 100644
index 0000000000..4b1c3d303b
--- /dev/null
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+class RefTransposeWorkload : public TypedWorkload<TransposeQueueDescriptor, DataType>
+{
+public:
+ static const std::string& GetName()
+ {
+ static const std::string name = std::string("RefTranspose") + GetDataTypeName(DataType) + "Workload";
+ return name;
+ }
+
+ using TypedWorkload<TransposeQueueDescriptor, DataType>::m_Data;
+ using TypedWorkload<TransposeQueueDescriptor, DataType>::TypedWorkload;
+ void Execute() const override;
+};
+
+using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>;
+using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>;
+using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>;
+using RefTransposeQSymm16Workload = RefTransposeWorkload<DataType::QSymmS16>;
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 7034b67aa5..a0558ff06e 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -58,6 +58,7 @@
#include "RefStridedSliceWorkload.hpp"
#include "RefSpaceToDepthWorkload.hpp"
#include "RefTransposeConvolution2dWorkload.hpp"
+#include "RefTransposeWorkload.hpp"
#include "RefWorkloadUtils.hpp"
#include "Resize.hpp"
#include "Softmax.hpp"