diff options
author | Cathal Corbett <catcor01@e127348.nice.arm.com> | 2021-10-07 11:46:40 +0100 |
---|---|---|
committer | Cathal Corbett <cathal.corbett@arm.com> | 2021-10-08 11:28:35 +0000 |
commit | 521032fd424cf86681eb125afbf5eaee47d8c585 (patch) | |
tree | 65162778f203638f1c039097b8240422f99dad76 /src/backends/backendsCommon | |
parent | 723bc3b5d8a911a369eee658631d9f107ea09896 (diff) | |
download | armnn-521032fd424cf86681eb125afbf5eaee47d8c585.tar.gz |
IVGCVSW-6417: Catch AddFullyConnected API error when weights TensorInfo isn't set
* Updated code in Graph.cpp InferTensorInfos() to be more descriptive.
* Added method VerifyConstantLayerSetTensorInfo() in Graph.cpp/hpp
to error when ConstantLayer TensorInfo is not set.
* Updated Optimize() in Network.cpp to call VerifyConstantLayerSetTensorInfo().
* Added unit test with ConstantLayer TensorInfo not
set to catch error in VerifyConstantLayerSetTensorInfo().
* Added comments around method VerifyConstantLayerSetTensorInfo().
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I366596243f7c5823676222e2d0cce1335bc8c325
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp | 41 |
1 files changed, 39 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp index af6b56852a..7345ff5151 100644 --- a/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp +++ b/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp @@ -84,6 +84,25 @@ armnn::INetworkPtr CreateFullyConnectedNetworkConstWeightsNonConstBias(const arm return network; } +armnn::INetworkPtr CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo& inputTensorInfo, + const armnn::TensorInfo& outputTensorInfo, + const armnn::ConstTensor& weightsConstantTensor, + armnn::FullyConnectedDescriptor descriptor) +{ + armnn::INetworkPtr network(armnn::INetwork::Create()); + + armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input"); + armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weightsConstantTensor, "Weights"); + armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected"); + armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output"); + + Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0); + weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1)); + Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0); + + return network; +} + template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId>& backends) { @@ -141,7 +160,8 @@ void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId>& backends, const bool transposeWeights, - const bool constantWeightsOrBias) + const bool constantWeightsOrBias, + const bool tensorInfoSet) { unsigned int inputWidth = 1; unsigned int inputHeight = 1; @@ -210,7 +230,24 @@ void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn:: descriptor.m_TransposeWeightMatrix = transposeWeights; descriptor.m_ConstantWeights = constantWeightsOrBias; - if (!constantWeightsOrBias) + if(!tensorInfoSet) + { + // Tests constant weights and non constant bias. + ConstTensor weightsConstantTensor(weightsDesc, weights.data()); + + armnn::INetworkPtr network = CreateFullyConnectedNetworkNoTensorInfoConstWeights(inputTensorInfo, + outputTensorInfo, + weightsConstantTensor, + descriptor); + CHECK(network); + + // Create runtime in which test will run + IRuntime::CreationOptions options; + IRuntimePtr runtime(IRuntime::Create(options)); + + CHECK_THROWS_AS( Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException ); + } + else if (!constantWeightsOrBias) { // Tests non constant weights and constant bias. ConstTensor biasConstantTensor(biasesDesc, biasValues.data()); |