aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp252
1 files changed, 55 insertions, 197 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 04709b7..fc6655b 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -77,6 +77,10 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
static DType Type2DType(mlir::Type element_type) {
if (element_type.isF64() || element_type.isF32()) {
return DType_FP32;
+ } else if (element_type.isFloat8E5M2()) {
+ return DType_FP8E5M2;
+ } else if (element_type.isFloat8E4M3FN()) {
+ return DType_FP8E4M3;
} else if (element_type.isF16()) {
return DType_FP16;
} else if (element_type.isBF16()) {
@@ -101,29 +105,6 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-// Returns number of bits TOSA flatbuffer store in tensor raw data array
-uint64_t GetDTypeSize(DType dtype) {
- switch (dtype) {
- case DType_INT4:
- return 4;
- case DType_BOOL:
- case DType_UINT8:
- case DType_INT8:
- return 8;
- case DType_INT16:
- return 16;
- case DType_FP32:
- case DType_INT32:
- return 32;
- case DType_INT48:
- return 48;
- default:
- llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype]
- << "\n";
- return 1;
- }
-}
-
static DType Type2PoolAccumDType(mlir::Type element_type) {
// def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
if (element_type.isF32()) {
@@ -168,7 +149,8 @@ public:
TosaSerializationOperator *build(mlir::Operation &op) const;
TosaSerializationHandler *GetTsh() const;
TosaSerializationRegionBuilder *GetRegionBuilder() const;
- mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype,
+ mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op,
+ mlir::Attribute &attr, DType dtype,
std::vector<uint8_t> &u8_data) const;
private:
@@ -319,24 +301,39 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
}
mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
- mlir::Attribute &attr, DType type, std::vector<uint8_t> &u8_data) const {
+ mlir::Operation &op, mlir::Attribute &attr, DType type,
+ std::vector<uint8_t> &u8_data) const {
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
- if (type == DType_FP32) {
+
+ switch (type) {
+ case DType_FP32:
+ case DType_BF16:
+ case DType_FP16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2: {
std::vector<float> data;
auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
}
} else if (val_attr) {
data.push_back((float)val_attr.getValueAsDouble());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_INT8) {
+
+ if (type == DType_FP16) {
+ TosaSerializationHandler::ConvertF16toU8(data, u8_data);
+ } else {
+ // for all other floating types, store as F32 values
+ TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ }
+ break;
+ }
+ case DType_INT8: {
std::vector<int8_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -347,11 +344,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
+ break;
+ }
+ case DType_INT16: {
std::vector<int16_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -362,11 +361,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
+ break;
+ }
+ case DType_INT32: {
std::vector<int32_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -377,28 +378,32 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
+ break;
+ }
+ case DType_INT48: {
std::vector<int64_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<int64_t>()) {
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ uint64_t val = valueIt.getLimitedValue();
data.push_back(val);
}
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
+ break;
+ }
+ case DType_BOOL: {
std::vector<bool> data;
-
auto val_attr = attr.dyn_cast<mlir::BoolAttr>();
if (dense_attr) {
@@ -408,15 +413,18 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getValue());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- llvm::errs() << "Unknown element type of const attribute\n";
+ break;
+ }
+ default: {
+ op.emitOpError("Unknown element type of const attribute");
return mlir::failure();
}
+ }
return mlir::success();
}
@@ -669,19 +677,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
-#if 0
- // Gracefully handle constants of "constant unit" type which have no value
- // by creating a numpy value of 0.
- auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>();
-
- if (unit_val)
- {
- std::vector<float> data = { 0.0 };
- type = DType_FP32;
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- }
-#endif
-
// Update tensor.data array with Const value attribute
mlir::Attribute value_attr = op.getAttr("value");
if (!value_attr) {
@@ -689,139 +684,10 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
std::vector<uint8_t> u8_data;
-
+ mlir::Attribute attr = op.getAttr(llvm::StringRef("value"));
DType type = ts->GetDtype();
- if (type == DType_FP32) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_FP16) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF16toU8(data, u8_data);
- } else if (type == DType_INT8) {
- std::vector<int8_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int8_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
- std::vector<int16_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int16_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
- std::vector<int32_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int32_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
- std::vector<int64_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
- uint64_t val = valueIt.getLimitedValue();
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
- std::vector<bool> data;
-
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<bool>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getValue());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
-
- TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- op.emitOpError("Unknown element type of const attribute");
+ if (GetDataFromAttribute(op, attr, type, u8_data).failed()) {
return nullptr;
}
@@ -1898,7 +1764,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
if (initial_value) {
if (initial_value.isa<mlir::DenseElementsAttr>()) {
if (op_builder
- .GetDataFromAttribute(initial_value, element_type, u8_data)
+ .GetDataFromAttribute(*op, initial_value, element_type, u8_data)
.failed()) {
llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
"initial_value of variable tensor\n";
@@ -1909,14 +1775,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
return mlir::failure();
}
} else {
- uint64_t num_elements = 1;
- for (int64_t dim : tensor_type.getShape()) {
- num_elements *= dim;
- }
- uint64_t num_bits = num_elements * GetDTypeSize(element_type);
- uint64_t num_bytes =
- (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1;
- // std::fill_n(u8_data.begin(), num_bytes, 0);
TosaSerializationHandler::ForceAlignTensorData(u8_data);
}
ser_tensor->SetData(u8_data);