aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/test/QuantizeChecker.hpp
blob: 1a35903934024e2ac6653bcca08b308b63dde101 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "TosaTestUtils.hpp"

using namespace armnn;
using namespace tosa;

void VerifyQuantize(TosaSerializationBasicBlock* quantizeBlock,
                    std::vector<int32_t> shape,
                    DType inputDataType = DType_FP32,
                    DType outputDataType = DType_FP32)
{
    std::string blockStr = "Op_QUANTIZE_block_";
    CHECK(quantizeBlock->GetName().find(blockStr)  != std::string::npos);
    CHECK(quantizeBlock->GetInputs().size() == 1);
    CHECK(quantizeBlock->GetOutputs().size() == 1);
    CHECK(quantizeBlock->GetOperators().size() == 5); // MUL, CONST, ADD, CONST, CAST
    CHECK(quantizeBlock->GetTensors().size() == 6);

    std::basic_string<char> blockInputName = quantizeBlock->GetInputs()[0];
    std::basic_string<char> blockOutputName = quantizeBlock->GetOutputs()[0];

    //
    // Verify Constants
    //
    TosaSerializationOperator* constZeroPointOp = quantizeBlock->GetOperators().at(0);
    CHECK(constZeroPointOp->GetAttributeType() == Attribute_NONE);
    CHECK(constZeroPointOp->GetOp() == tosa::Op_CONST);

    TosaSerializationOperator* constScaleOp = quantizeBlock->GetOperators().at(1);
    CHECK(constScaleOp->GetAttributeType() == Attribute_NONE);
    CHECK(constScaleOp->GetOp() == tosa::Op_CONST);

    //
    // Verify Multiplication
    //
    ElementwiseBinaryDescriptor mulDescriptor(BinaryOperation::Mul);
    TosaSerializationOperator* mulOp = quantizeBlock->GetOperators().at(2);
    CHECK(mulOp->GetAttributeType() == tosa::Attribute_MulAttribute);
    CHECK(mulOp->GetOp() == tosa::Op_MUL);

    CHECK(mulOp->GetInputTensorNames().size() == 2);
    std::basic_string<char> mulInputName0 = mulOp->GetInputTensorNames()[0];
    std::basic_string<char> mulInputName1 = mulOp->GetInputTensorNames()[1];

    CHECK(blockInputName == mulInputName0);

    TosaSerializationTensor* mulInputTensor0 = quantizeBlock->GetTensorByName(mulInputName0);
    CHECK(mulInputTensor0->GetDtype() == inputDataType);
    CHECK(mulInputTensor0->GetData().size() == 0);
    CHECK(mulInputTensor0->GetShape() == shape);

    TosaSerializationTensor* mulInputTensor1 = quantizeBlock->GetTensorByName(mulInputName1);
    CHECK(mulInputTensor1->GetShape() == shape);

    //
    // Verify Addition
    //
    ElementwiseBinaryDescriptor addDescriptor(BinaryOperation::Add);
    TosaSerializationOperator* addOp = quantizeBlock->GetOperators().at(3);
    CHECK(addOp->GetAttributeType() == Attribute_NONE);
    CHECK(addOp->GetOp() == tosa::Op_ADD);

    CHECK(addOp->GetInputTensorNames().size() == 2);
    std::basic_string<char> addInputName0 = addOp->GetInputTensorNames()[0];
    std::basic_string<char> addInputName1 = addOp->GetInputTensorNames()[1];

    TosaSerializationTensor* addInputTensor0 = quantizeBlock->GetTensorByName(addInputName0);
    CHECK(addInputTensor0->GetDtype() == inputDataType);
    CHECK(addInputTensor0->GetData().size() == 0);
    CHECK(addInputTensor0->GetShape() == shape);

    TosaSerializationTensor* addInputTensor1 = quantizeBlock->GetTensorByName(addInputName1);
    CHECK(addInputTensor1->GetShape() == shape);

    //
    // Verify Cast
    //
    TosaSerializationOperator* castOp = quantizeBlock->GetOperators().at(4);
    CHECK(castOp->GetAttributeType() == Attribute_NONE);
    CHECK(castOp->GetOp() == tosa::Op_CAST);

    CHECK(castOp->GetInputTensorNames().size() == 1);
    CHECK(castOp->GetOutputTensorNames().size() == 1);

    std::basic_string<char> castInputName = castOp->GetInputTensorNames()[0];
    std::basic_string<char> castOutputName = castOp->GetOutputTensorNames()[0];

    TosaSerializationTensor* castInputTensor = quantizeBlock->GetTensorByName(castInputName);
    CHECK(castInputTensor->GetDtype() == inputDataType);
    CHECK(castInputTensor->GetData().size() == 0);
    CHECK(castInputTensor->GetShape() == shape);

    TosaSerializationTensor* castOutputTensor = quantizeBlock->GetTensorByName(castOutputName);
    CHECK(castOutputTensor->GetDtype() == outputDataType);
    CHECK(castOutputTensor->GetData().size() == 0);
    CHECK(castOutputTensor->GetShape() == shape);

    CHECK(blockOutputName == castOutputName);


}