diff options
Diffstat (limited to 'tests/ExecuteNetwork/TfliteExecutor.cpp')
-rw-r--r-- | tests/ExecuteNetwork/TfliteExecutor.cpp | 44 |
1 files changed, 4 insertions, 40 deletions
diff --git a/tests/ExecuteNetwork/TfliteExecutor.cpp b/tests/ExecuteNetwork/TfliteExecutor.cpp index fc9c21a559..f365623d62 100644 --- a/tests/ExecuteNetwork/TfliteExecutor.cpp +++ b/tests/ExecuteNetwork/TfliteExecutor.cpp @@ -230,45 +230,9 @@ void TfLiteExecutor::CompareAndPrintResult(std::vector<const void*> otherOutput) for (unsigned int outputIndex = 0; outputIndex < m_TfLiteInterpreter->outputs().size(); ++outputIndex) { auto tfLiteDelegateOutputId = m_TfLiteInterpreter->outputs()[outputIndex]; - float result = 0; - switch (m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->type) - { - case kTfLiteFloat32: - { - result = ComputeRMSE<float>(m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->allocation, - otherOutput[outputIndex], - m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->bytes); - - break; - } - case kTfLiteInt32: - { - result = ComputeRMSE<int32_t>(m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->allocation, - otherOutput[outputIndex], - m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->bytes); - break; - } - case kTfLiteUInt8: - { - result = ComputeRMSE<uint8_t>(m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->allocation, - otherOutput[outputIndex], - m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->bytes); - break; - } - case kTfLiteInt8: - { - result = ComputeRMSE<int8_t>(m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->allocation, - otherOutput[outputIndex], - m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->bytes); - break; - } - default: - { - } - } - - std::cout << "RMSE of " - << m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->name - << ": " << result << std::endl; + size_t size = m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->bytes; + double result = ComputeByteLevelRMSE(m_TfLiteInterpreter->tensor(tfLiteDelegateOutputId)->allocation, + otherOutput[outputIndex], size); + std::cout << "Byte level root mean square error: " << result << "\n"; } }; |