ArmNN
 23.11
TosaRefPreCompiledWorkload Class Reference

#include <TosaRefPreCompiledWorkload.hpp>

Inheritance diagram for TosaRefPreCompiledWorkload:
[legend]
Collaboration diagram for TosaRefPreCompiledWorkload:
[legend]

Public Member Functions

 TosaRefPreCompiledWorkload (const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
 
void Execute () const override
 
- Public Member Functions inherited from BaseWorkload< PreCompiledQueueDescriptor >
 BaseWorkload (const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info)
 
virtual const std::string & GetName () const override
 
void ExecuteAsync (ExecutionData &executionData) override
 
void PostAllocationConfigure () override
 
const PreCompiledQueueDescriptorGetData () const
 
arm::pipe::ProfilingGuid GetGuid () const final
 
virtual bool SupportsTensorHandleReplacement () const override
 
void ReplaceInputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
 
void ReplaceOutputTensorHandle (ITensorHandle *tensorHandle, unsigned int slot) override
 
- Public Member Functions inherited from IWorkload
virtual ~IWorkload ()
 
virtual arm::pipe::ProfilingGuid GetGuid () const =0
 
virtual const std::string & GetName () const =0
 
virtual void RegisterDebugCallback (const DebugCallbackFunction &)
 
virtual armnn::Optional< armnn::MemoryRequirementsGetMemoryRequirements ()
 

Additional Inherited Members

- Protected Attributes inherited from BaseWorkload< PreCompiledQueueDescriptor >
PreCompiledQueueDescriptor m_Data
 
const arm::pipe::ProfilingGuid m_Guid
 
const std::string m_Name
 

Detailed Description

Definition at line 22 of file TosaRefPreCompiledWorkload.hpp.

Constructor & Destructor Documentation

◆ TosaRefPreCompiledWorkload()

TosaRefPreCompiledWorkload ( const PreCompiledQueueDescriptor descriptor,
const WorkloadInfo info 
)

Definition at line 11 of file TosaRefPreCompiledWorkload.cpp.

13  : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
14  , m_workloadInfo(info)
15 {
16  // Check that the workload is holding a pointer to a valid pre-compiled object
17  if (m_Data.m_PreCompiledObject == nullptr)
18  {
19  throw InvalidArgumentException(
20  "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21  }
22 }

References armnn::info, BaseWorkload< PreCompiledQueueDescriptor >::m_Data, and PreCompiledQueueDescriptor::m_PreCompiledObject.

Member Function Documentation

◆ Execute()

void Execute ( ) const
overridevirtual

Implements IWorkload.

Definition at line 23 of file TosaRefPreCompiledWorkload.cpp.

24 {
25  tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
26 
27  std::vector<std::string> inputNames = handler->GetMainRegion()->GetBlocks()[0]->GetInputs();
28  std::vector<std::string> outputNames = handler->GetMainRegion()->GetBlocks()[0]->GetOutputs();
29 
30  TosaReference::IModelRunner runner;
31  GraphStatus status;
32 
33  // Initialise the model runner with the TosaSerializationHandler
34  status = runner.initialize(*handler);
35  if(status != GraphStatus::TOSA_VALID)
36  {
37  throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
38  }
39 
40  // Set the inputs
41  for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
42  {
43  DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
44  switch (dataType)
45  {
46  case DataType::Float16:
47  SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
48  break;
49  case DataType::Float32:
50  SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
51  break;
52  case DataType::QAsymmU8:
53  case DataType::QAsymmS8:
54  case DataType::QSymmS8:
55  case DataType::QSymmS16:
56  case DataType::Signed32:
57  SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
58  break;
59  case DataType::Signed64:
60  SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
61  break;
62  case DataType::Boolean:
63  SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
64  break;
65  default:
66  throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
67  }
68  }
69 
70  // Run the TOSA Reference Model
71  status = runner.run();
72  if(status != GraphStatus::TOSA_VALID)
73  {
74  throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
75  }
76 
77  // Gets the outputs
78  for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
79  {
80  DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
81  switch (dataType)
82  {
83  case DataType::Float16:
84  GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
85  break;
86  case DataType::Float32:
87  GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
88  break;
89  case DataType::QAsymmU8:
90  case DataType::QAsymmS8:
91  case DataType::QSymmS8:
92  case DataType::QSymmS16:
93  case DataType::Signed32:
94  GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
95  break;
96  case DataType::Signed64:
97  GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
98  break;
99  case DataType::Boolean:
100  GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
101  break;
102  default:
103  throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
104  }
105  }
106 }

References armnn::Boolean, armnn::Float16, armnn::Float32, BaseWorkload< PreCompiledQueueDescriptor >::m_Data, WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, PreCompiledQueueDescriptor::m_PreCompiledObject, armnn::QAsymmS8, armnn::QAsymmU8, armnn::QSymmS16, armnn::QSymmS8, armnn::Signed32, and armnn::Signed64.


The documentation for this class was generated from the following files:
armnn::PreCompiledQueueDescriptor::m_PreCompiledObject
void * m_PreCompiledObject
Definition: WorkloadData.hpp:519
armnn::DataType::Boolean
@ Boolean
armnn::DataType::Float32
@ Float32
armnn::DataType::QAsymmU8
@ QAsymmU8
armnn::DataType::QSymmS8
@ QSymmS8
armnn::DataType::QSymmS16
@ QSymmS16
armnn::WorkloadInfo::m_OutputTensorInfos
std::vector< TensorInfo > m_OutputTensorInfos
Definition: WorkloadInfo.hpp:19
armnn::DataType::Float16
@ Float16
armnn::DataType
DataType
Definition: Types.hpp:48
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
armnn::DataType::Signed32
@ Signed32
armnn::DataType::QAsymmS8
@ QAsymmS8
armnn::BaseWorkload< PreCompiledQueueDescriptor >::m_Data
PreCompiledQueueDescriptor m_Data
Definition: Workload.hpp:89
armnn::WorkloadInfo::m_InputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos
Definition: WorkloadInfo.hpp:18
armnn::DataType::Signed64
@ Signed64