aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefBackend.cpp10
-rw-r--r--src/backends/reference/RefBackend.hpp6
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp6
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp14
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.cpp45
5 files changed, 67 insertions, 14 deletions
diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp
index e93b317dce..53c55ab26a 100644
--- a/src/backends/reference/RefBackend.cpp
+++ b/src/backends/reference/RefBackend.cpp
@@ -69,6 +69,16 @@ IBackendInternal::ILayerSupportSharedPtr RefBackend::GetLayerSupport() const
return layerSupport;
}
+bool RefBackend::HasCapability(BackendCapability capabilityClass) const
+{
+ auto search = backendCapabilities.find(capabilityClass);
+ if (search != backendCapabilities.end())
+ {
+ return true;
+ }
+ return false;
+}
+
OptimizationViews RefBackend::OptimizeSubgraphView(const SubgraphView& subgraph) const
{
OptimizationViews optimizationViews;
diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp
index 92d392dde6..c92936ca0c 100644
--- a/src/backends/reference/RefBackend.hpp
+++ b/src/backends/reference/RefBackend.hpp
@@ -9,6 +9,10 @@
namespace armnn
{
+const std::set<armnn::BackendCapability> backendCapabilities {
+ armnn::BackendCapability::NonConstWeights,
+};
+
class RefBackend : public IBackendInternal
{
public:
@@ -39,6 +43,8 @@ public:
std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override;
void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry) override;
+
+ bool HasCapability(BackendCapability capabilityClass) const override;
};
} // namespace armnn
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 4598568070..b6974811ef 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -15,6 +15,7 @@
#include <backendsCommon/test/DetectionPostProcessEndToEndTestImpl.hpp>
#include <backendsCommon/test/ElementwiseUnaryEndToEndTestImpl.hpp>
#include <backendsCommon/test/FillEndToEndTestImpl.hpp>
+#include <backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp>
#include <backendsCommon/test/GatherEndToEndTestImpl.hpp>
#include <backendsCommon/test/InstanceNormalizationEndToEndTestImpl.hpp>
#include <backendsCommon/test/LogSoftmaxEndToEndTestImpl.hpp>
@@ -599,6 +600,11 @@ BOOST_AUTO_TEST_CASE(RefFillEndToEndTestInt32)
FillEndToEnd<armnn::DataType::Signed32>(defaultBackends);
}
+BOOST_AUTO_TEST_CASE(RefFullyConnectedEndToEndTestInt32)
+{
+ FullyConnectedWithDynamicWeightsEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
BOOST_AUTO_TEST_CASE(RefGatherFloatTest)
{
GatherEndToEnd<armnn::DataType::Float32>(defaultBackends);
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 161476ed98..7371692d0e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -579,16 +579,22 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(HardSwishInt16, HardSwishInt16Test)
// Fully Connected
ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnected, FullyConnectedFloat32Test, false, false)
-ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedUint8, FullyConnectedTest<DataType::QAsymmU8>, false)
-ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedQSymm16, FullyConnectedTest<DataType::QSymmS16>, false)
+ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedUint8, FullyConnectedTest<DataType::QAsymmU8>, false, true)
+ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedQSymm16, FullyConnectedTest<DataType::QSymmS16>, false, true)
ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnectedWithBias, FullyConnectedFloat32Test, true, false)
-ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedUint8, FullyConnectedTest<DataType::QAsymmU8>, true)
-ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedQSymm16, FullyConnectedTest<DataType::QSymmS16>, true)
+ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedUint8, FullyConnectedTest<DataType::QAsymmU8>, true, true)
+ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedBiasedQSymm16, FullyConnectedTest<DataType::QSymmS16>, true, true)
ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnectedWithTranspose, FullyConnectedFloat32Test, false, true)
ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedLarge, FullyConnectedLargeTest, false)
ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedLargeTransposed, FullyConnectedLargeTest, true)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedWeightsAsInputsUint8,
+ FullyConnectedTest<DataType::QAsymmU8>,
+ false,
+ false)
+
// Splitter
ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleSplitterFloat32, SplitterFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleSplitterFloat16, SplitterFloat16Test)
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
index 9acca219b5..49e105f206 100644
--- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
@@ -14,18 +14,21 @@ namespace armnn
{
RefFullyConnectedWorkload::RefFullyConnectedWorkload(
const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
- m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
+ : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
{
- const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
- m_WeightShape = rWeightInfo.GetShape();
- m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
-
- if (descriptor.m_Parameters.m_BiasEnabled)
+ if (descriptor.m_Parameters.m_ConstantWeights)
{
- m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
- const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
- m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
+ m_Weight = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight));
+ const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
+ m_WeightShape = rWeightInfo.GetShape();
+ m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
+
+ if (descriptor.m_Parameters.m_BiasEnabled)
+ {
+ m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
+ const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
+ m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
+ }
}
}
@@ -36,6 +39,20 @@ void RefFullyConnectedWorkload::PostAllocationConfigure()
m_InputShape = inputInfo.GetShape();
m_InputDecoder = MakeDecoder<float>(inputInfo);
+ if (!m_Data.m_Parameters.m_ConstantWeights)
+ {
+ const TensorInfo& rWeightInfo = GetTensorInfo(m_Data.m_Inputs[1]);
+ ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
+ m_WeightShape = rWeightInfo.GetShape();
+ m_WeightDecoder = MakeDecoder<float>(rWeightInfo);
+
+ if (m_Data.m_Parameters.m_BiasEnabled)
+ {
+ const TensorInfo& biasInfo = GetTensorInfo(m_Data.m_Inputs[2]);
+ m_BiasDecoder = MakeDecoder<float>(biasInfo);
+ }
+ }
+
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
m_OutputShape = outputInfo.GetShape();
m_OutputEncoder = MakeEncoder<float>(outputInfo);
@@ -52,6 +69,14 @@ void RefFullyConnectedWorkload::Execute() const
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
+ if (!m_Data.m_Parameters.m_ConstantWeights)
+ {
+ m_WeightDecoder->Reset(m_Data.m_Inputs[1]->Map());
+ if (m_Data.m_Parameters.m_BiasEnabled)
+ {
+ m_BiasDecoder->Reset(m_Data.m_Inputs[2]->Map());
+ }
+ }
m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
FullyConnected(m_InputShape,