diff options
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 67 |
1 files changed, 55 insertions, 12 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index e264284..6664ec3 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -171,11 +171,31 @@ int OpPad<Rank, Dtype>::eval() { InEigenType pad_value = 0; - switch (Dtype) - { - case TOSA_REF_TYPE_BOOL: - case TOSA_REF_TYPE_INT8: - case TOSA_REF_TYPE_INT16: + // need to use input tensor's serializationDtype to deserialize pad_const + // because Dtype may be FP64 in precise_mode + switch (DType2RefType(inputs[0]->getSerializationDtype())) + { + case TOSA_REF_TYPE_BOOL: { + std::vector<bool> bool_data; + TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(), + /* size = */ 1, bool_data); + pad_value = (InEigenType)bool_data[0]; + break; + } + case TOSA_REF_TYPE_INT8: { + std::vector<int8_t> int8_data; + TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(), + /* size = */ 1, int8_data); + pad_value = (InEigenType)int8_data[0]; + break; + } + case TOSA_REF_TYPE_INT16: { + std::vector<int16_t> int16_data; + TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(), + /* size = */ 1, int16_data); + pad_value = (InEigenType)int16_data[0]; + break; + } case TOSA_REF_TYPE_INT32: { std::vector<int32_t> int32_data; TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(), @@ -183,15 +203,38 @@ int OpPad<Rank, Dtype>::eval() pad_value = (InEigenType)int32_data[0]; break; } - case TOSA_REF_TYPE_FP16: - case TOSA_REF_TYPE_BF16: - case TOSA_REF_TYPE_FP32: - case TOSA_REF_TYPE_FP64: - case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP16: { + std::vector<half_float::half> f16_data; + TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(), + /* size = */ 1, f16_data); + pad_value = (InEigenType)f16_data[0]; + break; + } + case TOSA_REF_TYPE_BF16: { + std::vector<float> f32_data; + TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } + case TOSA_REF_TYPE_FP32: { + std::vector<float> f32_data; + TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } + case TOSA_REF_TYPE_FP8E4M3: { + std::vector<float> f32_data; + TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } case TOSA_REF_TYPE_FP8E5M2: { std::vector<float> float_data; - TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(), - /* size = */ 1, float_data); + TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(), + /* size = */ 1, float_data); pad_value = (InEigenType)float_data[0]; break; } |