diff options
Diffstat (limited to 'src/armnn/test/OptimizerTests.cpp')
-rw-r--r-- | src/armnn/test/OptimizerTests.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp index 80addb4bfd..3b079864c2 100644 --- a/src/armnn/test/OptimizerTests.cpp +++ b/src/armnn/test/OptimizerTests.cpp @@ -995,4 +995,59 @@ BOOST_AUTO_TEST_CASE(ResizeBilinearValidateTensorShapesFromInputsNhwc) BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); } + +void CreateGatherGraph(Graph& graph, const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo, + const armnn::TensorInfo& outputInfo) +{ + Layer* input0 = graph.AddLayer<InputLayer>(0, "params"); + input0->GetOutputSlot().SetTensorInfo(paramsInfo); + + Layer* input1 = graph.AddLayer<InputLayer>(1, "indices"); + input1->GetOutputSlot().SetTensorInfo(indicesInfo); + + GatherLayer* layer = graph.AddLayer<GatherLayer>("gather"); + layer->GetOutputSlot().SetTensorInfo(outputInfo); + + Layer* output = graph.AddLayer<OutputLayer>(0, "output"); + input0->GetOutputSlot().Connect(layer->GetInputSlot(0)); + input1->GetOutputSlot().Connect(layer->GetInputSlot(1)); + layer->GetOutputSlot().Connect(output->GetInputSlot(0)); +} + +BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs) +{ + Graph graph; + armnn::TensorInfo paramsInfo({10, 5}, DataType::Float32); + armnn::TensorInfo indicesInfo({3}, DataType::Signed32); + armnn::TensorInfo outputInfo({3, 5}, DataType::Float32); + + CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo); + + BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); +} + +BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs1DParams) +{ + Graph graph; + armnn::TensorInfo paramsInfo({8}, DataType::Float32); + armnn::TensorInfo indicesInfo({5}, DataType::Signed32); + armnn::TensorInfo outputInfo( {5}, DataType::Float32); + + CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo); + + BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); +} + +BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputsMultiDimIndices) +{ + Graph graph; + armnn::TensorInfo paramsInfo({3, 2, 5}, DataType::Float32); + armnn::TensorInfo indicesInfo({2, 2}, DataType::Signed32); + armnn::TensorInfo outputInfo({2, 2, 2, 5}, DataType::Float32); + + CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo); + + BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); +} + BOOST_AUTO_TEST_SUITE_END() |