aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefFullyConnectedWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.cpp57
1 files changed, 20 insertions, 37 deletions
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
index c6ea147043..087fc9da68 100644
--- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
@@ -12,43 +12,26 @@
namespace armnn
{
-RefFullyConnectedWorkload::RefFullyConnectedWorkload(
- const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
- : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
-{
-}
-void RefFullyConnectedWorkload::PostAllocationConfigure()
+unsigned int GetNumActivations(const TensorInfo& inputInfo)
{
- PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs);
-}
-
-void RefFullyConnectedWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs,
- std::vector<ITensorHandle*> outputs)
-{
- const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
- ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
- m_InputShape = inputInfo.GetShape();
-
- const TensorInfo& rWeightInfo = GetTensorInfo(inputs[1]);
- ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
- m_WeightShape = rWeightInfo.GetShape();
- m_WeightDecoder = MakeDecoder<float>(rWeightInfo);
-
- if (m_Data.m_Parameters.m_BiasEnabled)
+ unsigned int numActivations = 1; // Total number of activations in the input.
+ for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
{
- const TensorInfo& biasInfo = GetTensorInfo(inputs[2]);
- m_BiasDecoder = MakeDecoder<float>(biasInfo);
+ numActivations *= inputInfo.GetShape()[i];
}
+ return numActivations;
+}
- const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
- m_OutputShape = outputInfo.GetShape();
- m_NumActivations = 1; // Total number of activations in the input.
- for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
- {
- m_NumActivations *= inputInfo.GetShape()[i];
- }
+RefFullyConnectedWorkload::RefFullyConnectedWorkload(
+ const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
+ , m_InputShape(info.m_InputTensorInfos[0].GetShape())
+ , m_WeightShape(info.m_InputTensorInfos[1].GetShape())
+ , m_OutputShape(info.m_OutputTensorInfos[0].GetShape())
+ , m_NumActivations(GetNumActivations(info.m_InputTensorInfos[0]))
+{
}
void RefFullyConnectedWorkload::Execute() const
@@ -58,8 +41,6 @@ void RefFullyConnectedWorkload::Execute() const
void RefFullyConnectedWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
{
- PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
-
Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
}
@@ -70,10 +51,12 @@ void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std:
std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map());
std::unique_ptr<Encoder<float>> OutputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map());
- m_WeightDecoder->Reset(inputs[1]->Map());
+ std::unique_ptr<Decoder<float>> weightsDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]), inputs[1]->Map());
+ std::unique_ptr<Decoder<float>> biasDecoder;
+
if (m_Data.m_Parameters.m_BiasEnabled)
{
- m_BiasDecoder->Reset(inputs[2]->Map());
+ biasDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), inputs[2]->Map());
}
FullyConnected(m_InputShape,
@@ -81,8 +64,8 @@ void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std:
m_OutputShape,
*OutputEncoder,
m_WeightShape,
- *m_WeightDecoder,
- m_BiasDecoder.get(),
+ *weightsDecoder,
+ biasDecoder.get(),
m_Data.m_Parameters.m_BiasEnabled,
m_NumActivations,
m_Data.m_Parameters.m_TransposeWeightMatrix);