aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2018-09-17 11:17:41 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:56 +0100
commitab8cdc13443408d727fd38be42315cf942251940 (patch)
treef1576377150b72f1a2404915adcf1affd5eab3e4
parentca225f0ab9aac74ccc7c62cfcf46c95f7715b2ee (diff)
downloadarmnn-ab8cdc13443408d727fd38be42315cf942251940.tar.gz
IVGCVSW-949 Add 8-bit fully connected support
Change-Id: I0953bb8dbc4b76001f207e37c8c2742a6ebd888b
-rw-r--r--Android.mk2
-rw-r--r--CMakeLists.txt4
-rw-r--r--src/armnn/backends/ClLayerSupport.cpp7
-rw-r--r--src/armnn/backends/ClWorkloadFactory.cpp4
-rw-r--r--src/armnn/backends/ClWorkloads.hpp2
-rw-r--r--src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp (renamed from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp)32
-rw-r--r--src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp (renamed from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp)10
-rw-r--r--src/armnn/backends/test/ArmComputeCl.cpp2
-rw-r--r--src/armnn/backends/test/CreateWorkloadCl.cpp4
9 files changed, 39 insertions, 28 deletions
diff --git a/Android.mk b/Android.mk
index c070b28f87..ad02db9c51 100644
--- a/Android.mk
+++ b/Android.mk
@@ -61,7 +61,7 @@ LOCAL_SRC_FILES := \
src/armnn/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp \
src/armnn/backends/ClWorkloads/ClDivisionFloatWorkload.cpp \
src/armnn/backends/ClWorkloads/ClFloorFloatWorkload.cpp \
- src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp \
+ src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp \
src/armnn/backends/ClWorkloads/ClL2NormalizationFloatWorkload.cpp \
src/armnn/backends/ClWorkloads/ClLstmFloatWorkload.cpp \
src/armnn/backends/ClWorkloads/ClMergerFloatWorkload.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 429046142f..a5dde68e56 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -506,8 +506,8 @@ if(ARMCOMPUTECL)
src/armnn/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.hpp
src/armnn/backends/ClWorkloads/ClFloorFloatWorkload.cpp
src/armnn/backends/ClWorkloads/ClFloorFloatWorkload.hpp
- src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp
- src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp
+ src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
+ src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp
src/armnn/backends/ClWorkloads/ClL2NormalizationFloatWorkload.cpp
src/armnn/backends/ClWorkloads/ClL2NormalizationFloatWorkload.hpp
src/armnn/backends/ClWorkloads/ClLstmFloatWorkload.cpp
diff --git a/src/armnn/backends/ClLayerSupport.cpp b/src/armnn/backends/ClLayerSupport.cpp
index 4664c2ee32..30a1330706 100644
--- a/src/armnn/backends/ClLayerSupport.cpp
+++ b/src/armnn/backends/ClLayerSupport.cpp
@@ -24,7 +24,7 @@
#include "ClWorkloads/ClDivisionFloatWorkload.hpp"
#include "ClWorkloads/ClL2NormalizationFloatWorkload.hpp"
#include "ClWorkloads/ClMultiplicationFloatWorkload.hpp"
-#include "ClWorkloads/ClFullyConnectedFloatWorkload.hpp"
+#include "ClWorkloads/ClFullyConnectedWorkload.hpp"
#include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
#include "ClWorkloads/ClPermuteWorkload.hpp"
#include "ClWorkloads/ClNormalizationFloatWorkload.hpp"
@@ -269,11 +269,6 @@ bool IsFullyConnectedSupportedCl(const TensorInfo& input,
const FullyConnectedDescriptor& descriptor,
std::string* reasonIfUnsupported)
{
- // At the moment U8 is unsupported
- if (input.GetDataType() == DataType::QuantisedAsymm8)
- {
- return false;
- }
FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
reasonIfUnsupported,
input,
diff --git a/src/armnn/backends/ClWorkloadFactory.cpp b/src/armnn/backends/ClWorkloadFactory.cpp
index c35f044e9e..591fb85dbb 100644
--- a/src/armnn/backends/ClWorkloadFactory.cpp
+++ b/src/armnn/backends/ClWorkloadFactory.cpp
@@ -116,8 +116,8 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMerger(const MergerQu
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected(
const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<ClFullyConnectedFloatWorkload, NullWorkload>(descriptor, info,
- m_MemoryManager.GetIntraLayerManager());
+ return MakeWorkload<ClFullyConnectedWorkload, ClFullyConnectedWorkload>(descriptor, info,
+ m_MemoryManager.GetIntraLayerManager());
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
diff --git a/src/armnn/backends/ClWorkloads.hpp b/src/armnn/backends/ClWorkloads.hpp
index 3472bca45c..2bbda8a62b 100644
--- a/src/armnn/backends/ClWorkloads.hpp
+++ b/src/armnn/backends/ClWorkloads.hpp
@@ -18,7 +18,7 @@
#include "backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.hpp"
#include "backends/ClWorkloads/ClDivisionFloatWorkload.hpp"
#include "backends/ClWorkloads/ClFloorFloatWorkload.hpp"
-#include "backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp"
+#include "backends/ClWorkloads/ClFullyConnectedWorkload.hpp"
#include "backends/ClWorkloads/ClL2NormalizationFloatWorkload.hpp"
#include "backends/ClWorkloads/ClLstmFloatWorkload.hpp"
#include "backends/ClWorkloads/ClMergerFloatWorkload.hpp"
diff --git a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
index b1dab7c7b9..5307fab062 100644
--- a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp
+++ b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "ClFullyConnectedFloatWorkload.hpp"
+#include "ClFullyConnectedWorkload.hpp"
#include "backends/ClTensorHandle.hpp"
#include "backends/CpuTensorHandle.hpp"
#include "backends/ArmComputeTensorUtils.hpp"
@@ -42,9 +42,9 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
fullyConnectedLayerInfo);
}
-ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnectedQueueDescriptor& descriptor,
+ClFullyConnectedWorkload::ClFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
- : FloatWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
+ : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
, m_FullyConnectedLayer(memoryManager)
{
m_WeightsTensor = std::make_unique<arm_compute::CLTensor>();
@@ -56,7 +56,7 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
}
- m_Data.ValidateInputsOutputs("ClFullyConnectedFloatWorkload", 1, 1);
+ m_Data.ValidateInputsOutputs("ClFullyConnectedWorkload", 1, 1);
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
@@ -67,11 +67,25 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
// Allocate
- InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
+ if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
+ {
+ InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
+ }
+ else
+ {
+ InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
+ }
if (m_BiasesTensor)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
+ if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
+ {
+ InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
+ }
+ else
+ {
+ InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
+ }
}
// Force Compute Library to perform the necessary copying and reshaping, after which
@@ -80,13 +94,13 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
FreeUnusedTensors();
}
-void ClFullyConnectedFloatWorkload::Execute() const
+void ClFullyConnectedWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedFloatWorkload_Execute");
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedWorkload_Execute");
m_FullyConnectedLayer.run();
}
-void ClFullyConnectedFloatWorkload::FreeUnusedTensors()
+void ClFullyConnectedWorkload::FreeUnusedTensors()
{
FreeTensorIfUnused(m_WeightsTensor);
FreeTensorIfUnused(m_BiasesTensor);
diff --git a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp
index e8d6a7897d..7aa9b86e15 100644
--- a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp
+++ b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp
@@ -20,14 +20,14 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
const TensorInfo& biases,
const FullyConnectedDescriptor& descriptor);
-class ClFullyConnectedFloatWorkload : public armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>
+class ClFullyConnectedWorkload : public armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>
{
public:
- ClFullyConnectedFloatWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor,
- const armnn::WorkloadInfo& info,
- std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
+ ClFullyConnectedWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor,
+ const armnn::WorkloadInfo& info,
+ std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
- using armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data;
+ using armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data;
void Execute() const override;
private:
diff --git a/src/armnn/backends/test/ArmComputeCl.cpp b/src/armnn/backends/test/ArmComputeCl.cpp
index 2c1d8b66cf..d8a70f03c0 100644
--- a/src/armnn/backends/test/ArmComputeCl.cpp
+++ b/src/armnn/backends/test/ArmComputeCl.cpp
@@ -42,6 +42,8 @@ ARMNN_AUTO_TEST_CASE(ReLu6Uint8, BoundedReLuUint8UpperBoundOnlyTest)
ARMNN_AUTO_TEST_CASE(SimpleFullyConnected, FullyConnectedFloat32Test, false, false)
ARMNN_AUTO_TEST_CASE(SimpleFullyConnectedWithBias, FullyConnectedFloat32Test, true, false)
ARMNN_AUTO_TEST_CASE(SimpleFullyConnectedWithTranspose, FullyConnectedFloat32Test, false, true)
+ARMNN_AUTO_TEST_CASE(FullyConnectedUint8, FullyConnectedUint8Test, false)
+ARMNN_AUTO_TEST_CASE(FullyConnectedBiasedUint8, FullyConnectedUint8Test, true)
ARMNN_AUTO_TEST_CASE(FullyConnectedLarge, FullyConnectedLargeTest, false)
ARMNN_AUTO_TEST_CASE(FullyConnectedLargeTransposed, FullyConnectedLargeTest, true)
diff --git a/src/armnn/backends/test/CreateWorkloadCl.cpp b/src/armnn/backends/test/CreateWorkloadCl.cpp
index a273582e53..bce265c7d0 100644
--- a/src/armnn/backends/test/CreateWorkloadCl.cpp
+++ b/src/armnn/backends/test/CreateWorkloadCl.cpp
@@ -268,12 +268,12 @@ static void ClCreateFullyConnectedWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloatWorkloadTest)
{
- ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float32>();
+ ClCreateFullyConnectedWorkloadTest<ClFullyConnectedWorkload, armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat16WorkloadTest)
{
- ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float16>();
+ ClCreateFullyConnectedWorkloadTest<ClFullyConnectedWorkload, armnn::DataType::Float16>();
}
template <typename NormalizationWorkloadType, typename armnn::DataType DataType>