aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2022-10-07 23:34:48 +0000
committerEric Kunze <eric.kunze@arm.com>2022-10-11 10:28:21 -0700
commit9e899bb534956d2d08829beabe58a42a96531d08 (patch)
treeccc441e74019cf23682e45cdfc474f99ffeaa0db
parent2d31541b4303397f618b1090ecafe7998d30444b (diff)
downloadtosa_mlir_translator-9e899bb534956d2d08829beabe58a42a96531d08.tar.gz
Updates to work with the new FP16 serialization code
Adds accumulator data type where needed, and incorporates the new submodule Change-Id: Ice1d5508bd94812b0092e6a6238abb14f1bbc399 Signed-off-by: Eric Kunze <eric.kunze@arm.com>
-rw-r--r--src/TosaSerialize.cpp61
m---------third_party/serialization_lib0
2 files changed, 52 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 6692932..7ba4bf2 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -61,7 +61,7 @@ template <> struct equal_to<mlir::Value> {
} // namespace std
-ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
+static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
if (mode_str == "NEAREST_NEIGHBOR")
return ResizeMode_NEAREST;
else if (mode_str == "BILINEAR")
@@ -70,7 +70,7 @@ ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
return ResizeMode_UNKNOWN;
}
-DType Type2DType(mlir::Type element_type) {
+static DType Type2DType(mlir::Type element_type) {
if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
element_type.isBF16()) {
return DType_FLOAT;
@@ -94,6 +94,27 @@ DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
+static DType Type2AccumDType(mlir::Type element_type) {
+ if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
+ element_type.isBF16()) {
+ return DType_FLOAT;
+ } else if (element_type.isInteger(8)) {
+ return DType_INT32;
+ } else if (element_type.isInteger(16)) {
+ return DType_INT48;
+ }
+ return DType_UNKNOWN;
+}
+
+static DType Type2PoolAccumDType(mlir::Type element_type) {
+ if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
+ element_type.isBF16()) {
+ return DType_FLOAT;
+ } else if (element_type.isInteger(8) || element_type.isInteger(16)) {
+ return DType_INT32;
+ }
+ return DType_UNKNOWN;
+}
class TosaSerializationBlockBuilder;
class TosaSerializationOperatorBuilder {
@@ -207,7 +228,10 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
int32_t input_zp = quant_info ? quant_info.getInputZp() : 0;
int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0;
- TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp);
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2PoolAccumDType(tensor.getElementType());
+ TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
opcode, Attribute_PoolAttribute, &attribute,
@@ -530,8 +554,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
int32_t input_zp = quant_info ? quant_info.getInputZp() : 0;
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2AccumDType(tensor.getElementType());
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -577,8 +604,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
int32_t input_zp = quant_info ? quant_info.getInputZp() : 0;
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2AccumDType(tensor.getElementType());
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV3D, Attribute_ConvAttribute, &attribute,
@@ -624,8 +654,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
int32_t input_zp = quant_info ? quant_info.getInputZp() : 0;
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2AccumDType(tensor.getElementType());
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -672,8 +705,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
int32_t input_zp = quant_info ? quant_info.getInputZp() : 0;
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2AccumDType(tensor.getElementType());
- TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp);
+ TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,
@@ -697,7 +733,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>(
int32_t input_zp = quant_info ? quant_info.getInputZp() : 0;
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
- TosaFullyConnectedAttribute attribute(input_zp, weight_zp);
+
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2AccumDType(tensor.getElementType());
+ TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute,
@@ -720,8 +760,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>(
int32_t A_zp = quant_info ? quant_info.getAZp() : 0;
int32_t B_zp = quant_info ? quant_info.getBZp() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
+ DType type = Type2AccumDType(tensor.getElementType());
- TosaMatMulAttribute attribute(A_zp, B_zp);
+ TosaMatMulAttribute attribute(A_zp, B_zp, type);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_MATMUL, Attribute_MatMulAttribute, &attribute,
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 4381b3d7fcb7cab975f46c62c86a35c53ade047
+Subproject 485a11d8cb67c8062c632f0987cd31cedbe93d6