ArmNN
 22.11
TosaRefPreCompiledWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 namespace armnn
9 {
10 
12  const WorkloadInfo& info)
13  : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
14  , m_workloadInfo(info)
15 {
16  // Check that the workload is holding a pointer to a valid pre-compiled object
17  if (m_Data.m_PreCompiledObject == nullptr)
18  {
20  "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21  }
22 }
23 
25 {
26  uint32_t numInputBuffers = static_cast<uint32_t>(m_Data.m_Inputs.size());
27  uint32_t numOutputBuffers = static_cast<uint32_t>(m_Data.m_Outputs.size());
28 
29  tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
30 
31  std::vector<std::string> input_names = handler->GetInputs();
32  std::vector<std::string> output_names = handler->GetOutputs();
33 
34  TosaReference::IModelRunner runner;
35  GraphStatus status;
36 
37  // Initialise the model runner with the TosaSerializationHandler
38  status = runner.initialize(*handler);
39  if(status != GraphStatus::TOSA_VALID)
40  {
41  throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
42  }
43 
44  // Set the inputs
45  for (uint32_t inputSlotIdx = 0; inputSlotIdx < numInputBuffers; ++inputSlotIdx)
46  {
47  DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
48  switch (dataType)
49  {
50  case DataType::Float32:
51  SetInput<float>(runner, input_names[inputSlotIdx], inputSlotIdx);
52  break;
53  default:
54  throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
55  }
56  }
57 
58  // Run the TOSA Reference Model
59  status = runner.run();
60  if(status != GraphStatus::TOSA_VALID)
61  {
62  throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
63  }
64 
65  // Gets the outputs
66  for (uint32_t outputSlotIdx = 0; outputSlotIdx < numOutputBuffers; ++outputSlotIdx)
67  {
68  DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
69  switch (dataType)
70  {
71  case DataType::Float32:
72  GetOutput<float>(runner, output_names[outputSlotIdx], outputSlotIdx);
73  break;
74  default:
75  throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
76  }
77  }
78 }
79 
80 template <typename T>
81 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
82  std::string inputName,
83  uint32_t inputIndex) const
84 {
85  std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
86  m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
87 
88  runner.setInput<T>(inputName, inputData);
89 }
90 
91 template <typename T>
92 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
93  std::string outputName,
94  uint32_t outputIndex) const
95 {
96  std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
97 
98  m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
99 }
100 
102 {
103  return true;
104 }
105 
106 } //namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
std::vector< TensorInfo > m_InputTensorInfos
DataType
Definition: Types.hpp:48
std::vector< TensorInfo > m_OutputTensorInfos
bool TosaRefPreCompiledWorkloadValidate(std::string *)
std::vector< ITensorHandle * > m_Outputs
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
Contains information about TensorInfos of a layer.
std::vector< ITensorHandle * > m_Inputs