aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/ShapeInferenceTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/ShapeInferenceTests.cpp')
-rw-r--r--src/armnn/test/ShapeInferenceTests.cpp18
1 files changed, 5 insertions, 13 deletions
diff --git a/src/armnn/test/ShapeInferenceTests.cpp b/src/armnn/test/ShapeInferenceTests.cpp
index 8abcfd7595..d3c928fec1 100644
--- a/src/armnn/test/ShapeInferenceTests.cpp
+++ b/src/armnn/test/ShapeInferenceTests.cpp
@@ -401,24 +401,16 @@ TEST_CASE("FloorTest")
TEST_CASE("FullyConnectedTest")
{
- Graph graph;
-
const unsigned int inputWidth = 3u;
const unsigned int inputHeight = 2u;
const unsigned int inputChannels = 1u;
const unsigned int outputChannels = 2u;
- auto layer = BuildGraph<FullyConnectedLayer>(&graph,
- {{1, inputChannels, inputHeight, inputWidth}},
- FullyConnectedDescriptor(),
- "fc");
-
-
- const float Datum = 0.0f;
- ConstTensor weights({{inputChannels, outputChannels}, DataType::Float32}, &Datum);
- layer->m_Weight = std::make_unique<ScopedTensorHandle>(weights);
-
- RunShapeInferenceTest<FullyConnectedLayer>(layer, {{ 1, outputChannels }});
+ CreateGraphAndRunTest<FullyConnectedLayer>({{ 1, inputChannels, inputHeight, inputWidth }, // input
+ { inputChannels, outputChannels }}, // weights
+ {{ 1, outputChannels }}, // output
+ FullyConnectedDescriptor(),
+ "fc");
}
TEST_CASE("GatherTest")