diff options
Diffstat (limited to 'src/backends/tosaCommon/test/QuantizeChecker.hpp')
-rw-r--r-- | src/backends/tosaCommon/test/QuantizeChecker.hpp | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/src/backends/tosaCommon/test/QuantizeChecker.hpp b/src/backends/tosaCommon/test/QuantizeChecker.hpp new file mode 100644 index 0000000000..1a35903934 --- /dev/null +++ b/src/backends/tosaCommon/test/QuantizeChecker.hpp @@ -0,0 +1,105 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "TosaTestUtils.hpp" + +using namespace armnn; +using namespace tosa; + +void VerifyQuantize(TosaSerializationBasicBlock* quantizeBlock, + std::vector<int32_t> shape, + DType inputDataType = DType_FP32, + DType outputDataType = DType_FP32) +{ + std::string blockStr = "Op_QUANTIZE_block_"; + CHECK(quantizeBlock->GetName().find(blockStr) != std::string::npos); + CHECK(quantizeBlock->GetInputs().size() == 1); + CHECK(quantizeBlock->GetOutputs().size() == 1); + CHECK(quantizeBlock->GetOperators().size() == 5); // MUL, CONST, ADD, CONST, CAST + CHECK(quantizeBlock->GetTensors().size() == 6); + + std::basic_string<char> blockInputName = quantizeBlock->GetInputs()[0]; + std::basic_string<char> blockOutputName = quantizeBlock->GetOutputs()[0]; + + // + // Verify Constants + // + TosaSerializationOperator* constZeroPointOp = quantizeBlock->GetOperators().at(0); + CHECK(constZeroPointOp->GetAttributeType() == Attribute_NONE); + CHECK(constZeroPointOp->GetOp() == tosa::Op_CONST); + + TosaSerializationOperator* constScaleOp = quantizeBlock->GetOperators().at(1); + CHECK(constScaleOp->GetAttributeType() == Attribute_NONE); + CHECK(constScaleOp->GetOp() == tosa::Op_CONST); + + // + // Verify Multiplication + // + ElementwiseBinaryDescriptor mulDescriptor(BinaryOperation::Mul); + TosaSerializationOperator* mulOp = quantizeBlock->GetOperators().at(2); + CHECK(mulOp->GetAttributeType() == tosa::Attribute_MulAttribute); + CHECK(mulOp->GetOp() == tosa::Op_MUL); + + CHECK(mulOp->GetInputTensorNames().size() == 2); + std::basic_string<char> mulInputName0 = mulOp->GetInputTensorNames()[0]; + std::basic_string<char> mulInputName1 = mulOp->GetInputTensorNames()[1]; + + CHECK(blockInputName == mulInputName0); + + TosaSerializationTensor* mulInputTensor0 = quantizeBlock->GetTensorByName(mulInputName0); + CHECK(mulInputTensor0->GetDtype() == inputDataType); + CHECK(mulInputTensor0->GetData().size() == 0); + CHECK(mulInputTensor0->GetShape() == shape); + + TosaSerializationTensor* mulInputTensor1 = quantizeBlock->GetTensorByName(mulInputName1); + CHECK(mulInputTensor1->GetShape() == shape); + + // + // Verify Addition + // + ElementwiseBinaryDescriptor addDescriptor(BinaryOperation::Add); + TosaSerializationOperator* addOp = quantizeBlock->GetOperators().at(3); + CHECK(addOp->GetAttributeType() == Attribute_NONE); + CHECK(addOp->GetOp() == tosa::Op_ADD); + + CHECK(addOp->GetInputTensorNames().size() == 2); + std::basic_string<char> addInputName0 = addOp->GetInputTensorNames()[0]; + std::basic_string<char> addInputName1 = addOp->GetInputTensorNames()[1]; + + TosaSerializationTensor* addInputTensor0 = quantizeBlock->GetTensorByName(addInputName0); + CHECK(addInputTensor0->GetDtype() == inputDataType); + CHECK(addInputTensor0->GetData().size() == 0); + CHECK(addInputTensor0->GetShape() == shape); + + TosaSerializationTensor* addInputTensor1 = quantizeBlock->GetTensorByName(addInputName1); + CHECK(addInputTensor1->GetShape() == shape); + + // + // Verify Cast + // + TosaSerializationOperator* castOp = quantizeBlock->GetOperators().at(4); + CHECK(castOp->GetAttributeType() == Attribute_NONE); + CHECK(castOp->GetOp() == tosa::Op_CAST); + + CHECK(castOp->GetInputTensorNames().size() == 1); + CHECK(castOp->GetOutputTensorNames().size() == 1); + + std::basic_string<char> castInputName = castOp->GetInputTensorNames()[0]; + std::basic_string<char> castOutputName = castOp->GetOutputTensorNames()[0]; + + TosaSerializationTensor* castInputTensor = quantizeBlock->GetTensorByName(castInputName); + CHECK(castInputTensor->GetDtype() == inputDataType); + CHECK(castInputTensor->GetData().size() == 0); + CHECK(castInputTensor->GetShape() == shape); + + TosaSerializationTensor* castOutputTensor = quantizeBlock->GetTensorByName(castOutputName); + CHECK(castOutputTensor->GetDtype() == outputDataType); + CHECK(castOutputTensor->GetData().size() == 0); + CHECK(castOutputTensor->GetShape() == shape); + + CHECK(blockOutputName == castOutputName); + + +}
\ No newline at end of file |