aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp')
-rw-r--r--src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
index 18d2900eff..ffdbf6f49b 100644
--- a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
+++ b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
@@ -47,9 +47,25 @@ void TosaRefPreCompiledWorkload::Execute() const
DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
switch (dataType)
{
+ case DataType::Float16:
+ SetInput<half_float::half>(runner, input_names[inputSlotIdx], inputSlotIdx);
+ break;
case DataType::Float32:
SetInput<float>(runner, input_names[inputSlotIdx], inputSlotIdx);
break;
+ case DataType::QAsymmU8:
+ case DataType::QAsymmS8:
+ case DataType::QSymmS8:
+ case DataType::QSymmS16:
+ case DataType::Signed32:
+ SetInput<int32_t>(runner, input_names[inputSlotIdx], inputSlotIdx);
+ break;
+ case DataType::Signed64:
+ SetInput<int64_t>(runner, input_names[inputSlotIdx], inputSlotIdx);
+ break;
+ case DataType::Boolean:
+ SetInput<unsigned char>(runner, input_names[inputSlotIdx], inputSlotIdx);
+ break;
default:
throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
}
@@ -68,9 +84,25 @@ void TosaRefPreCompiledWorkload::Execute() const
DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
switch (dataType)
{
+ case DataType::Float16:
+ GetOutput<half_float::half>(runner, output_names[outputSlotIdx], outputSlotIdx);
+ break;
case DataType::Float32:
GetOutput<float>(runner, output_names[outputSlotIdx], outputSlotIdx);
break;
+ case DataType::QAsymmU8:
+ case DataType::QAsymmS8:
+ case DataType::QSymmS8:
+ case DataType::QSymmS16:
+ case DataType::Signed32:
+ GetOutput<int32_t>(runner, output_names[outputSlotIdx], outputSlotIdx);
+ break;
+ case DataType::Signed64:
+ GetOutput<int64_t>(runner, output_names[outputSlotIdx], outputSlotIdx);
+ break;
+ case DataType::Boolean:
+ GetOutput<unsigned char>(runner, output_names[outputSlotIdx], outputSlotIdx);
+ break;
default:
throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
}