aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc67
1 files changed, 55 insertions, 12 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index e264284..6664ec3 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -171,11 +171,31 @@ int OpPad<Rank, Dtype>::eval()
{
InEigenType pad_value = 0;
- switch (Dtype)
- {
- case TOSA_REF_TYPE_BOOL:
- case TOSA_REF_TYPE_INT8:
- case TOSA_REF_TYPE_INT16:
+ // need to use input tensor's serializationDtype to deserialize pad_const
+ // because Dtype may be FP64 in precise_mode
+ switch (DType2RefType(inputs[0]->getSerializationDtype()))
+ {
+ case TOSA_REF_TYPE_BOOL: {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(),
+ /* size = */ 1, bool_data);
+ pad_value = (InEigenType)bool_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_INT8: {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(),
+ /* size = */ 1, int8_data);
+ pad_value = (InEigenType)int8_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_INT16: {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(),
+ /* size = */ 1, int16_data);
+ pad_value = (InEigenType)int16_data[0];
+ break;
+ }
case TOSA_REF_TYPE_INT32: {
std::vector<int32_t> int32_data;
TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
@@ -183,15 +203,38 @@ int OpPad<Rank, Dtype>::eval()
pad_value = (InEigenType)int32_data[0];
break;
}
- case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_BF16:
- case TOSA_REF_TYPE_FP32:
- case TOSA_REF_TYPE_FP64:
- case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP16: {
+ std::vector<half_float::half> f16_data;
+ TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(),
+ /* size = */ 1, f16_data);
+ pad_value = (InEigenType)f16_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_BF16: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_FP32: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_FP8E4M3: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
case TOSA_REF_TYPE_FP8E5M2: {
std::vector<float> float_data;
- TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
- /* size = */ 1, float_data);
+ TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(),
+ /* size = */ 1, float_data);
pad_value = (InEigenType)float_data[0];
break;
}