diff options
author | Matthew Bentham <matthew.bentham@arm.com> | 2018-09-17 11:17:41 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-10 16:16:56 +0100 |
commit | ab8cdc13443408d727fd38be42315cf942251940 (patch) | |
tree | f1576377150b72f1a2404915adcf1affd5eab3e4 /src/armnn/backends | |
parent | ca225f0ab9aac74ccc7c62cfcf46c95f7715b2ee (diff) | |
download | armnn-ab8cdc13443408d727fd38be42315cf942251940.tar.gz |
IVGCVSW-949 Add 8-bit fully connected support
Change-Id: I0953bb8dbc4b76001f207e37c8c2742a6ebd888b
Diffstat (limited to 'src/armnn/backends')
-rw-r--r-- | src/armnn/backends/ClLayerSupport.cpp | 7 | ||||
-rw-r--r-- | src/armnn/backends/ClWorkloadFactory.cpp | 4 | ||||
-rw-r--r-- | src/armnn/backends/ClWorkloads.hpp | 2 | ||||
-rw-r--r-- | src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp (renamed from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp) | 32 | ||||
-rw-r--r-- | src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp (renamed from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp) | 10 | ||||
-rw-r--r-- | src/armnn/backends/test/ArmComputeCl.cpp | 2 | ||||
-rw-r--r-- | src/armnn/backends/test/CreateWorkloadCl.cpp | 4 |
7 files changed, 36 insertions, 25 deletions
diff --git a/src/armnn/backends/ClLayerSupport.cpp b/src/armnn/backends/ClLayerSupport.cpp index 4664c2ee32..30a1330706 100644 --- a/src/armnn/backends/ClLayerSupport.cpp +++ b/src/armnn/backends/ClLayerSupport.cpp @@ -24,7 +24,7 @@ #include "ClWorkloads/ClDivisionFloatWorkload.hpp" #include "ClWorkloads/ClL2NormalizationFloatWorkload.hpp" #include "ClWorkloads/ClMultiplicationFloatWorkload.hpp" -#include "ClWorkloads/ClFullyConnectedFloatWorkload.hpp" +#include "ClWorkloads/ClFullyConnectedWorkload.hpp" #include "ClWorkloads/ClPooling2dBaseWorkload.hpp" #include "ClWorkloads/ClPermuteWorkload.hpp" #include "ClWorkloads/ClNormalizationFloatWorkload.hpp" @@ -269,11 +269,6 @@ bool IsFullyConnectedSupportedCl(const TensorInfo& input, const FullyConnectedDescriptor& descriptor, std::string* reasonIfUnsupported) { - // At the moment U8 is unsupported - if (input.GetDataType() == DataType::QuantisedAsymm8) - { - return false; - } FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate, reasonIfUnsupported, input, diff --git a/src/armnn/backends/ClWorkloadFactory.cpp b/src/armnn/backends/ClWorkloadFactory.cpp index c35f044e9e..591fb85dbb 100644 --- a/src/armnn/backends/ClWorkloadFactory.cpp +++ b/src/armnn/backends/ClWorkloadFactory.cpp @@ -116,8 +116,8 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMerger(const MergerQu std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected( const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClFullyConnectedFloatWorkload, NullWorkload>(descriptor, info, - m_MemoryManager.GetIntraLayerManager()); + return MakeWorkload<ClFullyConnectedWorkload, ClFullyConnectedWorkload>(descriptor, info, + m_MemoryManager.GetIntraLayerManager()); } std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, diff --git a/src/armnn/backends/ClWorkloads.hpp b/src/armnn/backends/ClWorkloads.hpp index 3472bca45c..2bbda8a62b 100644 --- a/src/armnn/backends/ClWorkloads.hpp +++ b/src/armnn/backends/ClWorkloads.hpp @@ -18,7 +18,7 @@ #include "backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.hpp" #include "backends/ClWorkloads/ClDivisionFloatWorkload.hpp" #include "backends/ClWorkloads/ClFloorFloatWorkload.hpp" -#include "backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp" +#include "backends/ClWorkloads/ClFullyConnectedWorkload.hpp" #include "backends/ClWorkloads/ClL2NormalizationFloatWorkload.hpp" #include "backends/ClWorkloads/ClLstmFloatWorkload.hpp" #include "backends/ClWorkloads/ClMergerFloatWorkload.hpp" diff --git a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp index b1dab7c7b9..5307fab062 100644 --- a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp +++ b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT // -#include "ClFullyConnectedFloatWorkload.hpp" +#include "ClFullyConnectedWorkload.hpp" #include "backends/ClTensorHandle.hpp" #include "backends/CpuTensorHandle.hpp" #include "backends/ArmComputeTensorUtils.hpp" @@ -42,9 +42,9 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input, fullyConnectedLayerInfo); } -ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnectedQueueDescriptor& descriptor, +ClFullyConnectedWorkload::ClFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager) - : FloatWorkload<FullyConnectedQueueDescriptor>(descriptor, info) + : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info) , m_FullyConnectedLayer(memoryManager) { m_WeightsTensor = std::make_unique<arm_compute::CLTensor>(); @@ -56,7 +56,7 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo()); } - m_Data.ValidateInputsOutputs("ClFullyConnectedFloatWorkload", 1, 1); + m_Data.ValidateInputsOutputs("ClFullyConnectedWorkload", 1, 1); arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); @@ -67,11 +67,25 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info); // Allocate - InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight); + if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8) + { + InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>()); + } + else + { + InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight); + } if (m_BiasesTensor) { - InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias); + if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32) + { + InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>()); + } + else + { + InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias); + } } // Force Compute Library to perform the necessary copying and reshaping, after which @@ -80,13 +94,13 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte FreeUnusedTensors(); } -void ClFullyConnectedFloatWorkload::Execute() const +void ClFullyConnectedWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedFloatWorkload_Execute"); + ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedWorkload_Execute"); m_FullyConnectedLayer.run(); } -void ClFullyConnectedFloatWorkload::FreeUnusedTensors() +void ClFullyConnectedWorkload::FreeUnusedTensors() { FreeTensorIfUnused(m_WeightsTensor); FreeTensorIfUnused(m_BiasesTensor); diff --git a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp index e8d6a7897d..7aa9b86e15 100644 --- a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp +++ b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp @@ -20,14 +20,14 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input, const TensorInfo& biases, const FullyConnectedDescriptor& descriptor); -class ClFullyConnectedFloatWorkload : public armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor> +class ClFullyConnectedWorkload : public armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor> { public: - ClFullyConnectedFloatWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor, - const armnn::WorkloadInfo& info, - std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager); + ClFullyConnectedWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor, + const armnn::WorkloadInfo& info, + std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager); - using armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data; + using armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data; void Execute() const override; private: diff --git a/src/armnn/backends/test/ArmComputeCl.cpp b/src/armnn/backends/test/ArmComputeCl.cpp index 2c1d8b66cf..d8a70f03c0 100644 --- a/src/armnn/backends/test/ArmComputeCl.cpp +++ b/src/armnn/backends/test/ArmComputeCl.cpp @@ -42,6 +42,8 @@ ARMNN_AUTO_TEST_CASE(ReLu6Uint8, BoundedReLuUint8UpperBoundOnlyTest) ARMNN_AUTO_TEST_CASE(SimpleFullyConnected, FullyConnectedFloat32Test, false, false) ARMNN_AUTO_TEST_CASE(SimpleFullyConnectedWithBias, FullyConnectedFloat32Test, true, false) ARMNN_AUTO_TEST_CASE(SimpleFullyConnectedWithTranspose, FullyConnectedFloat32Test, false, true) +ARMNN_AUTO_TEST_CASE(FullyConnectedUint8, FullyConnectedUint8Test, false) +ARMNN_AUTO_TEST_CASE(FullyConnectedBiasedUint8, FullyConnectedUint8Test, true) ARMNN_AUTO_TEST_CASE(FullyConnectedLarge, FullyConnectedLargeTest, false) ARMNN_AUTO_TEST_CASE(FullyConnectedLargeTransposed, FullyConnectedLargeTest, true) diff --git a/src/armnn/backends/test/CreateWorkloadCl.cpp b/src/armnn/backends/test/CreateWorkloadCl.cpp index a273582e53..bce265c7d0 100644 --- a/src/armnn/backends/test/CreateWorkloadCl.cpp +++ b/src/armnn/backends/test/CreateWorkloadCl.cpp @@ -268,12 +268,12 @@ static void ClCreateFullyConnectedWorkloadTest() BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloatWorkloadTest) { - ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float32>(); + ClCreateFullyConnectedWorkloadTest<ClFullyConnectedWorkload, armnn::DataType::Float32>(); } BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat16WorkloadTest) { - ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float16>(); + ClCreateFullyConnectedWorkloadTest<ClFullyConnectedWorkload, armnn::DataType::Float16>(); } template <typename NormalizationWorkloadType, typename armnn::DataType DataType> |