ArmNN
 21.11
OutputShapeOfSqueeze.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../TfLiteParser.hpp"
7 
8 #include <doctest/doctest.h>
9 
10 TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
11 {
12 
13 struct TfLiteParserFixture
14 {
15 
17  unsigned int m_InputShape[4];
18 
19  TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
20  ~TfLiteParserFixture() { }
21 
22 };
23 
24 TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
25 {
26 
27  std::vector<uint32_t> squeezeDims = { };
28 
29  armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
30  armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
31  CHECK(outputTensorInfo.GetNumElements() == 4);
32  CHECK(outputTensorInfo.GetNumDimensions() == 2);
33  CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
34 };
35 
36 TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
37 {
38  std::vector<uint32_t> squeezeDims = { 1, 2 };
39 
40  armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
41  armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
42  CHECK(outputTensorInfo.GetNumElements() == 4);
43  CHECK(outputTensorInfo.GetNumDimensions() == 4);
44  CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
45 };
46 
47 TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
48 {
49  std::vector<uint32_t> squeezeDims = { 1, 3 };
50 
51  armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
52  armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
53  CHECK(outputTensorInfo.GetNumElements() == 4);
54  CHECK(outputTensorInfo.GetNumDimensions() == 3);
55  CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
56 };
57 
58 }
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
unsigned int GetNumElements() const
Definition: Tensor.hpp:196