diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 23 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 9 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 4 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 21 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 8 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConstantFloat32Workload.cpp | 19 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConstantFloat32Workload.hpp | 20 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConstantUint8Workload.cpp | 19 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConstantUint8Workload.hpp | 20 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConstantWorkload.cpp (renamed from src/backends/reference/workloads/RefBaseConstantWorkload.cpp) | 13 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefConstantWorkload.hpp (renamed from src/backends/reference/workloads/RefBaseConstantWorkload.hpp) | 13 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 4 |
12 files changed, 63 insertions, 110 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 25c2bafe2f..45f108c2f8 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -34,6 +34,7 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported, &FalseFunc<Params...>, floatFuncPtr, uint8FuncPtr, + &FalseFunc<Params...>, std::forward<Params>(params)...); } @@ -105,10 +106,12 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &FalseFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, @@ -119,12 +122,14 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, input.GetDataType(), &TrueFunc<>, &FalseInputFuncF32<>, - &FalseFuncU8<>) && + &FalseFuncU8<>, + &FalseFuncI32<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>)); + &FalseFuncU8<>, + &FalseFuncI32<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, @@ -135,12 +140,14 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, input.GetDataType(), &FalseInputFuncF16<>, &TrueFunc<>, - &FalseFuncU8<>) && + &FalseFuncU8<>, + &FalseFuncI32<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, - &FalseFuncU8<>)); + &FalseFuncU8<>, + &FalseFuncI32<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 9bdda9d128..b112e9dd6a 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -24,7 +24,7 @@ template <typename F32Workload, typename U8Workload, typename QueueDescriptorTyp std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const { - return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload>(descriptor, info); + return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload>(descriptor, info); } RefWorkloadFactory::RefWorkloadFactory() @@ -126,8 +126,8 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected( std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload> - (descriptor, info); + return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload, + NullWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, @@ -205,7 +205,8 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2Nor std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info); + return MakeWorkloadHelper<NullWorkload, RefConstantFloat32Workload, RefConstantUint8Workload, + RefConstantInt32Workload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 8dd6a51139..763f26e18c 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -24,13 +24,11 @@ BACKEND_SOURCES := \ workloads/Pooling2d.cpp \ workloads/RefActivationFloat32Workload.cpp \ workloads/RefActivationUint8Workload.cpp \ - workloads/RefBaseConstantWorkload.cpp \ workloads/RefBatchNormalizationFloat32Workload.cpp \ workloads/RefBatchNormalizationUint8Workload.cpp \ workloads/RefBatchToSpaceNdFloat32Workload.cpp \ workloads/RefBatchToSpaceNdUint8Workload.cpp \ - workloads/RefConstantFloat32Workload.cpp \ - workloads/RefConstantUint8Workload.cpp \ + workloads/RefConstantWorkload.cpp \ workloads/RefConvertFp16ToFp32Workload.cpp \ workloads/RefConvertFp32ToFp16Workload.cpp \ workloads/RefConvolution2dFloat32Workload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 4f4a161509..330f406265 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -4,6 +4,7 @@ // #include <backendsCommon/test/EndToEndTestImpl.hpp> +#include <backendsCommon/test/GatherEndToEndTestImpl.hpp> #include <backendsCommon/test/MergerTestImpl.hpp> #include <backendsCommon/test/ArithmeticTestImpl.hpp> @@ -416,4 +417,24 @@ BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim3Uint8Test) MergerDim3EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); } +BOOST_AUTO_TEST_CASE(RefGatherFloatTest) +{ + GatherEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +BOOST_AUTO_TEST_CASE(RefGatherUint8Test) +{ + GatherEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); +} + +BOOST_AUTO_TEST_CASE(RefGatherMultiDimFloatTest) +{ + GatherMultiDimEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +BOOST_AUTO_TEST_CASE(RefGatherMultiDimUint8Test) +{ + GatherMultiDimEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 583c89a5b4..f95fda08d1 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -32,8 +32,6 @@ list(APPEND armnnRefBackendWorkloads_sources RefActivationFloat32Workload.hpp RefActivationUint8Workload.cpp RefActivationUint8Workload.hpp - RefBaseConstantWorkload.cpp - RefBaseConstantWorkload.hpp RefBatchNormalizationFloat32Workload.cpp RefBatchNormalizationFloat32Workload.hpp RefBatchNormalizationUint8Workload.cpp @@ -42,10 +40,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp RefBatchToSpaceNdUint8Workload.hpp - RefConstantFloat32Workload.cpp - RefConstantFloat32Workload.hpp - RefConstantUint8Workload.cpp - RefConstantUint8Workload.hpp + RefConstantWorkload.cpp + RefConstantWorkload.hpp RefConvertFp16ToFp32Workload.cpp RefConvertFp16ToFp32Workload.hpp RefConvertFp32ToFp16Workload.cpp diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp b/src/backends/reference/workloads/RefConstantFloat32Workload.cpp deleted file mode 100644 index 074e8ccaae..0000000000 --- a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefConstantFloat32Workload.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ - -void RefConstantFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantFloat32Workload_Execute"); - RefBaseConstantWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp b/src/backends/reference/workloads/RefConstantFloat32Workload.hpp deleted file mode 100644 index 76e3a42026..0000000000 --- a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "RefBaseConstantWorkload.hpp" - -namespace armnn -{ - -class RefConstantFloat32Workload : public RefBaseConstantWorkload<DataType::Float32> -{ -public: - using RefBaseConstantWorkload<DataType::Float32>::RefBaseConstantWorkload; - virtual void Execute() const override; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.cpp b/src/backends/reference/workloads/RefConstantUint8Workload.cpp deleted file mode 100644 index 07e4719d54..0000000000 --- a/src/backends/reference/workloads/RefConstantUint8Workload.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefConstantUint8Workload.hpp" - -#include "Profiling.hpp" - -namespace armnn -{ - -void RefConstantUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantUint8Workload_Execute"); - RefBaseConstantWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.hpp b/src/backends/reference/workloads/RefConstantUint8Workload.hpp deleted file mode 100644 index 02552ac80b..0000000000 --- a/src/backends/reference/workloads/RefConstantUint8Workload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "RefBaseConstantWorkload.hpp" - -namespace armnn -{ - -class RefConstantUint8Workload : public RefBaseConstantWorkload<DataType::QuantisedAsymm8> -{ -public: - using RefBaseConstantWorkload<DataType::QuantisedAsymm8>::RefBaseConstantWorkload; - virtual void Execute() const override; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp b/src/backends/reference/workloads/RefConstantWorkload.cpp index 647677b4fb..e074c6fb04 100644 --- a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp +++ b/src/backends/reference/workloads/RefConstantWorkload.cpp @@ -1,9 +1,9 @@ -// +// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // -#include "RefBaseConstantWorkload.hpp" +#include "RefConstantWorkload.hpp" #include "RefWorkloadUtils.hpp" @@ -17,7 +17,7 @@ namespace armnn { template <armnn::DataType DataType> -void RefBaseConstantWorkload<DataType>::Execute() const +void RefConstantWorkload<DataType>::Execute() const { // Considering the reference backend independently, it could be possible to initialise the intermediate tensor // created by the layer output handler at workload construction time, rather than at workload execution time. @@ -27,6 +27,8 @@ void RefBaseConstantWorkload<DataType>::Execute() const // could have a non-owning reference to the layer output tensor managed by the const input layer); again, this is // not an option for other backends, and the extra complexity required to make this work for the reference backend // may not be worth the effort (skipping a memory copy in the first inference). + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantWorkload_Execute"); + if (!m_RanOnce) { const ConstantQueueDescriptor& data = this->m_Data; @@ -43,7 +45,8 @@ void RefBaseConstantWorkload<DataType>::Execute() const } } -template class RefBaseConstantWorkload<DataType::Float32>; -template class RefBaseConstantWorkload<DataType::QuantisedAsymm8>; +template class RefConstantWorkload<DataType::Float32>; +template class RefConstantWorkload<DataType::QuantisedAsymm8>; +template class RefConstantWorkload<DataType::Signed32>; } //namespace armnn diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp b/src/backends/reference/workloads/RefConstantWorkload.hpp index 82ee11f6ec..75d7ecce26 100644 --- a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp +++ b/src/backends/reference/workloads/RefConstantWorkload.hpp @@ -1,4 +1,4 @@ -// +// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // @@ -15,19 +15,26 @@ namespace armnn // Base class template providing an implementation of the Constant layer common to all data types. template <armnn::DataType DataType> -class RefBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataType> +class RefConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataType> { public: - RefBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) + RefConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) : TypedWorkload<ConstantQueueDescriptor, DataType>(descriptor, info) , m_RanOnce(false) { } + using TypedWorkload<ConstantQueueDescriptor, DataType>::m_Data; + using TypedWorkload<ConstantQueueDescriptor, DataType>::TypedWorkload; + virtual void Execute() const override; private: mutable bool m_RanOnce; }; +using RefConstantFloat32Workload = RefConstantWorkload<DataType::Float32>; +using RefConstantUint8Workload = RefConstantWorkload<DataType::QuantisedAsymm8>; +using RefConstantInt32Workload = RefConstantWorkload<DataType::Signed32>; + } //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 8550ee583e..1cbceb366b 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -5,11 +5,10 @@ #pragma once -#include "RefConstantUint8Workload.hpp" #include "ElementwiseFunction.hpp" #include "RefElementwiseWorkload.hpp" #include "ConvImpl.hpp" -#include "RefBaseConstantWorkload.hpp" +#include "RefConstantWorkload.hpp" #include "RefConvolution2dUint8Workload.hpp" #include "RefSplitterUint8Workload.hpp" #include "RefResizeBilinearUint8Workload.hpp" @@ -46,7 +45,6 @@ #include "RefSpaceToBatchNdWorkload.hpp" #include "RefSplitterFloat32Workload.hpp" #include "RefStridedSliceWorkload.hpp" -#include "RefConstantFloat32Workload.hpp" #include "RefActivationFloat32Workload.hpp" #include "RefConvolution2dFloat32Workload.hpp" #include "Pooling2d.hpp" |