aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
index 6c74aaa6ed..00ed55caaf 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
@@ -34,6 +34,15 @@ auto ParseDataArray<armnn::DataType::Float32>(std::istream& stream)
}
template<>
+auto ParseDataArray<armnn::DataType::Float16>(std::istream& stream)
+{
+ return ParseArrayImpl<armnn::Half>(stream, [](const std::string& s)
+ {
+ return armnn::Half(std::stof(s));
+ });
+}
+
+template<>
auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
{
return ParseArrayImpl<int>(stream, [](const std::string& s) { return std::stoi(s); });
@@ -139,6 +148,20 @@ void TensorPrinter::operator()(const std::vector<float>& values)
WriteToFile(values);
}
+void TensorPrinter::operator()(const std::vector<armnn::Half>& values)
+{
+ if (m_PrintToConsole)
+ {
+ std::cout << m_OutputBinding << ": ";
+ ForEachValue(values, [](armnn::Half value)
+ {
+ printf("%f ", static_cast<float>(value));
+ });
+ printf("\n");
+ }
+ WriteToFile(values);
+}
+
void TensorPrinter::operator()(const std::vector<uint8_t>& values)
{
if(m_DequantizeOutput)
@@ -261,6 +284,12 @@ void PopulateTensorWithData(armnnUtils::TContainer& tensorData,
GenerateDummyTensorData<armnn::DataType::Float32>(numElements);
}
}
+ else if (dataTypeStr.compare("float16") == 0)
+ {
+ tensorData = readFromFile ?
+ ParseDataArray<armnn::DataType::Float16>(inputTensorFile) :
+ GenerateDummyTensorData<armnn::DataType::Float16>(numElements);
+ }
else if (dataTypeStr.compare("int") == 0)
{
tensorData = readFromFile ?