14 , m_workloadInfo(
info)
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
28 std::vector<std::string> inputNames = handler->GetInputs();
29 std::vector<std::string> outputNames = handler->GetOutputs();
31 TosaReference::IModelRunner runner;
35 status = runner.initialize(*handler);
36 if(status != GraphStatus::TOSA_VALID)
38 throw armnn::Exception(
"An error has occurred while initialising the TOSA Reference Model.");
42 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
48 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
51 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
58 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
61 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
64 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
67 throw armnn::Exception(
"Input data type is unsupported in TOSA Reference Backend.");
72 status = runner.run();
73 if(status != GraphStatus::TOSA_VALID)
75 throw armnn::Exception(
"An error has occurred while running the TOSA Reference Model.");
79 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
85 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
88 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
95 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
98 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
101 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
104 throw armnn::Exception(
"Output data type is unsupported in TOSA Reference Backend.");
109 template <
typename T>
110 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
111 std::string inputName,
112 uint32_t inputIndex)
const
114 std::vector<T> inputData(
m_Data.
m_Inputs[inputIndex]->GetShape().GetNumElements());
117 runner.setInput<T>(inputName, inputData);
120 template <
typename T>
121 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
122 std::string outputName,
123 uint32_t outputIndex)
const
125 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);