aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
index 2136c446fb..0df3bf5ef5 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -184,13 +184,14 @@ template<typename T>
void WriteToFile(const std::string& outputTensorFileName,
const std::string& outputName,
const T* const array,
- const unsigned int numElements)
+ const unsigned int numElements,
+ armnn::DataType dataType)
{
std::ofstream outputTensorFile;
outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
if (outputTensorFile.is_open())
{
- outputTensorFile << outputName << ": ";
+ outputTensorFile << outputName << ", "<< GetDataTypeName(dataType) << " : ";
for (std::size_t i = 0; i < numElements; ++i)
{
outputTensorFile << +array[i] << " ";
@@ -209,6 +210,7 @@ struct OutputWriteInfo
const std::string& m_OutputName;
const armnn::Tensor& m_Tensor;
const bool m_PrintTensor;
+ const armnn::DataType m_DataType;
};
template <typename T>
@@ -221,7 +223,8 @@ void PrintTensor(OutputWriteInfo& info, const char* formatString)
WriteToFile(info.m_OutputTensorFile.value(),
info.m_OutputName,
array,
- info.m_Tensor.GetNumElements());
+ info.m_Tensor.GetNumElements(),
+ info.m_DataType);
}
if (info.m_PrintTensor)
@@ -248,7 +251,8 @@ void PrintQuantizedTensor(OutputWriteInfo& info)
WriteToFile(info.m_OutputTensorFile.value(),
info.m_OutputName,
dequantizedValues.data(),
- tensor.GetNumElements());
+ tensor.GetNumElements(),
+ info.m_DataType);
}
if (info.m_PrintTensor)