14 , m_workloadInfo(info)
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
26 uint32_t numInputBuffers =
static_cast<uint32_t
>(
m_Data.
m_Inputs.size());
27 uint32_t numOutputBuffers =
static_cast<uint32_t
>(
m_Data.
m_Outputs.size());
31 std::vector<std::string> input_names = handler->GetInputs();
32 std::vector<std::string> output_names = handler->GetOutputs();
34 TosaReference::IModelRunner runner;
38 status = runner.initialize(*handler);
39 if(status != GraphStatus::TOSA_VALID)
41 throw armnn::Exception(
"An error has occurred while initialising the TOSA Reference Model.");
45 for (uint32_t inputSlotIdx = 0; inputSlotIdx < numInputBuffers; ++inputSlotIdx)
51 SetInput<float>(runner, input_names[inputSlotIdx], inputSlotIdx);
54 throw armnn::Exception(
"Input data type is unsupported in TOSA Reference Backend.");
59 status = runner.run();
60 if(status != GraphStatus::TOSA_VALID)
62 throw armnn::Exception(
"An error has occurred while running the TOSA Reference Model.");
66 for (uint32_t outputSlotIdx = 0; outputSlotIdx < numOutputBuffers; ++outputSlotIdx)
72 GetOutput<float>(runner, output_names[outputSlotIdx], outputSlotIdx);
75 throw armnn::Exception(
"Output data type is unsupported in TOSA Reference Backend.");
81 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
82 std::string inputName,
83 uint32_t inputIndex)
const 85 std::vector<T> inputData(
m_Data.
m_Inputs[inputIndex]->GetShape().GetNumElements());
88 runner.setInput<T>(inputName, inputData);
92 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
93 std::string outputName,
94 uint32_t outputIndex)
const 96 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
void Execute() const override
Copyright (c) 2021 ARM Limited and Contributors.
std::vector< TensorInfo > m_InputTensorInfos
PreCompiledQueueDescriptor m_Data
void * m_PreCompiledObject
std::vector< TensorInfo > m_OutputTensorInfos
bool TosaRefPreCompiledWorkloadValidate(std::string *)
std::vector< ITensorHandle * > m_Outputs
Base class for all ArmNN exceptions so that users can filter to just those.
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs