ArmNN
 24.02
TosaRefPreCompiledWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 namespace armnn
9 {
10 
12  const WorkloadInfo& info)
14  , m_workloadInfo(info)
15 {
16  // Check that the workload is holding a pointer to a valid pre-compiled object
17  if (m_Data.m_PreCompiledObject == nullptr)
18  {
20  "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21  }
22 }
24 {
25  tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
26 
27  std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28  std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
29 
30  TosaReference::IModelRunner runner;
31  GraphStatus status;
32 
33  // Initialise the model runner with the TosaSerializationHandler
34  status = runner.initialize(*handler);
35  if(status != GraphStatus::TOSA_VALID)
36  {
37  throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
38  }
39 
40  // Set the inputs
41  for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
42  {
43  DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
44  switch (dataType)
45  {
46  case DataType::Float16:
47  SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
48  break;
49  case DataType::Float32:
50  SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
51  break;
52  case DataType::QAsymmU8:
53  SetInput<uint8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
54  break;
55  case DataType::QAsymmS8:
56  case DataType::QSymmS8:
57  SetInput<int8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
58  break;
59  case DataType::QSymmS16:
60  SetInput<int16_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
61  break;
62  case DataType::Signed32:
63  SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
64  break;
65  case DataType::Signed64:
66  SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
67  break;
68  case DataType::Boolean:
69  SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
70  break;
71  default:
72  throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
73  }
74  }
75 
76  // Run the TOSA Reference Model
77  status = runner.run();
78  if(status != GraphStatus::TOSA_VALID)
79  {
80  throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
81  }
82 
83  // Gets the outputs
84  for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
85  {
86  DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
87  switch (dataType)
88  {
89  case DataType::Float16:
90  GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
91  break;
92  case DataType::Float32:
93  GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
94  break;
95  case DataType::QAsymmU8:
96  GetOutput<uint8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
97  break;
98  case DataType::QAsymmS8:
99  case DataType::QSymmS8:
100  GetOutput<int8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
101  break;
102  case DataType::QSymmS16:
103  GetOutput<int16_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
104  break;
105  case DataType::Signed32:
106  GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
107  break;
108  case DataType::Signed64:
109  GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
110  break;
111  case DataType::Boolean:
112  GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
113  break;
114  default:
115  throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
116  }
117  }
118 }
119 
120 template <typename T>
121 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
122  std::string inputName,
123  uint32_t inputIndex) const
124 {
125  SetInput<T, T>(runner, inputName, inputIndex);
126 }
127 
128 template <typename T, typename Trunner>
129 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
130  std::string inputName,
131  uint32_t inputIndex) const
132 {
133  std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
134  std::vector<Trunner> inputDataRunner(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
135 
136  m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
137 
138  std::transform(inputData.begin(), inputData.end(),
139  inputDataRunner.begin(), [](T x) { return static_cast<Trunner>(x);});
140 
141  runner.setInput<Trunner>(inputName, inputDataRunner);
142 }
143 
144 template <typename T>
145 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
146  std::string outputName,
147  uint32_t outputIndex) const
148 {
149  GetOutput<T, T>(runner, outputName, outputIndex);
150 }
151 
152 template <typename T, typename Trunner>
153 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
154  std::string outputName,
155  uint32_t outputIndex) const
156 {
157  std::vector<Trunner> actualOutputsRunner = runner.getOutput<Trunner>(outputName);
158  std::vector<T> actualOutputs (actualOutputsRunner.size());
159 
160  std::transform(actualOutputsRunner.begin(), actualOutputsRunner.end(),
161  actualOutputs.begin(), [](Trunner x) { return static_cast<T>(x);});
162 
163  m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
164 }
165 
167 {
168  return true;
169 }
170 
171 } //namespace armnn
armnn::PreCompiledQueueDescriptor::m_PreCompiledObject
void * m_PreCompiledObject
Definition: WorkloadData.hpp:519
armnn::DataType::Boolean
@ Boolean
armnn::DataType::Float32
@ Float32
armnn::DataType::QAsymmU8
@ QAsymmU8
armnn::PreCompiledQueueDescriptor
Definition: WorkloadData.hpp:512
armnn::DataType::QSymmS8
@ QSymmS8
armnn::DataType::QSymmS16
@ QSymmS16
armnn::WorkloadInfo::m_OutputTensorInfos
std::vector< TensorInfo > m_OutputTensorInfos
Definition: WorkloadInfo.hpp:19
armnn::DataType::Float16
@ Float16
armnn::WorkloadInfo
Contains information about TensorInfos of a layer.
Definition: WorkloadInfo.hpp:16
armnn::DataType
DataType
Definition: Types.hpp:48
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::TosaRefPreCompiledWorkload::Execute
void Execute() const override
Definition: TosaRefPreCompiledWorkload.cpp:23
TosaRefPreCompiledWorkload.hpp
armnn::TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: TosaRefPreCompiledWorkload.cpp:11
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnn::BoostLogSeverityMapping::info
@ info
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
armnn::DataType::Signed32
@ Signed32
armnn::BaseWorkload
Definition: Workload.hpp:33
armnn::DataType::QAsymmS8
@ QAsymmS8
armnn::BaseWorkload< PreCompiledQueueDescriptor >::m_Data
PreCompiledQueueDescriptor m_Data
Definition: Workload.hpp:89
armnn::WorkloadInfo::m_InputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos
Definition: WorkloadInfo.hpp:18
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::TosaRefPreCompiledWorkloadValidate
bool TosaRefPreCompiledWorkloadValidate(std::string *)
Definition: TosaRefPreCompiledWorkload.cpp:166
armnn::DataType::Signed64
@ Signed64
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26