diff options
author | Eric Kunze <eric.kunze@arm.com> | 2022-10-07 23:34:48 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-10-11 10:28:21 -0700 |
commit | 9e899bb534956d2d08829beabe58a42a96531d08 (patch) | |
tree | ccc441e74019cf23682e45cdfc474f99ffeaa0db | |
parent | 2d31541b4303397f618b1090ecafe7998d30444b (diff) | |
download | tosa_mlir_translator-9e899bb534956d2d08829beabe58a42a96531d08.tar.gz |
Updates to work with the new FP16 serialization code
Adds accumulator data type where needed, and incorporates the new submodule
Change-Id: Ice1d5508bd94812b0092e6a6238abb14f1bbc399
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
-rw-r--r-- | src/TosaSerialize.cpp | 61 | ||||
m--------- | third_party/serialization_lib | 0 |
2 files changed, 52 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 6692932..7ba4bf2 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -61,7 +61,7 @@ template <> struct equal_to<mlir::Value> { } // namespace std -ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { +static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { if (mode_str == "NEAREST_NEIGHBOR") return ResizeMode_NEAREST; else if (mode_str == "BILINEAR") @@ -70,7 +70,7 @@ ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { return ResizeMode_UNKNOWN; } -DType Type2DType(mlir::Type element_type) { +static DType Type2DType(mlir::Type element_type) { if (element_type.isF64() || element_type.isF32() || element_type.isF16() || element_type.isBF16()) { return DType_FLOAT; @@ -94,6 +94,27 @@ DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } +static DType Type2AccumDType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32() || element_type.isF16() || + element_type.isBF16()) { + return DType_FLOAT; + } else if (element_type.isInteger(8)) { + return DType_INT32; + } else if (element_type.isInteger(16)) { + return DType_INT48; + } + return DType_UNKNOWN; +} + +static DType Type2PoolAccumDType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32() || element_type.isF16() || + element_type.isBF16()) { + return DType_FLOAT; + } else if (element_type.isInteger(8) || element_type.isInteger(16)) { + return DType_INT32; + } + return DType_UNKNOWN; +} class TosaSerializationBlockBuilder; class TosaSerializationOperatorBuilder { @@ -207,7 +228,10 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0; - TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp); + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2PoolAccumDType(tensor.getElementType()); + TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( opcode, Attribute_PoolAttribute, &attribute, @@ -530,8 +554,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -577,8 +604,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -624,8 +654,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -672,8 +705,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp); + TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -697,7 +733,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - TosaFullyConnectedAttribute attribute(input_zp, weight_zp); + + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2AccumDType(tensor.getElementType()); + TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, @@ -720,8 +760,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>( int32_t A_zp = quant_info ? quant_info.getAZp() : 0; int32_t B_zp = quant_info ? quant_info.getBZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaMatMulAttribute attribute(A_zp, B_zp); + TosaMatMulAttribute attribute(A_zp, B_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MATMUL, Attribute_MatMulAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject 4381b3d7fcb7cab975f46c62c86a35c53ade047 +Subproject 485a11d8cb67c8062c632f0987cd31cedbe93d6 |