diff options
Diffstat (limited to 'src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp')
-rw-r--r-- | src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp b/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp index 4c453cc799..778d7b0814 100644 --- a/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp +++ b/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -25,33 +25,38 @@ TEST_CASE("ConvertConstantsHalfToFloatTest") std::vector<uint16_t> halfWeights(4); armnnUtils::FloatingPointConverter::ConvertFloat32To16(convWeightsData.data(), convWeightsData.size(), halfWeights.data()); - armnn::ConstTensor weights(armnn::TensorInfo(4, dims, armnn::DataType::Float16, 0.0f, 0, true), halfWeights); + armnn::TensorInfo weightInfo = armnn::TensorInfo(4, dims, armnn::DataType::Float16, 0.0f, 0, true); + armnn::ConstTensor weights(weightInfo, halfWeights); //Create the simple test network auto input = graph.AddLayer<armnn::InputLayer>(0, "input"); input->GetOutputSlot().SetTensorInfo(info); auto fc = graph.AddLayer<armnn::FullyConnectedLayer>(armnn::FullyConnectedDescriptor(), "fc"); - fc->m_Weight = std::make_unique<armnn::ScopedTensorHandle>(weights); fc->GetOutputSlot().SetTensorInfo(info); + auto weightsLayer = graph.AddLayer<armnn::ConstantLayer>("weights"); + weightsLayer->m_LayerOutput = std::make_unique<armnn::ScopedTensorHandle>(weights); + weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo); + auto output = graph.AddLayer<armnn::OutputLayer>(1, "output"); //Connect up the layers input->GetOutputSlot().Connect(fc->GetInputSlot(0)); + weightsLayer->GetOutputSlot().Connect(fc->GetInputSlot(1)); fc->GetOutputSlot().Connect(output->GetInputSlot(0)); //Test the tensor info is correct. - CHECK(fc->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::Float16); + CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float16); // Run the optimizer armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(ConvertConstantsHalfToFloat())); //Test the tensor info is correct. - CHECK(fc->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::Float32); + CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float32); // Now test the data matches float32 data - const float* data = fc->m_Weight->GetConstTensor<float>(); + const float* data = weightsLayer->m_LayerOutput->GetConstTensor<float>(); CHECK(1.0f == data[0]); CHECK(2.0f == data[1]); CHECK(3.0f == data[2]); |