aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2019-01-23 15:23:11 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-01-23 17:13:15 +0000
commitdb2b160bf9e7759d0157dfa57ee940290f5170e3 (patch)
tree536fa36ebc9eb8442b96b486a10cadab28d32647
parentc625f000198218fc8d03130ee5658f73b94b2683 (diff)
downloadarmnn-db2b160bf9e7759d0157dfa57ee940290f5170e3.tar.gz
IVGCVSW-2511 Add end to end Gather layer test
* Add end to end test for Gather operator * Add Support for int32 to Constant layer for Ref * Add Int32Workload * Add RefConstantWorkload as template for float, uint8, int32 * Remove unused RefBaseConstantWorkload * Remove unused RefConstantFloat32Workload * Remove unused RefConstantUint8Workload * Add support check for int32 in LayerSupport functions Change-Id: Ic970588a49ebe2aafb12be8adef52371feacaa7b
-rw-r--r--src/armnn/LayerSupportCommon.hpp15
-rw-r--r--src/backends/backendsCommon/MakeWorkloadHelper.hpp9
-rw-r--r--src/backends/backendsCommon/Workload.hpp3
-rw-r--r--src/backends/backendsCommon/test/CMakeLists.txt2
-rw-r--r--src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp124
-rw-r--r--src/backends/cl/ClLayerSupport.cpp4
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp4
-rw-r--r--src/backends/reference/RefLayerSupport.cpp23
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp9
-rw-r--r--src/backends/reference/backend.mk4
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp21
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt8
-rw-r--r--src/backends/reference/workloads/RefConstantFloat32Workload.cpp19
-rw-r--r--src/backends/reference/workloads/RefConstantFloat32Workload.hpp20
-rw-r--r--src/backends/reference/workloads/RefConstantUint8Workload.cpp19
-rw-r--r--src/backends/reference/workloads/RefConstantUint8Workload.hpp20
-rw-r--r--src/backends/reference/workloads/RefConstantWorkload.cpp (renamed from src/backends/reference/workloads/RefBaseConstantWorkload.cpp)13
-rw-r--r--src/backends/reference/workloads/RefConstantWorkload.hpp (renamed from src/backends/reference/workloads/RefBaseConstantWorkload.hpp)13
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp4
19 files changed, 217 insertions, 117 deletions
diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp
index c309f8c6c7..109728cd81 100644
--- a/src/armnn/LayerSupportCommon.hpp
+++ b/src/armnn/LayerSupportCommon.hpp
@@ -12,12 +12,13 @@
namespace armnn
{
-template<typename Float16Func, typename Float32Func, typename Uint8Func, typename ... Params>
+template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename ... Params>
bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported,
DataType dataType,
Float16Func float16FuncPtr,
Float32Func float32FuncPtr,
Uint8Func uint8FuncPtr,
+ Int32Func int32FuncPtr,
Params&&... params)
{
switch(dataType)
@@ -28,6 +29,8 @@ bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported,
return float32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
case DataType::QuantisedAsymm8:
return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
+ case DataType::Signed32:
+ return int32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
default:
return false;
}
@@ -76,6 +79,16 @@ bool FalseFuncU8(Optional<std::string&> reasonIfUnsupported, Params&&... params)
}
template<typename ... Params>
+bool FalseFuncI32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
+{
+ if (reasonIfUnsupported)
+ {
+ reasonIfUnsupported.value() = "Layer is not supported with int32 data type";
+ }
+ return false;
+}
+
+template<typename ... Params>
bool FalseInputFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
{
if (reasonIfUnsupported)
diff --git a/src/backends/backendsCommon/MakeWorkloadHelper.hpp b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
index 78a9669530..7784cc6d4d 100644
--- a/src/backends/backendsCommon/MakeWorkloadHelper.hpp
+++ b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
@@ -37,8 +37,8 @@ struct MakeWorkloadForType<NullWorkload>
// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
-template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename QueueDescriptorType,
- typename... Args>
+template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename Int32Workload,
+ typename QueueDescriptorType, typename... Args>
std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
const WorkloadInfo& info,
Args&&... args)
@@ -58,6 +58,8 @@ std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descrip
return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::QuantisedAsymm8:
return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
+ case DataType::Signed32:
+ return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
default:
BOOST_ASSERT_MSG(false, "Unknown DataType.");
return nullptr;
@@ -73,10 +75,9 @@ std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descrip
const WorkloadInfo& info,
Args&&... args)
{
- return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload>(descriptor, info,
+ return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload, NullWorkload>(descriptor, info,
std::forward<Args>(args)...);
}
-
} //namespace
} //namespace armnn
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 65392194a2..34d13635ba 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -162,6 +162,9 @@ template <typename QueueDescriptor>
using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
template <typename QueueDescriptor>
+using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
+
+template <typename QueueDescriptor>
using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
armnn::DataType::Float16,
armnn::DataType::Float32>;
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 8107176210..80a9cfeaa9 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -16,6 +16,8 @@ list(APPEND armnnBackendsCommonUnitTests_sources
DebugTestImpl.hpp
EndToEndTestImpl.hpp
FullyConnectedTestImpl.hpp
+ GatherTestImpl.hpp
+ GatherEndToEndTestImpl.hpp
IsLayerSupportedTestImpl.hpp
JsonPrinterTestImpl.cpp
JsonPrinterTestImpl.hpp
diff --git a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
new file mode 100644
index 0000000000..d30da549df
--- /dev/null
+++ b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
@@ -0,0 +1,124 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/INetwork.hpp>
+#include <backendsCommon/test/CommonTestUtils.hpp>
+#include <TypeUtils.hpp>
+
+namespace{
+
+armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
+ const armnn::TensorInfo& indicesInfo,
+ const armnn::TensorInfo& outputInfo,
+ const std::vector<int32_t>& indicesData)
+{
+ armnn::INetworkPtr net(armnn::INetwork::Create());
+
+ armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
+ armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
+ armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather");
+ armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
+ Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
+ Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
+ Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
+
+ return net;
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+void GatherEndToEnd(const std::vector<BackendId>& backends)
+{
+ armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
+ armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+ armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
+
+ paramsInfo.SetQuantizationScale(1.0f);
+ paramsInfo.SetQuantizationOffset(0);
+ outputInfo.SetQuantizationScale(1.0f);
+ outputInfo.SetQuantizationOffset(0);
+
+ // Creates structures for input & output.
+ std::vector<T> paramsData{
+ 1, 2, 3, 4, 5, 6, 7, 8
+ };
+
+ std::vector<int32_t> indicesData{
+ 7, 6, 5
+ };
+
+ std::vector<T> expectedOutput{
+ 8, 7, 6
+ };
+
+ // Builds up the structure of the network
+ armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
+
+ BOOST_TEST_CHECKPOINT("create a network");
+
+ std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
+ std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
+
+ EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
+{
+ armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
+ armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
+ armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
+
+ paramsInfo.SetQuantizationScale(1.0f);
+ paramsInfo.SetQuantizationOffset(0);
+ outputInfo.SetQuantizationScale(1.0f);
+ outputInfo.SetQuantizationOffset(0);
+
+ // Creates structures for input & output.
+ std::vector<T> paramsData{
+ 1, 2, 3,
+ 4, 5, 6,
+
+ 7, 8, 9,
+ 10, 11, 12,
+
+ 13, 14, 15,
+ 16, 17, 18
+ };
+
+ std::vector<int32_t> indicesData{
+ 1, 2, 1,
+ 2, 1, 0
+ };
+
+ std::vector<T> expectedOutput{
+ 7, 8, 9,
+ 10, 11, 12,
+ 13, 14, 15,
+ 16, 17, 18,
+ 7, 8, 9,
+ 10, 11, 12,
+
+ 13, 14, 15,
+ 16, 17, 18,
+ 7, 8, 9,
+ 10, 11, 12,
+ 1, 2, 3,
+ 4, 5, 6
+ };
+
+ // Builds up the structure of the network
+ armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
+
+ BOOST_TEST_CHECKPOINT("create a network");
+
+ std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
+ std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
+
+ EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
+}
+
+} // anonymous namespace \ No newline at end of file
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index cb03e8b5ae..3e35f9d52d 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -121,6 +121,7 @@ bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
floatFuncPtr,
floatFuncPtr,
uint8FuncPtr,
+ &FalseFunc<>,
std::forward<Params>(params)...);
}
@@ -265,7 +266,8 @@ bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
input.GetDataType(),
&FalseFuncF16<>,
&TrueFunc<>,
- &FalseFuncU8<>);
+ &FalseFuncU8<>,
+ &FalseFuncI32<>);
}
bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index 76cdf140d2..2f83c8f82a 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -71,6 +71,7 @@ bool IsSupportedForDataTypeNeon(Optional<std::string&> reasonIfUnsupported,
floatFuncPtr,
floatFuncPtr,
uint8FuncPtr,
+ &FalseFunc<>,
std::forward<Params>(params)...);
}
@@ -212,7 +213,8 @@ bool NeonLayerSupport::IsFloorSupported(const TensorInfo& input,
input.GetDataType(),
&FalseFuncF16<>,
&TrueFunc<>,
- &FalseFuncU8<>);
+ &FalseFuncU8<>,
+ &FalseFuncI32<>);
}
bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 25c2bafe2f..45f108c2f8 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -34,6 +34,7 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
&FalseFunc<Params...>,
floatFuncPtr,
uint8FuncPtr,
+ &FalseFunc<Params...>,
std::forward<Params>(params)...);
}
@@ -105,10 +106,12 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- output.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+ output.GetDataType(),
+ &FalseFunc<>,
+ &TrueFunc<>,
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
@@ -119,12 +122,14 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
input.GetDataType(),
&TrueFunc<>,
&FalseInputFuncF32<>,
- &FalseFuncU8<>) &&
+ &FalseFuncU8<>,
+ &FalseFuncI32<>) &&
IsSupportedForDataTypeGeneric(reasonIfUnsupported,
output.GetDataType(),
&FalseOutputFuncF16<>,
&TrueFunc<>,
- &FalseFuncU8<>));
+ &FalseFuncU8<>,
+ &FalseFuncI32<>));
}
bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
@@ -135,12 +140,14 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
input.GetDataType(),
&FalseInputFuncF16<>,
&TrueFunc<>,
- &FalseFuncU8<>) &&
+ &FalseFuncU8<>,
+ &FalseFuncI32<>) &&
IsSupportedForDataTypeGeneric(reasonIfUnsupported,
output.GetDataType(),
&TrueFunc<>,
&FalseOutputFuncF32<>,
- &FalseFuncU8<>));
+ &FalseFuncU8<>,
+ &FalseFuncI32<>));
}
bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 9bdda9d128..b112e9dd6a 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -24,7 +24,7 @@ template <typename F32Workload, typename U8Workload, typename QueueDescriptorTyp
std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
const WorkloadInfo& info) const
{
- return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload>(descriptor, info);
+ return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload>(descriptor, info);
}
RefWorkloadFactory::RefWorkloadFactory()
@@ -126,8 +126,8 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload>
- (descriptor, info);
+ return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload,
+ NullWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
@@ -205,7 +205,8 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2Nor
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info);
+ return MakeWorkloadHelper<NullWorkload, RefConstantFloat32Workload, RefConstantUint8Workload,
+ RefConstantInt32Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 8dd6a51139..763f26e18c 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -24,13 +24,11 @@ BACKEND_SOURCES := \
workloads/Pooling2d.cpp \
workloads/RefActivationFloat32Workload.cpp \
workloads/RefActivationUint8Workload.cpp \
- workloads/RefBaseConstantWorkload.cpp \
workloads/RefBatchNormalizationFloat32Workload.cpp \
workloads/RefBatchNormalizationUint8Workload.cpp \
workloads/RefBatchToSpaceNdFloat32Workload.cpp \
workloads/RefBatchToSpaceNdUint8Workload.cpp \
- workloads/RefConstantFloat32Workload.cpp \
- workloads/RefConstantUint8Workload.cpp \
+ workloads/RefConstantWorkload.cpp \
workloads/RefConvertFp16ToFp32Workload.cpp \
workloads/RefConvertFp32ToFp16Workload.cpp \
workloads/RefConvolution2dFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 4f4a161509..330f406265 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -4,6 +4,7 @@
//
#include <backendsCommon/test/EndToEndTestImpl.hpp>
+#include <backendsCommon/test/GatherEndToEndTestImpl.hpp>
#include <backendsCommon/test/MergerTestImpl.hpp>
#include <backendsCommon/test/ArithmeticTestImpl.hpp>
@@ -416,4 +417,24 @@ BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim3Uint8Test)
MergerDim3EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends);
}
+BOOST_AUTO_TEST_CASE(RefGatherFloatTest)
+{
+ GatherEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(RefGatherUint8Test)
+{
+ GatherEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(RefGatherMultiDimFloatTest)
+{
+ GatherMultiDimEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(RefGatherMultiDimUint8Test)
+{
+ GatherMultiDimEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends);
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 583c89a5b4..f95fda08d1 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -32,8 +32,6 @@ list(APPEND armnnRefBackendWorkloads_sources
RefActivationFloat32Workload.hpp
RefActivationUint8Workload.cpp
RefActivationUint8Workload.hpp
- RefBaseConstantWorkload.cpp
- RefBaseConstantWorkload.hpp
RefBatchNormalizationFloat32Workload.cpp
RefBatchNormalizationFloat32Workload.hpp
RefBatchNormalizationUint8Workload.cpp
@@ -42,10 +40,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchToSpaceNdFloat32Workload.hpp
RefBatchToSpaceNdUint8Workload.cpp
RefBatchToSpaceNdUint8Workload.hpp
- RefConstantFloat32Workload.cpp
- RefConstantFloat32Workload.hpp
- RefConstantUint8Workload.cpp
- RefConstantUint8Workload.hpp
+ RefConstantWorkload.cpp
+ RefConstantWorkload.hpp
RefConvertFp16ToFp32Workload.cpp
RefConvertFp16ToFp32Workload.hpp
RefConvertFp32ToFp16Workload.cpp
diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp b/src/backends/reference/workloads/RefConstantFloat32Workload.cpp
deleted file mode 100644
index 074e8ccaae..0000000000
--- a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefConstantFloat32Workload.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefConstantFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantFloat32Workload_Execute");
- RefBaseConstantWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp b/src/backends/reference/workloads/RefConstantFloat32Workload.hpp
deleted file mode 100644
index 76e3a42026..0000000000
--- a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "RefBaseConstantWorkload.hpp"
-
-namespace armnn
-{
-
-class RefConstantFloat32Workload : public RefBaseConstantWorkload<DataType::Float32>
-{
-public:
- using RefBaseConstantWorkload<DataType::Float32>::RefBaseConstantWorkload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.cpp b/src/backends/reference/workloads/RefConstantUint8Workload.cpp
deleted file mode 100644
index 07e4719d54..0000000000
--- a/src/backends/reference/workloads/RefConstantUint8Workload.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefConstantUint8Workload.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefConstantUint8Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantUint8Workload_Execute");
- RefBaseConstantWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.hpp b/src/backends/reference/workloads/RefConstantUint8Workload.hpp
deleted file mode 100644
index 02552ac80b..0000000000
--- a/src/backends/reference/workloads/RefConstantUint8Workload.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "RefBaseConstantWorkload.hpp"
-
-namespace armnn
-{
-
-class RefConstantUint8Workload : public RefBaseConstantWorkload<DataType::QuantisedAsymm8>
-{
-public:
- using RefBaseConstantWorkload<DataType::QuantisedAsymm8>::RefBaseConstantWorkload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp b/src/backends/reference/workloads/RefConstantWorkload.cpp
index 647677b4fb..e074c6fb04 100644
--- a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp
+++ b/src/backends/reference/workloads/RefConstantWorkload.cpp
@@ -1,9 +1,9 @@
-//
+//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include "RefBaseConstantWorkload.hpp"
+#include "RefConstantWorkload.hpp"
#include "RefWorkloadUtils.hpp"
@@ -17,7 +17,7 @@ namespace armnn
{
template <armnn::DataType DataType>
-void RefBaseConstantWorkload<DataType>::Execute() const
+void RefConstantWorkload<DataType>::Execute() const
{
// Considering the reference backend independently, it could be possible to initialise the intermediate tensor
// created by the layer output handler at workload construction time, rather than at workload execution time.
@@ -27,6 +27,8 @@ void RefBaseConstantWorkload<DataType>::Execute() const
// could have a non-owning reference to the layer output tensor managed by the const input layer); again, this is
// not an option for other backends, and the extra complexity required to make this work for the reference backend
// may not be worth the effort (skipping a memory copy in the first inference).
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantWorkload_Execute");
+
if (!m_RanOnce)
{
const ConstantQueueDescriptor& data = this->m_Data;
@@ -43,7 +45,8 @@ void RefBaseConstantWorkload<DataType>::Execute() const
}
}
-template class RefBaseConstantWorkload<DataType::Float32>;
-template class RefBaseConstantWorkload<DataType::QuantisedAsymm8>;
+template class RefConstantWorkload<DataType::Float32>;
+template class RefConstantWorkload<DataType::QuantisedAsymm8>;
+template class RefConstantWorkload<DataType::Signed32>;
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp b/src/backends/reference/workloads/RefConstantWorkload.hpp
index 82ee11f6ec..75d7ecce26 100644
--- a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp
+++ b/src/backends/reference/workloads/RefConstantWorkload.hpp
@@ -1,4 +1,4 @@
-//
+//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -15,19 +15,26 @@ namespace armnn
// Base class template providing an implementation of the Constant layer common to all data types.
template <armnn::DataType DataType>
-class RefBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataType>
+class RefConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataType>
{
public:
- RefBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
+ RefConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
: TypedWorkload<ConstantQueueDescriptor, DataType>(descriptor, info)
, m_RanOnce(false)
{
}
+ using TypedWorkload<ConstantQueueDescriptor, DataType>::m_Data;
+ using TypedWorkload<ConstantQueueDescriptor, DataType>::TypedWorkload;
+
virtual void Execute() const override;
private:
mutable bool m_RanOnce;
};
+using RefConstantFloat32Workload = RefConstantWorkload<DataType::Float32>;
+using RefConstantUint8Workload = RefConstantWorkload<DataType::QuantisedAsymm8>;
+using RefConstantInt32Workload = RefConstantWorkload<DataType::Signed32>;
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 8550ee583e..1cbceb366b 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -5,11 +5,10 @@
#pragma once
-#include "RefConstantUint8Workload.hpp"
#include "ElementwiseFunction.hpp"
#include "RefElementwiseWorkload.hpp"
#include "ConvImpl.hpp"
-#include "RefBaseConstantWorkload.hpp"
+#include "RefConstantWorkload.hpp"
#include "RefConvolution2dUint8Workload.hpp"
#include "RefSplitterUint8Workload.hpp"
#include "RefResizeBilinearUint8Workload.hpp"
@@ -46,7 +45,6 @@
#include "RefSpaceToBatchNdWorkload.hpp"
#include "RefSplitterFloat32Workload.hpp"
#include "RefStridedSliceWorkload.hpp"
-#include "RefConstantFloat32Workload.hpp"
#include "RefActivationFloat32Workload.hpp"
#include "RefConvolution2dFloat32Workload.hpp"
#include "Pooling2d.hpp"