ArmNN
 23.05
TosaRefPreCompiledWorkload.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 namespace armnn
9 {
10 
12  const WorkloadInfo& info)
14  , m_workloadInfo(info)
15 {
16  // Check that the workload is holding a pointer to a valid pre-compiled object
17  if (m_Data.m_PreCompiledObject == nullptr)
18  {
20  "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21  }
22 }
23 
25 {
26  tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
27 
28  std::vector<std::string> inputNames = handler->GetInputs();
29  std::vector<std::string> outputNames = handler->GetOutputs();
30 
31  TosaReference::IModelRunner runner;
32  GraphStatus status;
33 
34  // Initialise the model runner with the TosaSerializationHandler
35  status = runner.initialize(*handler);
36  if(status != GraphStatus::TOSA_VALID)
37  {
38  throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
39  }
40 
41  // Set the inputs
42  for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
43  {
44  DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
45  switch (dataType)
46  {
47  case DataType::Float16:
48  SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
49  break;
50  case DataType::Float32:
51  SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
52  break;
53  case DataType::QAsymmU8:
54  case DataType::QAsymmS8:
55  case DataType::QSymmS8:
56  case DataType::QSymmS16:
57  case DataType::Signed32:
58  SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
59  break;
60  case DataType::Signed64:
61  SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
62  break;
63  case DataType::Boolean:
64  SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
65  break;
66  default:
67  throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
68  }
69  }
70 
71  // Run the TOSA Reference Model
72  status = runner.run();
73  if(status != GraphStatus::TOSA_VALID)
74  {
75  throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
76  }
77 
78  // Gets the outputs
79  for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
80  {
81  DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
82  switch (dataType)
83  {
84  case DataType::Float16:
85  GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
86  break;
87  case DataType::Float32:
88  GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
89  break;
90  case DataType::QAsymmU8:
91  case DataType::QAsymmS8:
92  case DataType::QSymmS8:
93  case DataType::QSymmS16:
94  case DataType::Signed32:
95  GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
96  break;
97  case DataType::Signed64:
98  GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
99  break;
100  case DataType::Boolean:
101  GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
102  break;
103  default:
104  throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
105  }
106  }
107 }
108 
109 template <typename T>
110 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
111  std::string inputName,
112  uint32_t inputIndex) const
113 {
114  std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
115  m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
116 
117  runner.setInput<T>(inputName, inputData);
118 }
119 
120 template <typename T>
121 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
122  std::string outputName,
123  uint32_t outputIndex) const
124 {
125  std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
126 
127  m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
128 }
129 
131 {
132  return true;
133 }
134 
135 } //namespace armnn
armnn::DataType::QAsymmU8
@ QAsymmU8
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnn::DataType::Float16
@ Float16
armnn::TosaRefPreCompiledWorkload::Execute
void Execute() const override
Definition: TosaRefPreCompiledWorkload.cpp:24
armnn::DataType::Signed32
@ Signed32
armnn::BaseWorkload< PreCompiledQueueDescriptor >::m_Data
PreCompiledQueueDescriptor m_Data
Definition: Workload.hpp:83
armnn::DataType::QAsymmS8
@ QAsymmS8
armnn::TosaRefPreCompiledWorkloadValidate
bool TosaRefPreCompiledWorkloadValidate(std::string *)
Definition: TosaRefPreCompiledWorkload.cpp:130
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::BaseWorkload
Definition: Workload.hpp:33
armnn::WorkloadInfo::m_OutputTensorInfos
std::vector< TensorInfo > m_OutputTensorInfos
Definition: WorkloadInfo.hpp:19
armnn::DataType::Float32
@ Float32
armnn::DataType::Signed64
@ Signed64
armnn::PreCompiledQueueDescriptor
Definition: WorkloadData.hpp:507
armnn::WorkloadInfo
Contains information about TensorInfos of a layer.
Definition: WorkloadInfo.hpp:16
TosaRefPreCompiledWorkload.hpp
armnn::DataType
DataType
Definition: Types.hpp:48
armnn::PreCompiledQueueDescriptor::m_PreCompiledObject
void * m_PreCompiledObject
Definition: WorkloadData.hpp:514
armnn::TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
Definition: TosaRefPreCompiledWorkload.cpp:11
armnn::DataType::QSymmS8
@ QSymmS8
armnn::QueueDescriptor::m_Outputs
std::vector< ITensorHandle * > m_Outputs
Definition: WorkloadData.hpp:27
armnn::DataType::QSymmS16
@ QSymmS16
armnn::DataType::Boolean
@ Boolean
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::WorkloadInfo::m_InputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos
Definition: WorkloadInfo.hpp:18
armnn::QueueDescriptor::m_Inputs
std::vector< ITensorHandle * > m_Inputs
Definition: WorkloadData.hpp:26
armnn::BoostLogSeverityMapping::info
@ info