aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test
diff options
context:
space:
mode:
authorCathal Corbett <catcor01@e127348.nice.arm.com>2021-10-07 11:46:40 +0100
committerCathal Corbett <cathal.corbett@arm.com>2021-10-08 11:28:35 +0000
commit521032fd424cf86681eb125afbf5eaee47d8c585 (patch)
tree65162778f203638f1c039097b8240422f99dad76 /src/backends/backendsCommon/test
parent723bc3b5d8a911a369eee658631d9f107ea09896 (diff)
downloadarmnn-521032fd424cf86681eb125afbf5eaee47d8c585.tar.gz
IVGCVSW-6417: Catch AddFullyConnected API error when weights TensorInfo isn't set
* Updated code in Graph.cpp InferTensorInfos() to be more descriptive. * Added method VerifyConstantLayerSetTensorInfo() in Graph.cpp/hpp to error when ConstantLayer TensorInfo is not set. * Updated Optimize() in Network.cpp to call VerifyConstantLayerSetTensorInfo(). * Added unit test with ConstantLayer TensorInfo not set to catch error in VerifyConstantLayerSetTensorInfo(). * Added comments around method VerifyConstantLayerSetTensorInfo(). Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I366596243f7c5823676222e2d0cce1335bc8c325
Diffstat (limited to 'src/backends/backendsCommon/test')
-rw-r--r--src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp41
1 files changed, 39 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp
index af6b56852a..7345ff5151 100644
--- a/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp
@@ -84,6 +84,25 @@ armnn::INetworkPtr CreateFullyConnectedNetworkConstWeightsNonConstBias(const arm
return network;
}
+armnn::INetworkPtr CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::TensorInfo& outputTensorInfo,
+ const armnn::ConstTensor& weightsConstantTensor,
+ armnn::FullyConnectedDescriptor descriptor)
+{
+ armnn::INetworkPtr network(armnn::INetwork::Create());
+
+ armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
+ armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weightsConstantTensor, "Weights");
+ armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
+ armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
+
+ Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
+ weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
+ Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
+
+ return network;
+}
+
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId>& backends)
{
@@ -141,7 +160,8 @@ void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId>& backends,
const bool transposeWeights,
- const bool constantWeightsOrBias)
+ const bool constantWeightsOrBias,
+ const bool tensorInfoSet)
{
unsigned int inputWidth = 1;
unsigned int inputHeight = 1;
@@ -210,7 +230,24 @@ void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::
descriptor.m_TransposeWeightMatrix = transposeWeights;
descriptor.m_ConstantWeights = constantWeightsOrBias;
- if (!constantWeightsOrBias)
+ if(!tensorInfoSet)
+ {
+ // Tests constant weights and non constant bias.
+ ConstTensor weightsConstantTensor(weightsDesc, weights.data());
+
+ armnn::INetworkPtr network = CreateFullyConnectedNetworkNoTensorInfoConstWeights(inputTensorInfo,
+ outputTensorInfo,
+ weightsConstantTensor,
+ descriptor);
+ CHECK(network);
+
+ // Create runtime in which test will run
+ IRuntime::CreationOptions options;
+ IRuntimePtr runtime(IRuntime::Create(options));
+
+ CHECK_THROWS_AS( Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException );
+ }
+ else if (!constantWeightsOrBias)
{
// Tests non constant weights and constant bias.
ConstTensor biasConstantTensor(biasesDesc, biasValues.data());