aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2019-05-27 12:14:10 +0100
committerFrancis Murtagh <francis.murtagh@arm.com>2019-05-27 12:14:10 +0100
commit43aec5886449c1b024b740fd6f4500e827bde221 (patch)
treec12a128dcc6895a0663a4e4dd27c4110e492c6dd
parent7f2c35a82ec11be50b3478bd15207320bbf3bd57 (diff)
downloadarmnn-43aec5886449c1b024b740fd6f4500e827bde221.tar.gz
IVGCVSW-3134 Refactor FullyConnected workloads into single workload
* Refactor FullyConnected workloads into single workload. * Refactor FullyConnected ref implementation to use Encoders and Decoders to support all DataTypes. * Deleted RefFullyConnectedFloat32Workload and RefFullyConnected2dUint8Workload. Change-Id: Iad30fb0287ab7491e1297997e7d61f1d00785541 Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp2
-rw-r--r--src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp6
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp2
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/backend.mk3
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp8
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp11
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt6
-rw-r--r--src/backends/reference/workloads/Decoders.hpp2
-rw-r--r--src/backends/reference/workloads/FullyConnected.cpp50
-rw-r--r--src/backends/reference/workloads/FullyConnected.hpp20
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedFloat32Workload.cpp43
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedFloat32Workload.hpp26
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedUint8Workload.cpp66
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedUint8Workload.hpp26
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.cpp65
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.hpp43
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp4
18 files changed, 176 insertions, 209 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index 927d7e78bb..0b0ba7ddf1 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -100,7 +100,7 @@ public:
virtual std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
- virtual std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
+ virtual std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
virtual std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp b/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
index ac2595b6bf..3e6223ab79 100644
--- a/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
@@ -49,6 +49,12 @@ LayerTestResult<T, 2> SimpleFullyConnectedTestImpl(
ExecuteWorkload(*workload, memoryManager);
+ if (workloadFactory.GetBackendId() == armnn::Compute::CpuRef)
+ {
+ workload->PostAllocationConfigure();
+ workload->Execute();
+ }
+
CopyDataFromITensorHandle(&result.output[0][0], outputHandle.get());
return result;
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 3793ecf70f..119eb7df90 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -125,7 +125,7 @@ BOOST_AUTO_TEST_CASE(FullyConnectedQueueDescriptor_Validate_RequiredDataMissing)
//Invalid argument exception is expected, because not all required fields have been provided.
//In particular inputsData[0], outputsData[0] and weightsData can not be null.
- BOOST_CHECK_THROW(RefFullyConnectedFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefFullyConnectedWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 2266fcd2f2..6abcf9cd08 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -141,7 +141,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQ
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<RefFullyConnectedFloat32Workload, RefFullyConnectedUint8Workload>(descriptor, info);
+ return std::make_unique<RefFullyConnectedWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index b7a35ab9cb..50cfbf68cc 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -42,8 +42,7 @@ BACKEND_SOURCES := \
workloads/RefElementwiseWorkload.cpp \
workloads/RefFakeQuantizationFloat32Workload.cpp \
workloads/RefFloorFloat32Workload.cpp \
- workloads/RefFullyConnectedFloat32Workload.cpp \
- workloads/RefFullyConnectedUint8Workload.cpp \
+ workloads/RefFullyConnectedWorkload.cpp \
workloads/RefGatherWorkload.cpp \
workloads/RefL2NormalizationFloat32Workload.cpp \
workloads/RefLstmWorkload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 7226bd0fe7..48b85cb9de 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -317,14 +317,14 @@ static void RefCreateFullyConnectedWorkloadTest()
TensorInfo({ 3, 7 }, DataType, outputQScale));
}
-BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat32Workload)
+BOOST_AUTO_TEST_CASE(CreateFullyConnectedWorkloadFloat32)
{
- RefCreateFullyConnectedWorkloadTest<RefFullyConnectedFloat32Workload, armnn::DataType::Float32>();
+ RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::Float32>();
}
-BOOST_AUTO_TEST_CASE(CreateFullyConnectedUint8Workload)
+BOOST_AUTO_TEST_CASE(CreateFullyConnectedWorkloadQuantisedAsymm8)
{
- RefCreateFullyConnectedWorkloadTest<RefFullyConnectedUint8Workload, armnn::DataType::QuantisedAsymm8>();
+ RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QuantisedAsymm8>();
}
template <typename NormalizationWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index ab6de2b37f..26b0179e71 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -23,6 +23,8 @@ public:
virtual BaseIterator& operator+=(const unsigned int increment) = 0;
virtual BaseIterator& operator-=(const unsigned int increment) = 0;
+
+ virtual BaseIterator& operator[](const unsigned int index) = 0;
};
template<typename IType>
@@ -54,7 +56,7 @@ class TypedIterator : public Base
{
public:
TypedIterator(T* data)
- : m_Iterator(data)
+ : m_Iterator(data), m_Start(data)
{}
TypedIterator& operator++() override
@@ -75,8 +77,15 @@ public:
return *this;
}
+ TypedIterator& operator[](const unsigned int index) override
+ {
+ m_Iterator = m_Start + index;
+ return *this;
+ }
+
protected:
T* m_Iterator;
+ T* m_Start;
};
class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index ef5e46a3cc..7f26d78c7e 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -69,10 +69,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefFakeQuantizationFloat32Workload.hpp
RefFloorFloat32Workload.cpp
RefFloorFloat32Workload.hpp
- RefFullyConnectedFloat32Workload.cpp
- RefFullyConnectedFloat32Workload.hpp
- RefFullyConnectedUint8Workload.cpp
- RefFullyConnectedUint8Workload.hpp
+ RefFullyConnectedWorkload.cpp
+ RefFullyConnectedWorkload.hpp
RefGatherWorkload.cpp
RefGatherWorkload.hpp
RefL2NormalizationFloat32Workload.cpp
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index 57c19a2a58..f5ec90662a 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -7,6 +7,8 @@
#include "BaseIterator.hpp"
+#include <boost/assert.hpp>
+
namespace armnn
{
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp
index bf5814d2ad..02d9b060ef 100644
--- a/src/backends/reference/workloads/FullyConnected.cpp
+++ b/src/backends/reference/workloads/FullyConnected.cpp
@@ -5,32 +5,29 @@
#include "FullyConnected.hpp"
+#include "RefWorkloadUtils.hpp"
+
#include <boost/assert.hpp>
namespace armnn
{
-void FullyConnected(const float* inputData,
- float* outputData,
- const TensorInfo& inputTensorInfo,
- const TensorInfo& outputTensorInfo,
- const float* weightData,
- const float* biasData,
- bool transposeWeights)
+void FullyConnected(const TensorShape& rInputShape,
+ Decoder<float>& rInputDecoder,
+ const TensorShape& rOutputShape,
+ Encoder<float>& rOutputEncoder,
+ Decoder<float>& rWeightDecoder,
+ Decoder<float>& rBiasDecoder,
+ const bool biasEnabled,
+ const unsigned int K,
+ const bool transposeWeights)
{
- unsigned int N = outputTensorInfo.GetShape()[1]; // Outputs Vector Size.
-
- BOOST_ASSERT(inputTensorInfo.GetNumDimensions() > 1); // Needs some data.
-
- unsigned int K = 1; // Total number of activations in the input.
- for (unsigned int i = 1; i < inputTensorInfo.GetNumDimensions(); i++)
- {
- K *= inputTensorInfo.GetShape()[i];
- }
+ // Perform FullyConnected implementation
+ unsigned int outputSize = rOutputShape[1];
- for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++)
+ for (unsigned int n = 0; n < rInputShape[0]; n++)
{
- for (unsigned int channelOutput = 0; channelOutput < N; channelOutput++)
+ for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
{
float outval = 0.f;
@@ -39,22 +36,27 @@ void FullyConnected(const float* inputData,
float weight;
if (transposeWeights)
{
- weight = weightData[channelOutput * K + channelInput];
+ rWeightDecoder[channelOutput * K + channelInput];
+ weight = rWeightDecoder.Get();
}
else
{
- weight = weightData[channelInput * N + channelOutput];
+ rWeightDecoder[channelInput * outputSize + channelOutput];
+ weight = rWeightDecoder.Get();
}
- outval += weight * inputData[n * K + channelInput];
+ rInputDecoder[n * K + channelInput];
+ outval += weight * rInputDecoder.Get();
}
- if (biasData)
+ if (biasEnabled)
{
- outval += biasData[channelOutput];
+ rBiasDecoder[channelOutput];
+ outval += rBiasDecoder.Get();
}
- outputData[n * N + channelOutput] = outval;
+ rOutputEncoder[n * outputSize + channelOutput];
+ rOutputEncoder.Set(outval);
}
}
}
diff --git a/src/backends/reference/workloads/FullyConnected.hpp b/src/backends/reference/workloads/FullyConnected.hpp
index 623259f8f8..78fa055086 100644
--- a/src/backends/reference/workloads/FullyConnected.hpp
+++ b/src/backends/reference/workloads/FullyConnected.hpp
@@ -5,18 +5,24 @@
#pragma once
+#include "BaseIterator.hpp"
+#include "Decoders.hpp"
+#include "Encoders.hpp"
#include <armnn/Tensor.hpp>
+#include <backendsCommon/WorkloadData.hpp>
namespace armnn
{
/// Performs a matrix multiplication and optionally adds a bias.
-void FullyConnected(const float* inputData,
- float* outputData,
- const TensorInfo& inputTensorInfo,
- const TensorInfo& outputTensorInfo,
- const float* weightData,
- const float* biasData,
- bool transposeWeights);
+void FullyConnected(const TensorShape& rInputShape,
+ Decoder<float>& rInputDecoder,
+ const TensorShape& rOutputShape,
+ Encoder<float>& rOutputEncoder,
+ Decoder<float>& rWeightDecoder,
+ Decoder<float>& rBiasDecoder,
+ bool biasEnabled,
+ unsigned int K,
+ bool transposeWeights);
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefFullyConnectedFloat32Workload.cpp b/src/backends/reference/workloads/RefFullyConnectedFloat32Workload.cpp
deleted file mode 100644
index ccaf4cd87b..0000000000
--- a/src/backends/reference/workloads/RefFullyConnectedFloat32Workload.cpp
+++ /dev/null
@@ -1,43 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefFullyConnectedFloat32Workload.hpp"
-
-#include "FullyConnected.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-RefFullyConnectedFloat32Workload::RefFullyConnectedFloat32Workload(
- const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
- : Float32Workload<FullyConnectedQueueDescriptor>(descriptor, info),
- m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight))),
- m_Bias(descriptor.m_Parameters.m_BiasEnabled
- ? std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)) : nullptr) {}
-
-void RefFullyConnectedFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedFloat32Workload_Execute");
-
- const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- float* outputData = GetOutputTensorDataFloat(0, m_Data);
- const float* inputData = GetInputTensorDataFloat(0, m_Data);
- const float* weightData = m_Weight->GetConstTensor<float>();
- const float* biasData = m_Data.m_Parameters.m_BiasEnabled ? m_Bias->GetConstTensor<float>() : nullptr;
-
- FullyConnected(inputData,
- outputData,
- inputInfo,
- outputInfo,
- weightData,
- biasData,
- m_Data.m_Parameters.m_TransposeWeightMatrix);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefFullyConnectedFloat32Workload.hpp b/src/backends/reference/workloads/RefFullyConnectedFloat32Workload.hpp
deleted file mode 100644
index 6a05024ca3..0000000000
--- a/src/backends/reference/workloads/RefFullyConnectedFloat32Workload.hpp
+++ /dev/null
@@ -1,26 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefFullyConnectedFloat32Workload : public Float32Workload<FullyConnectedQueueDescriptor>
-{
-public:
- explicit RefFullyConnectedFloat32Workload(const FullyConnectedQueueDescriptor& descriptor,
- const WorkloadInfo& info);
- virtual void Execute() const override;
-
-private:
- std::unique_ptr<ScopedCpuTensorHandle> m_Weight;
- std::unique_ptr<ScopedCpuTensorHandle> m_Bias;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefFullyConnectedUint8Workload.cpp b/src/backends/reference/workloads/RefFullyConnectedUint8Workload.cpp
deleted file mode 100644
index cd785d786c..0000000000
--- a/src/backends/reference/workloads/RefFullyConnectedUint8Workload.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefFullyConnectedUint8Workload.hpp"
-
-#include "FullyConnected.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-#include <vector>
-
-namespace armnn
-{
-RefFullyConnectedUint8Workload::RefFullyConnectedUint8Workload(
- const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
- : Uint8Workload<FullyConnectedQueueDescriptor>(descriptor, info),
- m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight))),
- m_Bias(descriptor.m_Parameters.m_BiasEnabled
- ? std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)) : nullptr) {}
-
-void RefFullyConnectedUint8Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedUint8Workload_Execute");
-
- const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- const uint8_t* weightData = m_Weight->GetConstTensor<uint8_t>();
-
- auto dequant = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo);
-
- auto weight = Dequantize(weightData, m_Weight->GetTensorInfo());
-
- std::vector<float> results(outputInfo.GetNumElements());
-
- if (m_Data.m_Parameters.m_BiasEnabled)
- {
- const int32_t* biasData = m_Bias->GetConstTensor<int32_t>();
- auto bias = Dequantize(biasData, m_Bias->GetTensorInfo());
-
- FullyConnected(dequant.data(),
- results.data(),
- inputInfo,
- outputInfo,
- weight.data(),
- bias.data(),
- m_Data.m_Parameters.m_TransposeWeightMatrix);
- }
- else
- {
- FullyConnected(dequant.data(),
- results.data(),
- inputInfo,
- outputInfo,
- weight.data(),
- nullptr,
- m_Data.m_Parameters.m_TransposeWeightMatrix);
- }
-
- Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefFullyConnectedUint8Workload.hpp b/src/backends/reference/workloads/RefFullyConnectedUint8Workload.hpp
deleted file mode 100644
index 679ad8626a..0000000000
--- a/src/backends/reference/workloads/RefFullyConnectedUint8Workload.hpp
+++ /dev/null
@@ -1,26 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefFullyConnectedUint8Workload : public Uint8Workload<FullyConnectedQueueDescriptor>
-{
-public:
- explicit RefFullyConnectedUint8Workload(const FullyConnectedQueueDescriptor& descriptor,
- const WorkloadInfo& info);
- virtual void Execute() const override;
-
-private:
- std::unique_ptr<ScopedCpuTensorHandle> m_Weight;
- std::unique_ptr<ScopedCpuTensorHandle> m_Bias;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
new file mode 100644
index 0000000000..dc7030ef81
--- /dev/null
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
@@ -0,0 +1,65 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefFullyConnectedWorkload.hpp"
+
+#include "FullyConnected.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+namespace armnn
+{
+RefFullyConnectedWorkload::RefFullyConnectedWorkload(
+ const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
+ m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
+{
+ const TensorInfo& rWeightInfo = GetTensorInfo(m_Weight.get());
+ m_WeightShape = rWeightInfo.GetShape();
+ m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
+
+ if (descriptor.m_Parameters.m_BiasEnabled)
+ {
+ m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
+ const TensorInfo& biasInfo = GetTensorInfo(m_Bias.get());
+ m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
+ }
+}
+
+void RefFullyConnectedWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ BOOST_ASSERT(inputInfo.GetNumDimensions() > 1);
+ m_InputShape = inputInfo.GetShape();
+ m_InputDecoder = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+ m_OutputShape = outputInfo.GetShape();
+ m_OutputEncoder = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+ m_NumActivations = 1; // Total number of activations in the input.
+ for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
+ {
+ m_NumActivations *= inputInfo.GetShape()[i];
+ }
+}
+
+void RefFullyConnectedWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
+
+ FullyConnected(m_InputShape,
+ *m_InputDecoder,
+ m_OutputShape,
+ *m_OutputEncoder,
+ *m_WeightDecoder,
+ *m_BiasDecoder,
+ m_Data.m_Parameters.m_BiasEnabled,
+ m_NumActivations,
+ m_Data.m_Parameters.m_TransposeWeightMatrix);
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp
new file mode 100644
index 0000000000..d4a63d23ae
--- /dev/null
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp
@@ -0,0 +1,43 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include "BaseIterator.hpp"
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+
+
+namespace armnn
+{
+
+class RefFullyConnectedWorkload : public BaseWorkload<FullyConnectedQueueDescriptor>
+{
+public:
+ explicit RefFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
+ void PostAllocationConfigure() override;
+
+ virtual void Execute() const override;
+
+private:
+ std::unique_ptr<ScopedCpuTensorHandle> m_Weight;
+ std::unique_ptr<ScopedCpuTensorHandle> m_Bias;
+
+ std::unique_ptr<Decoder<float>> m_InputDecoder;
+ std::unique_ptr<Encoder<float>> m_OutputEncoder;
+ std::unique_ptr<Decoder<float>> m_WeightDecoder;
+ std::unique_ptr<Decoder<float>> m_BiasDecoder;
+
+ TensorShape m_InputShape;
+ TensorShape m_OutputShape;
+ TensorShape m_WeightShape;
+ unsigned int m_NumActivations;
+};
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 291f991f1e..54bc5c7f01 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -17,13 +17,12 @@
#include "RefPooling2dFloat32Workload.hpp"
#include "RefWorkloadUtils.hpp"
#include "RefConcatWorkload.hpp"
-#include "RefFullyConnectedFloat32Workload.hpp"
+#include "RefFullyConnectedWorkload.hpp"
#include "RefGatherWorkload.hpp"
#include "Softmax.hpp"
#include "TensorBufferArrayView.hpp"
#include "RefBatchNormalizationFloat32Workload.hpp"
#include "Splitter.hpp"
-#include "RefFullyConnectedUint8Workload.hpp"
#include "RefReshapeFloat32Workload.hpp"
#include "RefDepthwiseConvolution2dWorkload.hpp"
#include "FullyConnected.hpp"
@@ -59,5 +58,4 @@
#include "RefDebugWorkload.hpp"
#include "RefRsqrtFloat32Workload.hpp"
#include "RefDequantizeWorkload.hpp"
-
#include "RefQuantizeWorkload.hpp"