10 #include <doctest/doctest.h> 12 using namespace armnn;
17 TEST_CASE(
"ExpandDimsAxis0Test")
23 CHECK(outputShape.GetNumDimensions() == 4);
24 CHECK(outputShape[0] == 1);
25 CHECK(outputShape[1] == 2);
26 CHECK(outputShape[2] == 3);
27 CHECK(outputShape[3] == 4);
30 TEST_CASE(
"ExpandDimsAxis1Test")
36 CHECK(outputShape.GetNumDimensions() == 4);
37 CHECK(outputShape[0] == 2);
38 CHECK(outputShape[1] == 1);
39 CHECK(outputShape[2] == 3);
40 CHECK(outputShape[3] == 4);
43 TEST_CASE(
"ExpandDimsAxis2Test")
49 CHECK(outputShape.GetNumDimensions() == 4);
50 CHECK(outputShape[0] == 2);
51 CHECK(outputShape[1] == 3);
52 CHECK(outputShape[2] == 1);
53 CHECK(outputShape[3] == 4);
56 TEST_CASE(
"ExpandDimsAxis3Test")
62 CHECK(outputShape.GetNumDimensions() == 4);
63 CHECK(outputShape[0] == 2);
64 CHECK(outputShape[1] == 3);
65 CHECK(outputShape[2] == 4);
66 CHECK(outputShape[3] == 1);
69 TEST_CASE(
"ExpandDimsNegativeAxis1Test")
75 CHECK(outputShape.GetNumDimensions() == 4);
76 CHECK(outputShape[0] == 2);
77 CHECK(outputShape[1] == 3);
78 CHECK(outputShape[2] == 4);
79 CHECK(outputShape[3] == 1);
82 TEST_CASE(
"ExpandDimsNegativeAxis2Test")
88 CHECK(outputShape.GetNumDimensions() == 4);
89 CHECK(outputShape[0] == 2);
90 CHECK(outputShape[1] == 3);
91 CHECK(outputShape[2] == 1);
92 CHECK(outputShape[3] == 4);
95 TEST_CASE(
"ExpandDimsNegativeAxis3Test")
101 CHECK(outputShape.GetNumDimensions() == 4);
102 CHECK(outputShape[0] == 2);
103 CHECK(outputShape[1] == 1);
104 CHECK(outputShape[2] == 3);
105 CHECK(outputShape[3] == 4);
108 TEST_CASE(
"ExpandDimsNegativeAxis4Test")
114 CHECK(outputShape.GetNumDimensions() == 4);
115 CHECK(outputShape[0] == 1);
116 CHECK(outputShape[1] == 2);
117 CHECK(outputShape[2] == 3);
118 CHECK(outputShape[3] == 4);
121 TEST_CASE(
"ExpandDimsInvalidAxisTest")
129 TEST_CASE(
"ExpandDimsInvalidNegativeAxisTest")
TEST_SUITE("TestConstTensorLayerVisitor")
Copyright (c) 2021 ARM Limited and Contributors.
armnn::TensorShape ExpandDims(const armnn::TensorShape &tensorShape, int axis)