aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp')
-rw-r--r--src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp41
1 files changed, 39 insertions, 2 deletions
diff --git a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
index 8b08f01b23..5e4103aed1 100644
--- a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
+++ b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
@@ -50,9 +50,15 @@ void TosaRefPreCompiledWorkload::Execute() const
SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
break;
case DataType::QAsymmU8:
+ SetInput<uint8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
+ break;
case DataType::QAsymmS8:
case DataType::QSymmS8:
+ SetInput<int8_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
+ break;
case DataType::QSymmS16:
+ SetInput<int16_t, int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
+ break;
case DataType::Signed32:
SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
break;
@@ -87,9 +93,15 @@ void TosaRefPreCompiledWorkload::Execute() const
GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
break;
case DataType::QAsymmU8:
+ GetOutput<uint8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
+ break;
case DataType::QAsymmS8:
case DataType::QSymmS8:
+ GetOutput<int8_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
+ break;
case DataType::QSymmS16:
+ GetOutput<int16_t, int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
+ break;
case DataType::Signed32:
GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
break;
@@ -110,10 +122,23 @@ void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
std::string inputName,
uint32_t inputIndex) const
{
+ SetInput<T, T>(runner, inputName, inputIndex);
+}
+
+template <typename T, typename Trunner>
+void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
+ std::string inputName,
+ uint32_t inputIndex) const
+{
std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
+ std::vector<Trunner> inputDataRunner(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
+
m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
- runner.setInput<T>(inputName, inputData);
+ std::transform(inputData.begin(), inputData.end(),
+ inputDataRunner.begin(), [](T x) { return static_cast<Trunner>(x);});
+
+ runner.setInput<Trunner>(inputName, inputDataRunner);
}
template <typename T>
@@ -121,7 +146,19 @@ void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
std::string outputName,
uint32_t outputIndex) const
{
- std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
+ GetOutput<T, T>(runner, outputName, outputIndex);
+}
+
+template <typename T, typename Trunner>
+void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
+ std::string outputName,
+ uint32_t outputIndex) const
+{
+ std::vector<Trunner> actualOutputsRunner = runner.getOutput<Trunner>(outputName);
+ std::vector<T> actualOutputs (actualOutputsRunner.size());
+
+ std::transform(actualOutputsRunner.begin(), actualOutputsRunner.end(),
+ actualOutputs.begin(), [](Trunner x) { return static_cast<T>(x);});
m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
}