aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp
blob: 395038d9593ddd437c1338c08948281e96b7a111 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "../TfLiteParser.hpp"
#include <iostream>
#include <string>

#include <doctest/doctest.h>

TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
{

struct TfLiteParserFixture
{

    armnnTfLiteParser::TfLiteParserImpl m_Parser;
    unsigned int m_InputShape[4];

    TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
    ~TfLiteParserFixture()          {  }

};

TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
{

    std::vector<uint32_t> squeezeDims = {  };

    armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
    armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
    CHECK(outputTensorInfo.GetNumElements() == 4);
    CHECK(outputTensorInfo.GetNumDimensions() == 2);
    CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
};

TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
{
    std::vector<uint32_t> squeezeDims = { 1, 2 };

    armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
    armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
    CHECK(outputTensorInfo.GetNumElements() == 4);
    CHECK(outputTensorInfo.GetNumDimensions() == 4);
    CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
};

TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
{
    std::vector<uint32_t> squeezeDims = { 1, 3 };

    armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
    armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
    CHECK(outputTensorInfo.GetNumElements() == 4);
    CHECK(outputTensorInfo.GetNumDimensions() == 3);
    CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
};

}