From 5c54c3874ff5dc1656b9f28288e46086586f21b9 Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Wed, 9 Nov 2022 16:28:51 +0000 Subject: IVGCVSW-7165 Implement TosaRefPreCompiledWorkload::Execute() * Added FP32 support for TOSA Reference Backend. * Added main block creation to OptimizeSubgraphView, this will only occur once. Change-Id: I169dac50b78e2c693da6327962c9f1d3ae3bd712 Signed-off-by: James Conroy Signed-off-by: Matthew Sloyan --- .../workloads/TosaRefPreCompiledWorkload.cpp | 82 +++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) (limited to 'src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp') diff --git a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp index c4af4d40b1..18d2900eff 100644 --- a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp +++ b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp @@ -11,13 +11,91 @@ namespace armnn TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) + , m_workloadInfo(info) { - // Do nothing for now + // Check that the workload is holding a pointer to a valid pre-compiled object + if (m_Data.m_PreCompiledObject == nullptr) + { + throw InvalidArgumentException( + "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler)."); + } } void TosaRefPreCompiledWorkload::Execute() const { - // Do nothing for now + uint32_t numInputBuffers = static_cast(m_Data.m_Inputs.size()); + uint32_t numOutputBuffers = static_cast(m_Data.m_Outputs.size()); + + tosa::TosaSerializationHandler* handler = static_cast(m_Data.m_PreCompiledObject); + + std::vector input_names = handler->GetInputs(); + std::vector output_names = handler->GetOutputs(); + + TosaReference::IModelRunner runner; + GraphStatus status; + + // Initialise the model runner with the TosaSerializationHandler + status = runner.initialize(*handler); + if(status != GraphStatus::TOSA_VALID) + { + throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model."); + } + + // Set the inputs + for (uint32_t inputSlotIdx = 0; inputSlotIdx < numInputBuffers; ++inputSlotIdx) + { + DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType(); + switch (dataType) + { + case DataType::Float32: + SetInput(runner, input_names[inputSlotIdx], inputSlotIdx); + break; + default: + throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend."); + } + } + + // Run the TOSA Reference Model + status = runner.run(); + if(status != GraphStatus::TOSA_VALID) + { + throw armnn::Exception("An error has occurred while running the TOSA Reference Model."); + } + + // Gets the outputs + for (uint32_t outputSlotIdx = 0; outputSlotIdx < numOutputBuffers; ++outputSlotIdx) + { + DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType(); + switch (dataType) + { + case DataType::Float32: + GetOutput(runner, output_names[outputSlotIdx], outputSlotIdx); + break; + default: + throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend."); + } + } +} + +template +void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner, + std::string inputName, + uint32_t inputIndex) const +{ + std::vector inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements()); + m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data()); + + runner.setInput(inputName, inputData); +} + +template +void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner, + std::string outputName, + uint32_t outputIndex) const +{ + std::vector actualOutputs = runner.getOutput(outputName); + + m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data()); } bool TosaRefPreCompiledWorkloadValidate(std::string*) -- cgit v1.2.1