aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ClWorkloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/ClWorkloads')
-rw-r--r--src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp (renamed from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp)32
-rw-r--r--src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp (renamed from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp)10
2 files changed, 28 insertions, 14 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
index b1dab7c7b9..5307fab062 100644
--- a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp
+++ b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "ClFullyConnectedFloatWorkload.hpp"
+#include "ClFullyConnectedWorkload.hpp"
#include "backends/ClTensorHandle.hpp"
#include "backends/CpuTensorHandle.hpp"
#include "backends/ArmComputeTensorUtils.hpp"
@@ -42,9 +42,9 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
fullyConnectedLayerInfo);
}
-ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnectedQueueDescriptor& descriptor,
+ClFullyConnectedWorkload::ClFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
- : FloatWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
+ : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
, m_FullyConnectedLayer(memoryManager)
{
m_WeightsTensor = std::make_unique<arm_compute::CLTensor>();
@@ -56,7 +56,7 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
}
- m_Data.ValidateInputsOutputs("ClFullyConnectedFloatWorkload", 1, 1);
+ m_Data.ValidateInputsOutputs("ClFullyConnectedWorkload", 1, 1);
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
@@ -67,11 +67,25 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
// Allocate
- InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
+ if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
+ {
+ InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
+ }
+ else
+ {
+ InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
+ }
if (m_BiasesTensor)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
+ if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
+ {
+ InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
+ }
+ else
+ {
+ InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
+ }
}
// Force Compute Library to perform the necessary copying and reshaping, after which
@@ -80,13 +94,13 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
FreeUnusedTensors();
}
-void ClFullyConnectedFloatWorkload::Execute() const
+void ClFullyConnectedWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedFloatWorkload_Execute");
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedWorkload_Execute");
m_FullyConnectedLayer.run();
}
-void ClFullyConnectedFloatWorkload::FreeUnusedTensors()
+void ClFullyConnectedWorkload::FreeUnusedTensors()
{
FreeTensorIfUnused(m_WeightsTensor);
FreeTensorIfUnused(m_BiasesTensor);
diff --git a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp
index e8d6a7897d..7aa9b86e15 100644
--- a/src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp
+++ b/src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp
@@ -20,14 +20,14 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
const TensorInfo& biases,
const FullyConnectedDescriptor& descriptor);
-class ClFullyConnectedFloatWorkload : public armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>
+class ClFullyConnectedWorkload : public armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>
{
public:
- ClFullyConnectedFloatWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor,
- const armnn::WorkloadInfo& info,
- std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
+ ClFullyConnectedWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor,
+ const armnn::WorkloadInfo& info,
+ std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
- using armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data;
+ using armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data;
void Execute() const override;
private: