aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/OptimizerTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/OptimizerTests.cpp')
-rw-r--r--src/armnn/test/OptimizerTests.cpp55
1 files changed, 55 insertions, 0 deletions
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp
index 80addb4bfd..3b079864c2 100644
--- a/src/armnn/test/OptimizerTests.cpp
+++ b/src/armnn/test/OptimizerTests.cpp
@@ -995,4 +995,59 @@ BOOST_AUTO_TEST_CASE(ResizeBilinearValidateTensorShapesFromInputsNhwc)
BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
}
+
+void CreateGatherGraph(Graph& graph, const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo,
+ const armnn::TensorInfo& outputInfo)
+{
+ Layer* input0 = graph.AddLayer<InputLayer>(0, "params");
+ input0->GetOutputSlot().SetTensorInfo(paramsInfo);
+
+ Layer* input1 = graph.AddLayer<InputLayer>(1, "indices");
+ input1->GetOutputSlot().SetTensorInfo(indicesInfo);
+
+ GatherLayer* layer = graph.AddLayer<GatherLayer>("gather");
+ layer->GetOutputSlot().SetTensorInfo(outputInfo);
+
+ Layer* output = graph.AddLayer<OutputLayer>(0, "output");
+ input0->GetOutputSlot().Connect(layer->GetInputSlot(0));
+ input1->GetOutputSlot().Connect(layer->GetInputSlot(1));
+ layer->GetOutputSlot().Connect(output->GetInputSlot(0));
+}
+
+BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs)
+{
+ Graph graph;
+ armnn::TensorInfo paramsInfo({10, 5}, DataType::Float32);
+ armnn::TensorInfo indicesInfo({3}, DataType::Signed32);
+ armnn::TensorInfo outputInfo({3, 5}, DataType::Float32);
+
+ CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
+
+ BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
+}
+
+BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs1DParams)
+{
+ Graph graph;
+ armnn::TensorInfo paramsInfo({8}, DataType::Float32);
+ armnn::TensorInfo indicesInfo({5}, DataType::Signed32);
+ armnn::TensorInfo outputInfo( {5}, DataType::Float32);
+
+ CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
+
+ BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
+}
+
+BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputsMultiDimIndices)
+{
+ Graph graph;
+ armnn::TensorInfo paramsInfo({3, 2, 5}, DataType::Float32);
+ armnn::TensorInfo indicesInfo({2, 2}, DataType::Signed32);
+ armnn::TensorInfo outputInfo({2, 2, 2, 5}, DataType::Float32);
+
+ CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
+
+ BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
+}
+
BOOST_AUTO_TEST_SUITE_END()