aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-03-25 07:46:55 +0000
committerSadik Armagan <sadik.armagan@arm.com>2021-03-25 07:46:55 +0000
commitf0a6dec75832604d5ab18242dc216852821a8279 (patch)
treeff25e64c62c63975a54abd16a8bff744be70d7c0 /src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
parent16fb1a2d9c1d3d80c0f0b6ab549919fbabd2a0b9 (diff)
downloadarmnn-f0a6dec75832604d5ab18242dc216852821a8279.tar.gz
IVGCVSW-5736 and IVGCVSW-5743 'NonConstWeights: Update front-end and TfLiteDelegate support for FullyConnected Operator'
* Added front-end support for non-const weights for FULLY_CONNECTED operator * Added FULLY_CONNECTED end-to-end test * Updated FULLY_CONNECTED operator support in TfLite Arm NN Delegate for non-const weights * Updated the version numbers Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Iffa5b9aa9297aca4c02d923cce4636c88ac21faa
Diffstat (limited to 'src/backends/reference/workloads/RefFullyConnectedWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.cpp45
1 files changed, 35 insertions, 10 deletions
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
index 9acca219b5..49e105f206 100644
--- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
@@ -14,18 +14,21 @@ namespace armnn
{
RefFullyConnectedWorkload::RefFullyConnectedWorkload(
const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
- m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
+ : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
{
- const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
- m_WeightShape = rWeightInfo.GetShape();
- m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
-
- if (descriptor.m_Parameters.m_BiasEnabled)
+ if (descriptor.m_Parameters.m_ConstantWeights)
{
- m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
- const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
- m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
+ m_Weight = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight));
+ const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
+ m_WeightShape = rWeightInfo.GetShape();
+ m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
+
+ if (descriptor.m_Parameters.m_BiasEnabled)
+ {
+ m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
+ const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
+ m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
+ }
}
}
@@ -36,6 +39,20 @@ void RefFullyConnectedWorkload::PostAllocationConfigure()
m_InputShape = inputInfo.GetShape();
m_InputDecoder = MakeDecoder<float>(inputInfo);
+ if (!m_Data.m_Parameters.m_ConstantWeights)
+ {
+ const TensorInfo& rWeightInfo = GetTensorInfo(m_Data.m_Inputs[1]);
+ ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
+ m_WeightShape = rWeightInfo.GetShape();
+ m_WeightDecoder = MakeDecoder<float>(rWeightInfo);
+
+ if (m_Data.m_Parameters.m_BiasEnabled)
+ {
+ const TensorInfo& biasInfo = GetTensorInfo(m_Data.m_Inputs[2]);
+ m_BiasDecoder = MakeDecoder<float>(biasInfo);
+ }
+ }
+
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
m_OutputShape = outputInfo.GetShape();
m_OutputEncoder = MakeEncoder<float>(outputInfo);
@@ -52,6 +69,14 @@ void RefFullyConnectedWorkload::Execute() const
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
+ if (!m_Data.m_Parameters.m_ConstantWeights)
+ {
+ m_WeightDecoder->Reset(m_Data.m_Inputs[1]->Map());
+ if (m_Data.m_Parameters.m_BiasEnabled)
+ {
+ m_BiasDecoder->Reset(m_Data.m_Inputs[2]->Map());
+ }
+ }
m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
FullyConnected(m_InputShape,