aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/test/QuantizeChecker.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaCommon/test/QuantizeChecker.hpp')
-rw-r--r--src/backends/tosaCommon/test/QuantizeChecker.hpp105
1 files changed, 105 insertions, 0 deletions
diff --git a/src/backends/tosaCommon/test/QuantizeChecker.hpp b/src/backends/tosaCommon/test/QuantizeChecker.hpp
new file mode 100644
index 0000000000..1a35903934
--- /dev/null
+++ b/src/backends/tosaCommon/test/QuantizeChecker.hpp
@@ -0,0 +1,105 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "TosaTestUtils.hpp"
+
+using namespace armnn;
+using namespace tosa;
+
+void VerifyQuantize(TosaSerializationBasicBlock* quantizeBlock,
+ std::vector<int32_t> shape,
+ DType inputDataType = DType_FP32,
+ DType outputDataType = DType_FP32)
+{
+ std::string blockStr = "Op_QUANTIZE_block_";
+ CHECK(quantizeBlock->GetName().find(blockStr) != std::string::npos);
+ CHECK(quantizeBlock->GetInputs().size() == 1);
+ CHECK(quantizeBlock->GetOutputs().size() == 1);
+ CHECK(quantizeBlock->GetOperators().size() == 5); // MUL, CONST, ADD, CONST, CAST
+ CHECK(quantizeBlock->GetTensors().size() == 6);
+
+ std::basic_string<char> blockInputName = quantizeBlock->GetInputs()[0];
+ std::basic_string<char> blockOutputName = quantizeBlock->GetOutputs()[0];
+
+ //
+ // Verify Constants
+ //
+ TosaSerializationOperator* constZeroPointOp = quantizeBlock->GetOperators().at(0);
+ CHECK(constZeroPointOp->GetAttributeType() == Attribute_NONE);
+ CHECK(constZeroPointOp->GetOp() == tosa::Op_CONST);
+
+ TosaSerializationOperator* constScaleOp = quantizeBlock->GetOperators().at(1);
+ CHECK(constScaleOp->GetAttributeType() == Attribute_NONE);
+ CHECK(constScaleOp->GetOp() == tosa::Op_CONST);
+
+ //
+ // Verify Multiplication
+ //
+ ElementwiseBinaryDescriptor mulDescriptor(BinaryOperation::Mul);
+ TosaSerializationOperator* mulOp = quantizeBlock->GetOperators().at(2);
+ CHECK(mulOp->GetAttributeType() == tosa::Attribute_MulAttribute);
+ CHECK(mulOp->GetOp() == tosa::Op_MUL);
+
+ CHECK(mulOp->GetInputTensorNames().size() == 2);
+ std::basic_string<char> mulInputName0 = mulOp->GetInputTensorNames()[0];
+ std::basic_string<char> mulInputName1 = mulOp->GetInputTensorNames()[1];
+
+ CHECK(blockInputName == mulInputName0);
+
+ TosaSerializationTensor* mulInputTensor0 = quantizeBlock->GetTensorByName(mulInputName0);
+ CHECK(mulInputTensor0->GetDtype() == inputDataType);
+ CHECK(mulInputTensor0->GetData().size() == 0);
+ CHECK(mulInputTensor0->GetShape() == shape);
+
+ TosaSerializationTensor* mulInputTensor1 = quantizeBlock->GetTensorByName(mulInputName1);
+ CHECK(mulInputTensor1->GetShape() == shape);
+
+ //
+ // Verify Addition
+ //
+ ElementwiseBinaryDescriptor addDescriptor(BinaryOperation::Add);
+ TosaSerializationOperator* addOp = quantizeBlock->GetOperators().at(3);
+ CHECK(addOp->GetAttributeType() == Attribute_NONE);
+ CHECK(addOp->GetOp() == tosa::Op_ADD);
+
+ CHECK(addOp->GetInputTensorNames().size() == 2);
+ std::basic_string<char> addInputName0 = addOp->GetInputTensorNames()[0];
+ std::basic_string<char> addInputName1 = addOp->GetInputTensorNames()[1];
+
+ TosaSerializationTensor* addInputTensor0 = quantizeBlock->GetTensorByName(addInputName0);
+ CHECK(addInputTensor0->GetDtype() == inputDataType);
+ CHECK(addInputTensor0->GetData().size() == 0);
+ CHECK(addInputTensor0->GetShape() == shape);
+
+ TosaSerializationTensor* addInputTensor1 = quantizeBlock->GetTensorByName(addInputName1);
+ CHECK(addInputTensor1->GetShape() == shape);
+
+ //
+ // Verify Cast
+ //
+ TosaSerializationOperator* castOp = quantizeBlock->GetOperators().at(4);
+ CHECK(castOp->GetAttributeType() == Attribute_NONE);
+ CHECK(castOp->GetOp() == tosa::Op_CAST);
+
+ CHECK(castOp->GetInputTensorNames().size() == 1);
+ CHECK(castOp->GetOutputTensorNames().size() == 1);
+
+ std::basic_string<char> castInputName = castOp->GetInputTensorNames()[0];
+ std::basic_string<char> castOutputName = castOp->GetOutputTensorNames()[0];
+
+ TosaSerializationTensor* castInputTensor = quantizeBlock->GetTensorByName(castInputName);
+ CHECK(castInputTensor->GetDtype() == inputDataType);
+ CHECK(castInputTensor->GetData().size() == 0);
+ CHECK(castInputTensor->GetShape() == shape);
+
+ TosaSerializationTensor* castOutputTensor = quantizeBlock->GetTensorByName(castOutputName);
+ CHECK(castOutputTensor->GetDtype() == outputDataType);
+ CHECK(castOutputTensor->GetData().size() == 0);
+ CHECK(castOutputTensor->GetShape() == shape);
+
+ CHECK(blockOutputName == castOutputName);
+
+
+} \ No newline at end of file