aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-08 22:19:41 +0000
committerTai Ly <tai.ly@arm.com>2024-03-17 19:56:21 -0700
commit60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95 (patch)
treee3d229a2d596e1a0788dfd75d77b996263055496 /reference_model
parente67115ef82bcba0718dcbd75cc8411985001b7cc (diff)
downloadreference_model-60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95.tar.gz
[ref model] Change Clamp and Pad attribute fields
This implements changes due to ClampAttribute and PadAttribute field changes. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ide01e2a27fe3c1ea7794e7a4b6780b7eae436caf
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/activation_funcs.cc29
-rw-r--r--reference_model/src/ops/data_layout.cc16
2 files changed, 33 insertions, 12 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 1f4c3b3..de7d8be 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -17,6 +17,7 @@
#include "arith_util.h"
#include "quant_util.h"
#include "template_types.h"
+#include "tosa_serialization_handler.h"
#include <cmath>
using namespace TosaReference;
@@ -35,8 +36,11 @@ int OpClamp<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32: {
- InEigenType min = (InEigenType)attribute->min_fp();
- InEigenType max = (InEigenType)attribute->max_fp();
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
+ InEigenType min = (InEigenType)min_float_data[0];
+ InEigenType max = (InEigenType)max_float_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](InEigenType a) -> OutEigenType {
@@ -45,23 +49,32 @@ int OpClamp<Rank, Dtype>::register_fcn()
}
break;
case TOSA_REF_TYPE_FP64: {
- InEigenType min = (InEigenType)attribute->min_fp();
- InEigenType max = (InEigenType)attribute->max_fp();
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
+ InEigenType min = (InEigenType)min_float_data[0];
+ InEigenType max = (InEigenType)max_float_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
}
break;
case TOSA_REF_TYPE_INT8: {
- int8_t min = (int8_t)attribute->min_int();
- int8_t max = (int8_t)attribute->max_int();
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
+ int8_t min = (int8_t)min_int_data[0];
+ int8_t max = (int8_t)max_int_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](int8_t a) -> int8_t { return a <= min ? min : a >= max ? max : a; };
}
case TOSA_REF_TYPE_INT16: {
- int16_t min = (int16_t)attribute->min_int();
- int16_t max = (int16_t)attribute->max_int();
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
+ int16_t min = (int16_t)min_int_data[0];
+ int16_t max = (int16_t)max_int_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](int16_t a) -> int16_t { return a <= min ? min : a >= max ? max : a; };
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index b6ad704..4c17e78 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -176,17 +176,25 @@ int OpPad<Rank, Dtype>::eval()
case TOSA_REF_TYPE_BOOL:
case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16:
- case TOSA_REF_TYPE_INT32:
- pad_value = (InEigenType)attribute->pad_const_int();
+ case TOSA_REF_TYPE_INT32: {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
+ /* size = */ 1, int32_data);
+ pad_value = (InEigenType)int32_data[0];
break;
+ }
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
case TOSA_REF_TYPE_FP8E4M3:
- case TOSA_REF_TYPE_FP8E5M2:
- pad_value = (InEigenType)attribute->pad_const_fp();
+ case TOSA_REF_TYPE_FP8E5M2: {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
+ /* size = */ 1, float_data);
+ pad_value = (InEigenType)float_data[0];
break;
+ }
default:
ASSERT_MSG(false, "TOSA_REF_TYPE %s is not supported.", EnumNameTOSAREFTYPE(Dtype));
break;