aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/test/SplitChecker.hpp
blob: edef4a1cf9aa0bd3e27b57e9d68d41e7688b4226 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "TosaTestUtils.hpp"

using namespace armnn;
using namespace tosa;

void VerifySplit(TosaSerializationBasicBlock* splitBlock,
                 std::vector<std::vector<int32_t>> inputShape,
                 std::vector<std::vector<int32_t>> outputShape,
                 const BaseDescriptor& splitDescriptor,
                 DType dataType = DType_FP32)
{
    uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
    uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());

    std::string blockStr = "Op_SPLIT_block_";
    CHECK(splitBlock->GetName().find(blockStr)  != std::string::npos);
    CHECK(splitBlock->GetInputs().size() == numInputs);
    CHECK(splitBlock->GetOutputs().size() == numOutputs);
    CHECK(splitBlock->GetOperators().size() == 3);
    CHECK(splitBlock->GetTensors().size() == 4);

    //
    // Verify slice operator
    //

    for (uint32_t i = 0; i < splitBlock->GetOperators().size(); i++)
    {
        TosaSerializationOperator *sliceOp = splitBlock->GetOperators().at(i);
        uint32_t sliceOpOutputs = 1;
        CHECK(sliceOp->GetInputTensorNames().size() == numInputs);
        CHECK(sliceOp->GetOutputTensorNames().size() == sliceOpOutputs);

        std::basic_string<char> blockInputName = splitBlock->GetInputs()[0];
        std::basic_string<char> operatorInputName = sliceOp->GetInputTensorNames()[0];

        std::string opInputStr = "input" + std::to_string(0) + "_";

        CHECK(blockInputName == operatorInputName);
        CHECK(splitBlock->GetTensorByName(blockInputName));
        CHECK(blockInputName.find(opInputStr) != std::string::npos);

        TosaSerializationTensor* inputTensor = splitBlock->GetTensorByName(operatorInputName);
        CHECK(inputTensor->GetDtype() == dataType);
        CHECK(inputTensor->GetData().size() == 0);
        CHECK(inputTensor->GetShape() == inputShape[0]);

        std::basic_string<char> blockOutputName = splitBlock->GetOutputs()[i];
        std::basic_string<char> operatorOutputName  = sliceOp->GetOutputTensorNames()[0];

        std::string opOutputStr = "output" + std::to_string(i) + "_";

        CHECK(blockOutputName == operatorOutputName);
        CHECK(splitBlock->GetTensorByName(blockOutputName));
        CHECK(blockOutputName.find(opOutputStr)  != std::string::npos);

        TosaSerializationTensor* outputTensor = splitBlock->GetTensorByName(operatorOutputName);
        CHECK(outputTensor->GetDtype() == dataType);
        CHECK(outputTensor->GetData().size() == 0);
        CHECK(outputTensor->GetShape() == outputShape[0]);

        CHECK(sliceOp->GetAttributeType() == Attribute_SliceAttribute);
        CHECK(sliceOp->GetOp() == Op_SLICE);

        VerifyTosaAttribute(splitDescriptor,
                            sliceOp->GetAttribute(),
                            inputShape[0],
                            outputShape[0],
                            LayerType::Splitter,
                            i);
    }

}