aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp')
-rw-r--r--src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp17
1 files changed, 11 insertions, 6 deletions
diff --git a/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp b/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp
index 4c453cc799..778d7b0814 100644
--- a/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp
+++ b/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -25,33 +25,38 @@ TEST_CASE("ConvertConstantsHalfToFloatTest")
std::vector<uint16_t> halfWeights(4);
armnnUtils::FloatingPointConverter::ConvertFloat32To16(convWeightsData.data(), convWeightsData.size(),
halfWeights.data());
- armnn::ConstTensor weights(armnn::TensorInfo(4, dims, armnn::DataType::Float16, 0.0f, 0, true), halfWeights);
+ armnn::TensorInfo weightInfo = armnn::TensorInfo(4, dims, armnn::DataType::Float16, 0.0f, 0, true);
+ armnn::ConstTensor weights(weightInfo, halfWeights);
//Create the simple test network
auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
input->GetOutputSlot().SetTensorInfo(info);
auto fc = graph.AddLayer<armnn::FullyConnectedLayer>(armnn::FullyConnectedDescriptor(), "fc");
- fc->m_Weight = std::make_unique<armnn::ScopedTensorHandle>(weights);
fc->GetOutputSlot().SetTensorInfo(info);
+ auto weightsLayer = graph.AddLayer<armnn::ConstantLayer>("weights");
+ weightsLayer->m_LayerOutput = std::make_unique<armnn::ScopedTensorHandle>(weights);
+ weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
+
auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
//Connect up the layers
input->GetOutputSlot().Connect(fc->GetInputSlot(0));
+ weightsLayer->GetOutputSlot().Connect(fc->GetInputSlot(1));
fc->GetOutputSlot().Connect(output->GetInputSlot(0));
//Test the tensor info is correct.
- CHECK(fc->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::Float16);
+ CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float16);
// Run the optimizer
armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(ConvertConstantsHalfToFloat()));
//Test the tensor info is correct.
- CHECK(fc->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::Float32);
+ CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float32);
// Now test the data matches float32 data
- const float* data = fc->m_Weight->GetConstTensor<float>();
+ const float* data = weightsLayer->m_LayerOutput->GetConstTensor<float>();
CHECK(1.0f == data[0]);
CHECK(2.0f == data[1]);
CHECK(3.0f == data[2]);