diff options
Diffstat (limited to 'src/backends/backendsCommon/test/QuantizedLstmEndToEndTestImpl.cpp')
-rw-r--r-- | src/backends/backendsCommon/test/QuantizedLstmEndToEndTestImpl.cpp | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/test/QuantizedLstmEndToEndTestImpl.cpp b/src/backends/backendsCommon/test/QuantizedLstmEndToEndTestImpl.cpp index a2fadc7b92..f178951873 100644 --- a/src/backends/backendsCommon/test/QuantizedLstmEndToEndTestImpl.cpp +++ b/src/backends/backendsCommon/test/QuantizedLstmEndToEndTestImpl.cpp @@ -46,14 +46,14 @@ armnn::INetworkPtr CreateQuantizedLstmNetwork(armnn::TensorShape& inputShape, armnn::TensorInfo inputWeightsInfo({outputSize, inputSize}, armnn::DataType::QAsymmU8, weightsScale, - weightsOffset); + weightsOffset, true); armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize}, armnn::DataType::QAsymmU8, weightsScale, - weightsOffset); + weightsOffset, true); - armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset); + armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset, true); armnn::QuantizedLstmInputParams data; @@ -210,9 +210,16 @@ void QuantizedLstmEndToEnd(const std::vector<armnn::BackendId>& backends) inputTensors.reserve(3); // input - inputTensors.push_back({0, ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputVector.data())}); - inputTensors.push_back({1, ConstTensor(runtime->GetInputTensorInfo(netId, 1), cellStateInVector.data())}); - inputTensors.push_back({2, ConstTensor(runtime->GetInputTensorInfo(netId, 2), outputStateInVector.data())}); + TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); + TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); + TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); + inputTensorInfo0.SetConstant(true); + inputTensorInfo1.SetConstant(true); + inputTensorInfo2.SetConstant(true); + + inputTensors.push_back({0, ConstTensor(inputTensorInfo0, inputVector.data())}); + inputTensors.push_back({1, ConstTensor(inputTensorInfo1, cellStateInVector.data())}); + inputTensors.push_back({2, ConstTensor(inputTensorInfo2, outputStateInVector.data())}); OutputTensors outputTensors; outputTensors.reserve(2); |