diff options
Diffstat (limited to 'src/backends/reference/workloads')
4 files changed, 73 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 6795204d59..b2d8938745 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -141,6 +141,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefStridedSliceWorkload.hpp RefTransposeConvolution2dWorkload.cpp RefTransposeConvolution2dWorkload.hpp + RefTransposeWorkload.cpp + RefTransposeWorkload.hpp RefWorkloads.hpp RefWorkloadUtils.hpp Resize.cpp diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp new file mode 100644 index 0000000000..6bdfb2111d --- /dev/null +++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp @@ -0,0 +1,35 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefTransposeWorkload.hpp" +#include "RefWorkloadUtils.hpp" + +#include <armnnUtils/Transpose.hpp> + +#include <ResolveType.hpp> + +namespace armnn +{ + +template <armnn::DataType DataType> +void RefTransposeWorkload<DataType>::Execute() const +{ + using T = ResolveType<DataType>; + + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); + + const ITensorHandle* src = m_Data.m_Inputs[0]; + ITensorHandle* dst = m_Data.m_Outputs[0]; + const PermutationVector& mappings = m_Data.m_Parameters.m_DimMappings; + + armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T)); +} + +template class RefTransposeWorkload<DataType::Float16>; +template class RefTransposeWorkload<DataType::Float32>; +template class RefTransposeWorkload<DataType::QAsymmU8>; +template class RefTransposeWorkload<DataType::QSymmS16>; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp new file mode 100644 index 0000000000..4b1c3d303b --- /dev/null +++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp @@ -0,0 +1,35 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> + +#include <armnn/TypesUtils.hpp> + +namespace armnn +{ + +template <armnn::DataType DataType> +class RefTransposeWorkload : public TypedWorkload<TransposeQueueDescriptor, DataType> +{ +public: + static const std::string& GetName() + { + static const std::string name = std::string("RefTranspose") + GetDataTypeName(DataType) + "Workload"; + return name; + } + + using TypedWorkload<TransposeQueueDescriptor, DataType>::m_Data; + using TypedWorkload<TransposeQueueDescriptor, DataType>::TypedWorkload; + void Execute() const override; +}; + +using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>; +using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>; +using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>; +using RefTransposeQSymm16Workload = RefTransposeWorkload<DataType::QSymmS16>; + +} //namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 7034b67aa5..a0558ff06e 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -58,6 +58,7 @@ #include "RefStridedSliceWorkload.hpp" #include "RefSpaceToDepthWorkload.hpp" #include "RefTransposeConvolution2dWorkload.hpp" +#include "RefTransposeWorkload.hpp" #include "RefWorkloadUtils.hpp" #include "Resize.hpp" #include "Softmax.hpp" |