aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 1f24f71..82c107e 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -356,9 +356,9 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
return "";
}
-// this is a counter part to Type2PoolAccumDType
-mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+// this is a counter part to Type2PoolAccDType
+mlir::TypeAttr AccDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
if (dtype == DType_INT32) {
return mlir::TypeAttr::get(op_builder->getI32Type());
} else if (dtype == DType_FP32) {
@@ -366,7 +366,7 @@ mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
} else if (dtype == DType_FP16) {
return mlir::TypeAttr::get(op_builder->getF16Type());
} else {
- // unknown accum type
+ // unknown acc type
// for now, default to F32
return mlir::TypeAttr::get(op_builder->getF32Type());
}
@@ -504,7 +504,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
mlir::DenseI64ArrayAttr stride =
BuildDenseI64ArrayAttr(op_builder, attr->stride());
mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
- auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype());
+ auto acc_attr = AccDType2TypeAttr(op_builder, attr->acc_type());
int32_t input_zp = attr->input_zp();
int32_t output_zp = attr->output_zp();
@@ -1526,8 +1526,8 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>(
Attribute_CondIfAttribute); // double check attribute type
TosaCondIfAttribute *attr =
static_cast<TosaCondIfAttribute *>(op->GetAttribute());
- auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch());
- auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch());
+ auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph());
+ auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph());
if (!ser_then_region || !ser_else_region) {
llvm::errs() << "ERROR: " << get_string(op)
@@ -1591,8 +1591,8 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>(
Attribute_WhileLoopAttribute); // double check attribute type
TosaWhileLoopAttribute *attr =
static_cast<TosaWhileLoopAttribute *>(op->GetAttribute());
- auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch());
- auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch());
+ auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph());
+ auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph());
mlir::Operation *mlir_op =
op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values);