// // Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "RefFullyConnectedWorkload.hpp" #include "FullyConnected.hpp" #include "RefWorkloadUtils.hpp" #include "Profiling.hpp" namespace armnn { unsigned int GetNumActivations(const TensorInfo& inputInfo) { unsigned int numActivations = 1; // Total number of activations in the input. for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++) { numActivations *= inputInfo.GetShape()[i]; } return numActivations; } RefFullyConnectedWorkload::RefFullyConnectedWorkload( const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) : RefBaseWorkload(descriptor, info) , m_InputShape(info.m_InputTensorInfos[0].GetShape()) , m_WeightShape(info.m_InputTensorInfos[1].GetShape()) , m_OutputShape(info.m_OutputTensorInfos[0].GetShape()) , m_NumActivations(GetNumActivations(info.m_InputTensorInfos[0])) { } void RefFullyConnectedWorkload::Execute() const { Execute(m_Data.m_Inputs, m_Data.m_Outputs); } void RefFullyConnectedWorkload::ExecuteAsync(ExecutionData& executionData) { WorkingMemDescriptor* workingMemDescriptor = static_cast(executionData.m_Data); Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); } void RefFullyConnectedWorkload::Execute(std::vector inputs, std::vector outputs) const { ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefFullyConnectedWorkload_Execute"); std::unique_ptr> inputDecoder = MakeDecoder(GetTensorInfo(inputs[0]), inputs[0]->Map()); std::unique_ptr> OutputEncoder = MakeEncoder(GetTensorInfo(outputs[0]), outputs[0]->Map()); std::unique_ptr> weightsDecoder = MakeDecoder(GetTensorInfo(inputs[1]), inputs[1]->Map()); std::unique_ptr> biasDecoder; if (m_Data.m_Parameters.m_BiasEnabled) { biasDecoder = MakeDecoder(GetTensorInfo(inputs[2]), inputs[2]->Map()); } FullyConnected(m_InputShape, *inputDecoder, m_OutputShape, *OutputEncoder, m_WeightShape, *weightsDecoder, biasDecoder.get(), m_Data.m_Parameters.m_BiasEnabled, m_NumActivations, m_Data.m_Parameters.m_TransposeWeightMatrix); } } //namespace armnn