diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2021-03-25 07:46:55 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2021-03-25 07:46:55 +0000 |
commit | f0a6dec75832604d5ab18242dc216852821a8279 (patch) | |
tree | ff25e64c62c63975a54abd16a8bff744be70d7c0 /src/backends/reference | |
parent | 16fb1a2d9c1d3d80c0f0b6ab549919fbabd2a0b9 (diff) | |
download | armnn-f0a6dec75832604d5ab18242dc216852821a8279.tar.gz |
IVGCVSW-5736 and IVGCVSW-5743 'NonConstWeights: Update front-end and TfLiteDelegate support for FullyConnected Operator'
* Added front-end support for non-const weights for FULLY_CONNECTED operator
* Added FULLY_CONNECTED end-to-end test
* Updated FULLY_CONNECTED operator support in TfLite Arm NN Delegate for non-const weights
* Updated the version numbers
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Iffa5b9aa9297aca4c02d923cce4636c88ac21faa
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefBackend.cpp | 10 | ||||
-rw-r--r-- | src/backends/reference/RefBackend.hpp | 6 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 6 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 14 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefFullyConnectedWorkload.cpp | 45 |
5 files changed, 67 insertions, 14 deletions
diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp index e93b317dce..53c55ab26a 100644 --- a/src/backends/reference/RefBackend.cpp +++ b/src/backends/reference/RefBackend.cpp @@ -69,6 +69,16 @@ IBackendInternal::ILayerSupportSharedPtr RefBackend::GetLayerSupport() const return layerSupport; } +bool RefBackend::HasCapability(BackendCapability capabilityClass) const +{ + auto search = backendCapabilities.find(capabilityClass); + if (search != backendCapabilities.end()) + { + return true; + } + return false; +} + OptimizationViews RefBackend::OptimizeSubgraphView(const SubgraphView& subgraph) const { OptimizationViews optimizationViews; diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp index 92d392dde6..c92936ca0c 100644 --- a/src/backends/reference/RefBackend.hpp +++ b/src/backends/reference/RefBackend.hpp @@ -9,6 +9,10 @@ namespace armnn { +const std::set<armnn::BackendCapability> backendCapabilities { + armnn::BackendCapability::NonConstWeights, +}; + class RefBackend : public IBackendInternal { public: @@ -39,6 +43,8 @@ public: std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override; void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry) override; + + bool HasCapability(BackendCapability capabilityClass) const override; }; } // namespace armnn diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 4598568070..b6974811ef 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -15,6 +15,7 @@ #include <backendsCommon/test/DetectionPostProcessEndToEndTestImpl.hpp> #include <backendsCommon/test/ElementwiseUnaryEndToEndTestImpl.hpp> #include <backendsCommon/test/FillEndToEndTestImpl.hpp> +#include <backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp> #include <backendsCommon/test/GatherEndToEndTestImpl.hpp> #include <backendsCommon/test/InstanceNormalizationEndToEndTestImpl.hpp> #include <backendsCommon/test/LogSoftmaxEndToEndTestImpl.hpp> @@ -599,6 +600,11 @@ BOOST_AUTO_TEST_CASE(RefFillEndToEndTestInt32) FillEndToEnd<armnn::DataType::Signed32>(defaultBackends); } +BOOST_AUTO_TEST_CASE(RefFullyConnectedEndToEndTestInt32) +{ + FullyConnectedWithDynamicWeightsEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + BOOST_AUTO_TEST_CASE(RefGatherFloatTest) { GatherEndToEnd<armnn::DataType::Float32>(defaultBackends); diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 161476ed98..7371692d0e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -579,16 +579,22 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(HardSwishInt16, HardSwishInt16Test) // Fully Connected ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnected, FullyConnectedFloat32Test, false, false) -ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedUint8, FullyConnectedTest<DataType::QAsymmU8>, false) -ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedQSymm16, FullyConnectedTest<DataType::QSymmS16>, false) +ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedUint8, FullyConnectedTest<DataType::QAsymmU8>, false, true) +ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedQSymm16, FullyConnectedTest<DataType::QSymmS16>, false, true) ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnectedWithBias, FullyConnectedFloat32Test, true, false) -ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedUint8, FullyConnectedTest<DataType::QAsymmU8>, true) -ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedQSymm16, FullyConnectedTest<DataType::QSymmS16>, true) +ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedUint8, FullyConnectedTest<DataType::QAsymmU8>, true, true) +ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedQSymm16, FullyConnectedTest<DataType::QSymmS16>, true, true) ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnectedWithTranspose, FullyConnectedFloat32Test, false, true) ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedLarge, FullyConnectedLargeTest, false) ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedLargeTransposed, FullyConnectedLargeTest, true) + +ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedWeightsAsInputsUint8, + FullyConnectedTest<DataType::QAsymmU8>, + false, + false) + // Splitter ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleSplitterFloat32, SplitterFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleSplitterFloat16, SplitterFloat16Test) diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp index 9acca219b5..49e105f206 100644 --- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp +++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp @@ -14,18 +14,21 @@ namespace armnn { RefFullyConnectedWorkload::RefFullyConnectedWorkload( const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info), - m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight))) + : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info) { - const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo(); - m_WeightShape = rWeightInfo.GetShape(); - m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true)); - - if (descriptor.m_Parameters.m_BiasEnabled) + if (descriptor.m_Parameters.m_ConstantWeights) { - m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)); - const TensorInfo& biasInfo = m_Bias->GetTensorInfo(); - m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true)); + m_Weight = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)); + const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo(); + m_WeightShape = rWeightInfo.GetShape(); + m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true)); + + if (descriptor.m_Parameters.m_BiasEnabled) + { + m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)); + const TensorInfo& biasInfo = m_Bias->GetTensorInfo(); + m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true)); + } } } @@ -36,6 +39,20 @@ void RefFullyConnectedWorkload::PostAllocationConfigure() m_InputShape = inputInfo.GetShape(); m_InputDecoder = MakeDecoder<float>(inputInfo); + if (!m_Data.m_Parameters.m_ConstantWeights) + { + const TensorInfo& rWeightInfo = GetTensorInfo(m_Data.m_Inputs[1]); + ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); + m_WeightShape = rWeightInfo.GetShape(); + m_WeightDecoder = MakeDecoder<float>(rWeightInfo); + + if (m_Data.m_Parameters.m_BiasEnabled) + { + const TensorInfo& biasInfo = GetTensorInfo(m_Data.m_Inputs[2]); + m_BiasDecoder = MakeDecoder<float>(biasInfo); + } + } + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); m_OutputShape = outputInfo.GetShape(); m_OutputEncoder = MakeEncoder<float>(outputInfo); @@ -52,6 +69,14 @@ void RefFullyConnectedWorkload::Execute() const ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute"); m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map()); + if (!m_Data.m_Parameters.m_ConstantWeights) + { + m_WeightDecoder->Reset(m_Data.m_Inputs[1]->Map()); + if (m_Data.m_Parameters.m_BiasEnabled) + { + m_BiasDecoder->Reset(m_Data.m_Inputs[2]->Map()); + } + } m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map()); FullyConnected(m_InputShape, |