diff options
Diffstat (limited to 'src/backends/reference/workloads/RefFullyConnectedWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefFullyConnectedWorkload.cpp | 57 |
1 files changed, 20 insertions, 37 deletions
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp index c6ea147043..087fc9da68 100644 --- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp +++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp @@ -12,43 +12,26 @@ namespace armnn { -RefFullyConnectedWorkload::RefFullyConnectedWorkload( - const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) - : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info) -{ -} -void RefFullyConnectedWorkload::PostAllocationConfigure() +unsigned int GetNumActivations(const TensorInfo& inputInfo) { - PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs); -} - -void RefFullyConnectedWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs, - std::vector<ITensorHandle*> outputs) -{ - const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); - ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); - m_InputShape = inputInfo.GetShape(); - - const TensorInfo& rWeightInfo = GetTensorInfo(inputs[1]); - ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); - m_WeightShape = rWeightInfo.GetShape(); - m_WeightDecoder = MakeDecoder<float>(rWeightInfo); - - if (m_Data.m_Parameters.m_BiasEnabled) + unsigned int numActivations = 1; // Total number of activations in the input. + for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++) { - const TensorInfo& biasInfo = GetTensorInfo(inputs[2]); - m_BiasDecoder = MakeDecoder<float>(biasInfo); + numActivations *= inputInfo.GetShape()[i]; } + return numActivations; +} - const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); - m_OutputShape = outputInfo.GetShape(); - m_NumActivations = 1; // Total number of activations in the input. - for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++) - { - m_NumActivations *= inputInfo.GetShape()[i]; - } +RefFullyConnectedWorkload::RefFullyConnectedWorkload( + const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) + : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info) + , m_InputShape(info.m_InputTensorInfos[0].GetShape()) + , m_WeightShape(info.m_InputTensorInfos[1].GetShape()) + , m_OutputShape(info.m_OutputTensorInfos[0].GetShape()) + , m_NumActivations(GetNumActivations(info.m_InputTensorInfos[0])) +{ } void RefFullyConnectedWorkload::Execute() const @@ -58,8 +41,6 @@ void RefFullyConnectedWorkload::Execute() const void RefFullyConnectedWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) { - PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); - Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); } @@ -70,10 +51,12 @@ void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std: std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map()); std::unique_ptr<Encoder<float>> OutputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map()); - m_WeightDecoder->Reset(inputs[1]->Map()); + std::unique_ptr<Decoder<float>> weightsDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]), inputs[1]->Map()); + std::unique_ptr<Decoder<float>> biasDecoder; + if (m_Data.m_Parameters.m_BiasEnabled) { - m_BiasDecoder->Reset(inputs[2]->Map()); + biasDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), inputs[2]->Map()); } FullyConnected(m_InputShape, @@ -81,8 +64,8 @@ void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std: m_OutputShape, *OutputEncoder, m_WeightShape, - *m_WeightDecoder, - m_BiasDecoder.get(), + *weightsDecoder, + biasDecoder.get(), m_Data.m_Parameters.m_BiasEnabled, m_NumActivations, m_Data.m_Parameters.m_TransposeWeightMatrix); |