aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp30
-rw-r--r--src/backends/reference/RefLayerSupport.hpp4
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp3
-rw-r--r--src/backends/reference/backend.mk1
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp14
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/RefCastWorkload.cpp40
-rw-r--r--src/backends/reference/workloads/RefCastWorkload.hpp24
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
10 files changed, 125 insertions, 0 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 2e0a8f2faa..e278e45f01 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -303,6 +303,36 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ std::array<DataType, 9> supportedInputTypes =
+ {
+ DataType::BFloat16,
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16,
+ DataType::Signed32
+ };
+
+ bool supported = true;
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
+ "Reference cast: input is not a supported type");
+
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
+ "Reference cast: output is not a supported type");
+
+ supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+ "Reference cast: input and output shapes have different number of total elements");
+
+ return supported;
+}
+
bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index b75b778f7a..7a95bb06ac 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -46,6 +46,10 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsCastSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsComparisonSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index fde6c863c3..c1e3d58bd2 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -178,6 +178,12 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT
return std::make_unique<RefBatchToSpaceNdWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateCast(const CastQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefCastWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index c22d87fa43..734c5e49c3 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -85,6 +85,9 @@ public:
std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateCast(const CastQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 96765097e4..bf18284143 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -47,6 +47,7 @@ BACKEND_SOURCES := \
workloads/RefArgMinMaxWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdWorkload.cpp \
+ workloads/RefCastWorkload.cpp \
workloads/RefComparisonWorkload.cpp \
workloads/RefConcatWorkload.cpp \
workloads/RefConstantWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 7371692d0e..228df0946f 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1471,6 +1471,20 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(QLstm, QLstmTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(QLstm1, QLstmTest1)
ARMNN_AUTO_TEST_CASE_WITH_THF(QLstm2, QLstmTest2)
+// Cast
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastInt32ToFloat, CastInt32ToFloat2dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastInt16ToFloat, CastInt16ToFloat2dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastInt8ToFloat, CastInt8ToFloat2dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastInt8AsymmToFloat, CastInt8AsymmToFloat2dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastUIntToFloat, CastUInt8ToFloat2dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastInt8ToUInt, CastInt8ToUInt82dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastInt8AsymmToUInt, CastInt8AsymmToUInt82dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastFloat16ToFloat32, CastFloat16ToFloat322dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastBFloat16ToFloat32, CastBFloat16ToFloat322dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastFloatToFloat16, CastFloat32ToFloat162dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastFloatToIn8, CastFloat32ToInt82dTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(CastFloatToUInt8, CastFloat32ToUInt82dTest)
+
// Convert from BFloat16 to Float32
ARMNN_AUTO_TEST_CASE_WITH_THF(ConvertBf16ToFp32, ConvertBf16ToFp32Test)
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 1f4298be5d..dadedf995a 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
RefBatchToSpaceNdWorkload.hpp
+ RefCastWorkload.cpp
+ RefCastWorkload.hpp
RefComparisonWorkload.cpp
RefComparisonWorkload.hpp
RefConcatWorkload.cpp
diff --git a/src/backends/reference/workloads/RefCastWorkload.cpp b/src/backends/reference/workloads/RefCastWorkload.cpp
new file mode 100644
index 0000000000..7080415e5d
--- /dev/null
+++ b/src/backends/reference/workloads/RefCastWorkload.cpp
@@ -0,0 +1,40 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefCastWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
+#include <armnnUtils/FloatingPointConverter.hpp>
+#include <ResolveType.hpp>
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+
+namespace
+{
+ void Cast(armnn::Decoder<float>& in, armnn::Encoder<float>& out, const uint32_t numElements )
+ {
+ for (unsigned int i = 0; i < numElements; i++)
+ {
+ out.Set(in.Get());
+ ++in;
+ ++out;
+ }
+ }
+}
+
+namespace armnn
+{
+
+ void RefCastWorkload::Execute() const
+ {
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute");
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ Cast(*MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()),
+ *MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()),
+ inputInfo.GetNumElements());
+ }
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp
new file mode 100644
index 0000000000..6742ef08ca
--- /dev/null
+++ b/src/backends/reference/workloads/RefCastWorkload.hpp
@@ -0,0 +1,24 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include "RefWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+
+class RefCastWorkload : public BaseWorkload<CastQueueDescriptor>
+{
+public:
+ using BaseWorkload<CastQueueDescriptor>::BaseWorkload;
+ void Execute() const override;
+};
+
+} //namespace armnn
+
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 989644f633..d3995f2b82 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -18,6 +18,7 @@
#include "RefArgMinMaxWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
+#include "RefCastWorkload.hpp"
#include "RefComparisonWorkload.hpp"
#include "RefConvolution2dWorkload.hpp"
#include "RefConstantWorkload.hpp"