// // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "TosaRefPreCompiledWorkload.hpp" namespace armnn { TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) , m_workloadInfo(info) { // Check that the workload is holding a pointer to a valid pre-compiled object if (m_Data.m_PreCompiledObject == nullptr) { throw InvalidArgumentException( "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler)."); } } void TosaRefPreCompiledWorkload::Execute() const { tosa::TosaSerializationHandler* handler = static_cast(m_Data.m_PreCompiledObject); std::vector inputNames = handler->GetInputs(); std::vector outputNames = handler->GetOutputs(); TosaReference::IModelRunner runner; GraphStatus status; // Initialise the model runner with the TosaSerializationHandler status = runner.initialize(*handler); if(status != GraphStatus::TOSA_VALID) { throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model."); } // Set the inputs for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx) { DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType(); switch (dataType) { case DataType::Float16: SetInput(runner, inputNames[inputSlotIdx], inputSlotIdx); break; case DataType::Float32: SetInput(runner, inputNames[inputSlotIdx], inputSlotIdx); break; case DataType::QAsymmU8: case DataType::QAsymmS8: case DataType::QSymmS8: case DataType::QSymmS16: case DataType::Signed32: SetInput(runner, inputNames[inputSlotIdx], inputSlotIdx); break; case DataType::Signed64: SetInput(runner, inputNames[inputSlotIdx], inputSlotIdx); break; case DataType::Boolean: SetInput(runner, inputNames[inputSlotIdx], inputSlotIdx); break; default: throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend."); } } // Run the TOSA Reference Model status = runner.run(); if(status != GraphStatus::TOSA_VALID) { throw armnn::Exception("An error has occurred while running the TOSA Reference Model."); } // Gets the outputs for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx) { DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType(); switch (dataType) { case DataType::Float16: GetOutput(runner, outputNames[outputSlotIdx], outputSlotIdx); break; case DataType::Float32: GetOutput(runner, outputNames[outputSlotIdx], outputSlotIdx); break; case DataType::QAsymmU8: case DataType::QAsymmS8: case DataType::QSymmS8: case DataType::QSymmS16: case DataType::Signed32: GetOutput(runner, outputNames[outputSlotIdx], outputSlotIdx); break; case DataType::Signed64: GetOutput(runner, outputNames[outputSlotIdx], outputSlotIdx); break; case DataType::Boolean: GetOutput(runner, outputNames[outputSlotIdx], outputSlotIdx); break; default: throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend."); } } } template void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner, std::string inputName, uint32_t inputIndex) const { std::vector inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements()); m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data()); runner.setInput(inputName, inputData); } template void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner, std::string outputName, uint32_t outputIndex) const { std::vector actualOutputs = runner.getOutput(outputName); m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data()); } bool TosaRefPreCompiledWorkloadValidate(std::string*) { return true; } } //namespace armnn