From 9e899bb534956d2d08829beabe58a42a96531d08 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Fri, 7 Oct 2022 23:34:48 +0000 Subject: Updates to work with the new FP16 serialization code Adds accumulator data type where needed, and incorporates the new submodule Change-Id: Ice1d5508bd94812b0092e6a6238abb14f1bbc399 Signed-off-by: Eric Kunze --- src/TosaSerialize.cpp | 61 ++++++++++++++++++++++++++++++++++++------- third_party/serialization_lib | 2 +- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 6692932..7ba4bf2 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -61,7 +61,7 @@ template <> struct equal_to { } // namespace std -ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { +static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { if (mode_str == "NEAREST_NEIGHBOR") return ResizeMode_NEAREST; else if (mode_str == "BILINEAR") @@ -70,7 +70,7 @@ ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { return ResizeMode_UNKNOWN; } -DType Type2DType(mlir::Type element_type) { +static DType Type2DType(mlir::Type element_type) { if (element_type.isF64() || element_type.isF32() || element_type.isF16() || element_type.isBF16()) { return DType_FLOAT; @@ -94,6 +94,27 @@ DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } +static DType Type2AccumDType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32() || element_type.isF16() || + element_type.isBF16()) { + return DType_FLOAT; + } else if (element_type.isInteger(8)) { + return DType_INT32; + } else if (element_type.isInteger(16)) { + return DType_INT48; + } + return DType_UNKNOWN; +} + +static DType Type2PoolAccumDType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32() || element_type.isF16() || + element_type.isBF16()) { + return DType_FLOAT; + } else if (element_type.isInteger(8) || element_type.isInteger(16)) { + return DType_INT32; + } + return DType_UNKNOWN; +} class TosaSerializationBlockBuilder; class TosaSerializationOperatorBuilder { @@ -207,7 +228,10 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0; - TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp); + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2PoolAccumDType(tensor.getElementType()); + TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( opcode, Attribute_PoolAttribute, &attribute, @@ -530,8 +554,11 @@ TosaSerializationOperatorBuilder::build( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -577,8 +604,11 @@ TosaSerializationOperatorBuilder::build( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -624,8 +654,11 @@ TosaSerializationOperatorBuilder::build( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -672,8 +705,11 @@ TosaSerializationOperatorBuilder::build( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp); + TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -697,7 +733,11 @@ TosaSerializationOperatorBuilder::build( int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - TosaFullyConnectedAttribute attribute(input_zp, weight_zp); + + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2AccumDType(tensor.getElementType()); + TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, @@ -720,8 +760,11 @@ TosaSerializationOperatorBuilder::build( int32_t A_zp = quant_info ? quant_info.getAZp() : 0; int32_t B_zp = quant_info ? quant_info.getBZp() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast(); + DType type = Type2AccumDType(tensor.getElementType()); - TosaMatMulAttribute attribute(A_zp, B_zp); + TosaMatMulAttribute attribute(A_zp, B_zp, type); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MATMUL, Attribute_MatMulAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 4381b3d..485a11d 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 4381b3d7fcb7cab975f46c62c86a35c53ade047f +Subproject commit 485a11d8cb67c8062c632f0987cd31cedbe93d6d -- cgit v1.2.1