From 6103155b3a2a555c3fc4a3a2173b35ea573c9600 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 13 Mar 2024 19:19:53 +0000 Subject: [tosa_mlir_translator] Fix fp16, bf16 and fp8 serialization Fix serialization and deserialization of fp16, bf16 and fp8 for pad_const, clamp min_val/max_val, and const values Signed-off-by: Tai Ly Change-Id: Ia39a17d2f395584d5555d2c86cdae7113cf14e3f --- src/TosaDeserialize.cpp | 277 +++++++++++++++++------------------------- src/TosaSerialize.cpp | 270 +++++++++++++++++++++++----------------- third_party/serialization_lib | 2 +- 3 files changed, 270 insertions(+), 279 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 6fa691e..fdbd892 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -166,58 +166,70 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, return mlir::success(); } -mlir::DenseElementsAttr -ConstructConstAttr(const mlir::RankedTensorType &output_type, - TosaSerializationTensor *ts, const std::string &op_name) { - const auto &data = ts->GetData(); - auto &shape = ts->GetShape(); - // compute output data size - uint32_t out_size = 1; - for (const auto dim : shape) { - out_size *= dim; - } - mlir::DenseElementsAttr value_attr; - switch (ts->GetDtype()) { - case DType_FP32: - case DType_BF16: - case DType_FP8E4M3: - case DType_FP8E5M2: { - // for FP32, FP16 and FP8 types, value attributes are stored as FP32 values +mlir::DenseElementsAttr GetConstAttr(const std::vector &data, + const mlir::RankedTensorType &output_type, + uint32_t out_size) { + auto element_type = output_type.getElementType(); + if (element_type.isF32()) { + // for FP32, value attributes are stored as FP32 values std::vector float_data; TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data)); - break; + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(float_data)); + } + if (element_type.isBF16()) { + mlir::SmallVector bf16_data; + for (uint32_t i = 0; i < out_size; i++) { + uint64_t byte0 = data[i * sizeof(int16_t)]; + uint64_t byte1 = data[i * sizeof(int16_t) + 1]; + uint64_t bits = byte0 + (byte1 << 8); + mlir::APInt bf16_bits(16, bits); + mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits); + bf16_data.push_back(bf16); + } + return mlir::DenseElementsAttr::get(output_type, bf16_data); + } + if (element_type.isFloat8E4M3FN()) { + mlir::SmallVector f8_data; + for (uint32_t i = 0; i < out_size; i++) { + mlir::APInt f8_bits(8, static_cast(data[i])); + mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits); + f8_data.push_back(f8); + } + return mlir::DenseElementsAttr::get(output_type, f8_data); + } + if (element_type.isFloat8E5M2()) { + mlir::SmallVector f8_data; + for (uint32_t i = 0; i < out_size; i++) { + mlir::APInt f8_bits(8, static_cast(data[i])); + mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits); + f8_data.push_back(f8); + } + return mlir::DenseElementsAttr::get(output_type, f8_data); } - case DType_INT4: { + if (element_type.isInteger(4)) { std::vector int4_data; TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); - break; + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); } - case DType_INT8: { + if (element_type.isInteger(8)) { std::vector int8_data; TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); - break; + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); } - case DType_INT16: { + if (element_type.isInteger(16)) { std::vector int16_data; TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data)); - break; + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(int16_data)); } - case DType_INT32: { + if (element_type.isInteger(32)) { std::vector int32_data; TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data)); - break; + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(int32_data)); } - case DType_INT48: { + if (element_type.isInteger(48)) { std::vector int48_data; TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); std::vector apint_data; @@ -226,34 +238,38 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type, /* isSigned = */ false); apint_data.push_back(apint_value); } - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data)); - break; + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(apint_data)); } - case DType_BOOL: { + if (element_type.isInteger(1)) { std::vector bool_data; TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); llvm::SmallVector bool_values(bool_data.begin(), bool_data.end()); - value_attr = mlir::DenseElementsAttr::get(output_type, bool_values); - break; + return mlir::DenseElementsAttr::get(output_type, bool_values); } - case DType_FP16: { + if (element_type.isF16()) { std::vector half_data; TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data)); - break; + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data)); } - case DType_UINT8: - case DType_UINT16: - default: { + + return nullptr; +} + +mlir::DenseElementsAttr +ConstructConstAttr(const mlir::RankedTensorType &output_type, + TosaSerializationTensor *ts, const std::string &op_name) { + // compute output data size + uint32_t out_size = 1; + for (const auto dim : ts->GetShape()) { + out_size *= dim; + } + auto attr = GetConstAttr(ts->GetData(), output_type, out_size); + if (!attr) { llvm::errs() << "ERROR: " << op_name << " contains unsupported element type\n"; - return nullptr; } - } - - return value_attr; + return attr; } mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) { @@ -942,56 +958,31 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { TosaClampAttribute *attr = static_cast(op->GetAttribute()); - mlir::Type input_element_type = + mlir::Type element_type = llvm::cast(input_val.getType()).getElementType(); - if (auto quantType = llvm::dyn_cast( - input_element_type)) { - input_element_type = quantType.getStorageType(); + if (auto quantType = + llvm::dyn_cast(element_type)) { + element_type = quantType.getStorageType(); } + auto element_const_type = mlir::RankedTensorType::get({1}, element_type); + auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1); + auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1); + mlir::Attribute min_val_attr, max_val_attr; - if (input_element_type.isa()) { - std::vector min_float_data, max_float_data; - TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1, - min_float_data); - TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1, - max_float_data); - min_val_attr = - op_builder->getFloatAttr(input_element_type, min_float_data[0]); - max_val_attr = - op_builder->getFloatAttr(input_element_type, max_float_data[0]); + if (element_type.isa()) { + min_val_attr = op_builder->getFloatAttr( + element_type, min_values_attr.getValues()[0]); + max_val_attr = op_builder->getFloatAttr( + element_type, max_values_attr.getValues()[0]); } else { - std::vector min_int_data, max_int_data; - TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1, - min_int_data); - TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1, - max_int_data); - if (input_element_type.isUnsignedInteger()) { - if (input_element_type.isUnsignedInteger(8)) { - uint8_t min_val = min_int_data[0]; - uint8_t max_val = max_int_data[0]; - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else if (input_element_type.isUnsignedInteger(16)) { - uint16_t min_val = min_int_data[0]; - uint16_t max_val = max_int_data[0]; - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else { - llvm::errs() - << "ERROR: " << get_string(op) - << " contains unsupported unsigned int element data type.\n"; - return {}; - } - } else { - min_val_attr = - op_builder->getIntegerAttr(input_element_type, min_int_data[0]); - max_val_attr = - op_builder->getIntegerAttr(input_element_type, max_int_data[0]); - } + min_val_attr = op_builder->getIntegerAttr( + element_type, min_values_attr.getValues()[0]); + max_val_attr = op_builder->getIntegerAttr( + element_type, max_values_attr.getValues()[0]); } - mlir::Operation *mlir_op = op_builder->create( + auto mlir_op = op_builder->create( loc, output_type, input_val, min_val_attr, max_val_attr); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -1101,80 +1092,36 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { assert(op->GetAttributeType() == Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast(op->GetAttribute()); - - float pad_const_fp = 0.0f; - int32_t pad_const_int = 0; - - if (element_type.isa()) { - std::vector float_data; - TosaSerializationHandler::ConvertU8toF32(attr->pad_const(), - /* size = */ 1, float_data); - pad_const_fp = float_data[0]; - } else { - std::vector int32_data; - TosaSerializationHandler::ConvertU8toI32(attr->pad_const(), - /* size = */ 1, int32_data); - pad_const_int = int32_data[0]; + const auto &pad_const_u8_data = attr->pad_const(); + + // check for any value in pad_const_u8_data + bool has_pad_const = false; + for (auto v : pad_const_u8_data) { + if (v != 0) { + has_pad_const = true; + break; + } } - - // todo: int input_zp = attr->pad_input_zp(); - - mlir::Operation *mlir_op; - mlir::Value pad_const_value; - - bool isBoolType = element_type.isInteger(1); - // First handle boolean type. - if (isBoolType) { - mlir::Type boolType = op_builder->getIntegerType(1); - auto pad_const_type = mlir::RankedTensorType::get({}, boolType); - // Treat zero integer is `false`, and any non-zero integner evaluates to - // `true`. - bool pad_const = pad_const_int == 0 ? false : true; - auto pad_const_attr = - mlir::DenseElementsAttr::get(pad_const_type, {pad_const}); - mlir::Operation *pad_const_op = op_builder->create( - loc, pad_const_type, pad_const_attr); - - block->push_back(pad_const_op); - pad_const_value = pad_const_op->getResult(0); - mlir_op = op_builder->create( - loc, output_type, input_val, padding_val, pad_const_value); - + if (!has_pad_const) { + // handle the cases where no explicit pad_const input. + auto mlir_op = op_builder->create( + loc, output_type, input_val, padding_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } - // Second handle the cases where no explicit pad_const input. - if (pad_const_int == 0 && pad_const_fp == 0.0f) { - mlir_op = op_builder->create(loc, output_type, input_val, - padding_val); - block->push_back(mlir_op); - return std::vector({mlir_op->getResult(0)}); - } + // has pad const - create a const op for pad_const input + auto pad_const_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1); - // Then handle explicit numerical pad_const cases. - if (pad_const_int != 0) { - assert(pad_const_fp == 0.0f && llvm::isa(element_type)); - auto pad_const_int_type = mlir::RankedTensorType::get({}, element_type); - auto pad_const_int_attr = - mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); - mlir::Operation *pad_const_int_op = op_builder->create( - loc, pad_const_int_type, pad_const_int_attr); - block->push_back(pad_const_int_op); - pad_const_value = pad_const_int_op->getResult(0); - } else { - assert(pad_const_fp != 0 && llvm::isa(element_type)); - auto pad_const_fp_type = mlir::RankedTensorType::get({}, element_type); - auto pad_const_fp_attr = - mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); - mlir::Operation *pad_const_fp_op = op_builder->create( - loc, pad_const_fp_type, pad_const_fp_attr); - block->push_back(pad_const_fp_op); - pad_const_value = pad_const_fp_op->getResult(0); - } - - mlir_op = op_builder->create(loc, output_type, input_val, - padding_val, pad_const_value); + auto pad_const_op = op_builder->create( + loc, pad_const_type, pad_const_attr); + + block->push_back(pad_const_op); + mlir::Value pad_const_value = pad_const_op->getResult(0); + + auto mlir_op = op_builder->create( + loc, output_type, input_val, padding_val, pad_const_value); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 875303e..6553944 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -152,9 +152,29 @@ public: TosaSerializationHandler *GetTsh() const; TosaSerializationRegionBuilder *GetRegionBuilder() const; mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op, - mlir::Attribute &attr, DType dtype, + mlir::Attribute &attr, + mlir::Type element_type, std::vector &u8_data) const; + // populate u8_data with either int64_value or float_value depending on + // element_type + mlir::LogicalResult + GetU8DataFromIntOrFloatValue(int64_t int64_value, float fp_value, + mlir::Type element_type, + std::vector &u8_data) const; + + // populate u8_data with int_value depending on non-float element_type + mlir::LogicalResult + GetU8DataFromIntValues(const std::vector &int_values, + mlir::Type element_type, + std::vector &u8_data) const; + + // populate u8_data with fp_value depending on float element_type + mlir::LogicalResult + GetU8DataFromFloatValues(const std::vector &fp_values, + mlir::Type element_type, + std::vector &u8_data) const; + private: std::string GetTensorName(mlir::Value val) const; std::string GetVariableTensorName(mlir::Operation *op) const; @@ -303,134 +323,146 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName( } mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( - mlir::Operation &op, mlir::Attribute &attr, DType type, + mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type, std::vector &u8_data) const { + if (!element_type.isIntOrFloat()) { + return mlir::failure(); + } auto dense_attr = attr.dyn_cast(); - switch (type) { - case DType_FP32: - case DType_BF16: - case DType_FP16: - case DType_FP8E4M3: - case DType_FP8E5M2: { - std::vector data; + // handle float types + if (element_type.isa()) { + std::vector fp_data; auto val_attr = attr.dyn_cast(); if (dense_attr) { for (auto val : dense_attr.getValues()) { - data.push_back(val.convertToFloat()); + fp_data.push_back(val.convertToFloat()); } } else if (val_attr) { - data.push_back((float)val_attr.getValueAsDouble()); + fp_data.push_back((float)val_attr.getValueAsDouble()); } else { op.emitOpError("Unknown const attribute"); return mlir::failure(); } - if (type == DType_FP16) { - TosaSerializationHandler::ConvertF16toU8(data, u8_data); - } else { - // for all other floating types, store as F32 values - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } - break; + return GetU8DataFromFloatValues(fp_data, element_type, u8_data); } - case DType_INT8: { - std::vector data; - auto val_attr = attr.dyn_cast(); - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return mlir::failure(); + // element_type is integer type + + bool isInt48 = element_type.isInteger(48); + std::vector i64_data; + + auto val_attr = attr.dyn_cast(); + if (dense_attr) { + for (auto valueIt : dense_attr.getValues()) { + int64_t val = isInt48 ? static_cast(valueIt.getLimitedValue()) + : valueIt.getSExtValue(); + i64_data.push_back(val); } - TosaSerializationHandler::ConvertI8toU8(data, u8_data); - break; + } else if (val_attr) { + i64_data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return mlir::failure(); } - case DType_INT16: { - std::vector data; - auto val_attr = attr.dyn_cast(); - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return mlir::failure(); + return GetU8DataFromIntValues(i64_data, element_type, u8_data); +} + +mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntValues( + const std::vector &int64_values, mlir::Type element_type, + std::vector &u8_data) const { + switch (element_type.getIntOrFloatBitWidth()) { + case 1: { + // bool use bool vec + std::vector bool_values; + for (auto v : int64_values) { + bool bool_value = v == 0 ? false : true; + bool_values.push_back(bool_value); } - TosaSerializationHandler::ConvertI16toU8(data, u8_data); + TosaSerializationHandler::ConvertBooltoU8(bool_values, u8_data); break; } - case DType_INT32: { - std::vector data; - auto val_attr = attr.dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); + case 4: + case 8: { + // I4 and I8 use int8_t vec + std::vector i8_values; + for (auto v : int64_values) { + i8_values.push_back(static_cast(v)); + } + if (element_type.isInteger(4)) { + TosaSerializationHandler::ConvertI4toU8(i8_values, u8_data); } else { - op.emitOpError("Unknown const attribute"); - return mlir::failure(); + TosaSerializationHandler::ConvertI8toU8(i8_values, u8_data); } - TosaSerializationHandler::ConvertI32toU8(data, u8_data); break; } - case DType_INT48: { - std::vector data; - auto val_attr = attr.dyn_cast(); - - if (dense_attr) { - for (auto valueIt : dense_attr.getValues()) { - uint64_t val = valueIt.getLimitedValue(); - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return mlir::failure(); + case 16: { + // I16 use int16_t vec + std::vector i16_values; + for (auto v : int64_values) { + i16_values.push_back(static_cast(v)); } - TosaSerializationHandler::ConvertI48toU8(data, u8_data); + TosaSerializationHandler::ConvertI16toU8(i16_values, u8_data); break; } - case DType_BOOL: { - std::vector data; - auto val_attr = attr.dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getValue()); - } else { - op.emitOpError("Unknown const attribute"); - return mlir::failure(); + case 32: { + // I32 use int32_t vec + std::vector i32_values; + for (auto v : int64_values) { + i32_values.push_back(static_cast(v)); } - - TosaSerializationHandler::ConvertBooltoU8(data, u8_data); + TosaSerializationHandler::ConvertI32toU8(i32_values, u8_data); + break; + } + case 48: { + // I48 use int64_t vec + TosaSerializationHandler::ConvertI48toU8(int64_values, u8_data); break; } default: { - op.emitOpError("Unknown element type of const attribute"); + // unsupported bit widths return mlir::failure(); } } + return mlir::success(); +} +mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromFloatValues( + const std::vector &fp_values, mlir::Type element_type, + std::vector &u8_data) const { + assert( + element_type + .isa()); // this should only be called for float type + if (element_type.isF16()) { + TosaSerializationHandler::ConvertF16toU8(fp_values, u8_data); + } else if (element_type.isBF16()) { + TosaSerializationHandler::ConvertBF16toU8(fp_values, u8_data); + } else if (element_type.isFloat8E4M3FN()) { + TosaSerializationHandler::ConvertFP8E4M3toU8(fp_values, u8_data); + } else if (element_type.isFloat8E5M2()) { + TosaSerializationHandler::ConvertFP8E5M2toU8(fp_values, u8_data); + } else if (element_type.isF32()) { + TosaSerializationHandler::ConvertF32toU8(fp_values, u8_data); + } else { + return mlir::failure(); + } return mlir::success(); } +mlir::LogicalResult +TosaSerializationOperatorBuilder::GetU8DataFromIntOrFloatValue( + int64_t int64_value, float fp_value, mlir::Type element_type, + std::vector &u8_data) const { + if (element_type.isa()) { + return GetU8DataFromFloatValues({fp_value}, element_type, u8_data); + } else { + return GetU8DataFromIntValues({int64_value}, element_type, u8_data); + } +} + // Main template to catch unimplemented translation. template TosaSerializationOperator * @@ -691,9 +723,12 @@ TosaSerializationOperatorBuilder::build( } std::vector u8_data; mlir::Attribute attr = op.getAttr(llvm::StringRef("value")); - DType type = ts->GetDtype(); + mlir::Type element_type = + llvm::cast(op.getResult(0).getType()).getElementType(); - if (GetDataFromAttribute(op, attr, type, u8_data).failed()) { + if (GetDataFromAttribute(op, attr, element_type, u8_data).failed()) { + op.emitOpError("ERROR: GetDataFromAttribute() fails when building value of " + "const tensor"); return nullptr; } @@ -977,26 +1012,32 @@ TosaSerializationOperatorBuilder::build( } std::vector min_val, max_val; + float min_fp, max_fp; + int64_t min_int, max_int; + if (input_element_type.isa()) { - auto min_fp = + min_fp = mlir::cast(min_val_attr).getValue().convertToFloat(); - auto max_fp = + max_fp = mlir::cast(max_val_attr).getValue().convertToFloat(); - TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val); - TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val); + min_int = max_int = 0; } else { - int32_t min_int = 0; - int32_t max_int = 0; - if (input_element_type.isUnsignedInteger()) { - min_int = mlir::cast(min_val_attr).getUInt(); - max_int = mlir::cast(max_val_attr).getUInt(); - } else { - assert(input_element_type.isa()); - min_int = mlir::cast(min_val_attr).getInt(); - max_int = mlir::cast(max_val_attr).getInt(); - } - TosaSerializationHandler::ConvertI32toU8({min_int}, min_val); - TosaSerializationHandler::ConvertI32toU8({max_int}, max_val); + assert(input_element_type.isa()); + min_int = mlir::cast(min_val_attr).getInt(); + max_int = mlir::cast(max_val_attr).getInt(); + min_fp = max_fp = 0.f; + } + + if (GetU8DataFromIntOrFloatValue(min_int, min_fp, input_element_type, min_val) + .failed()) { + op.emitOpError("Failed to serialize min value"); + return nullptr; + } + + if (GetU8DataFromIntOrFloatValue(max_int, max_fp, input_element_type, max_val) + .failed()) { + op.emitOpError("Failed to serialize max value"); + return nullptr; } std::string input_name = GetTensorName(op.getOperand(0)); @@ -1150,11 +1191,14 @@ TosaSerializationOperatorBuilder::build( std::vector pad_const; mlir::Type input_element_type = llvm::cast(op.getOperand(0).getType()).getElementType(); - if (input_element_type.isa()) { - TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const); - } else { - TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const); + + if (GetU8DataFromIntOrFloatValue(pad_const_int, pad_const_fp, + input_element_type, pad_const) + .failed()) { + op.emitOpError("Failed to serialize pad_const value"); + return nullptr; } + TosaPadAttribute attribute(pad_const); TosaSerializationOperator *tyop = new TosaSerializationOperator( @@ -1806,11 +1850,11 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( // zeros mlir::Attribute initial_value = op->getAttr("initial_value"); std::vector u8_data; - DType element_type = Type2DType(tensor_type.getElementType()); if (initial_value) { if (initial_value.isa()) { if (op_builder - .GetDataFromAttribute(*op, initial_value, element_type, u8_data) + .GetDataFromAttribute(*op, initial_value, + tensor_type.getElementType(), u8_data) .failed()) { llvm::errs() << "ERROR: GetDataFromAttribute() fails when building " "initial_value of variable tensor\n"; diff --git a/third_party/serialization_lib b/third_party/serialization_lib index ad78daa..57d7818 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit ad78daaf0fa1e41742cbed314459c3dbbb483c20 +Subproject commit 57d781883142db8a45fe98ac1a1dfacc49cba78a -- cgit v1.2.1