aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/activation_funcs.cc96
-rw-r--r--reference_model/src/ops/activation_funcs.h1
-rw-r--r--reference_model/src/ops/data_layout.cc67
-rw-r--r--reference_model/src/ops/ewise_unary.cc7
-rw-r--r--reference_model/src/ops/type_conversion.cc10
5 files changed, 128 insertions, 53 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index de7d8be..fc2a9ac 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -31,53 +31,79 @@ int OpClamp<Rank, Dtype>::register_fcn()
auto tosa_level = g_func_config.tosa_level;
LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
- switch (Dtype)
+ ASSERT_MSG(!(static_cast<GraphNode*>(this))->getOutputs().empty(),
+ "Must call register_fcn after tensors are linked to nodes");
+
+ InEigenType min, max;
+
+ // need to use input tensor's serializationDtype to deserialize min/max values
+ // because Dtype may be FP64 in precise_mode
+ auto serializationDtype = (static_cast<GraphNode*>(this))->getInputs()[0]->getSerializationDtype();
+ switch (DType2RefType(serializationDtype))
{
- case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_BF16:
- case TOSA_REF_TYPE_FP32: {
+ case TOSA_REF_TYPE_FP16: {
+ std::vector<half_float::half> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF16(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toF16(attribute->max_val(), /* size = */ 1, max_float_data);
+ min = (InEigenType)min_float_data[0];
+ max = (InEigenType)max_float_data[0];
+ }
+ break;
+ case TOSA_REF_TYPE_BF16: {
std::vector<float> min_float_data, max_float_data;
- TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
- TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
- InEigenType min = (InEigenType)min_float_data[0];
- InEigenType max = (InEigenType)max_float_data[0];
- ERROR_IF(max < min, "OpClamp: max smaller than min");
-
- this->fcn = [min, max](InEigenType a) -> OutEigenType {
- return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a);
- };
+ TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, max_float_data);
+ min = (InEigenType)min_float_data[0];
+ max = (InEigenType)max_float_data[0];
}
break;
- case TOSA_REF_TYPE_FP64: {
+ case TOSA_REF_TYPE_FP32: {
std::vector<float> min_float_data, max_float_data;
TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
- InEigenType min = (InEigenType)min_float_data[0];
- InEigenType max = (InEigenType)max_float_data[0];
- ERROR_IF(max < min, "OpClamp: max smaller than min");
-
- this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
+ min = (InEigenType)min_float_data[0];
+ max = (InEigenType)max_float_data[0];
}
break;
case TOSA_REF_TYPE_INT8: {
- std::vector<int32_t> min_int_data, max_int_data;
- TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
- TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
- int8_t min = (int8_t)min_int_data[0];
- int8_t max = (int8_t)max_int_data[0];
-
- ERROR_IF(max < min, "OpClamp: max smaller than min");
- this->fcn = [min, max](int8_t a) -> int8_t { return a <= min ? min : a >= max ? max : a; };
+ std::vector<int8_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI8(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI8(attribute->max_val(), /* size = */ 1, max_int_data);
+ min = (InEigenType)min_int_data[0];
+ max = (InEigenType)max_int_data[0];
+ }
+ break;
+ case TOSA_REF_TYPE_INT16: {
+ std::vector<int16_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI16(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI16(attribute->max_val(), /* size = */ 1, max_int_data);
+ min = (InEigenType)min_int_data[0];
+ max = (InEigenType)max_int_data[0];
}
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+ // evaluation function is still based on Dtype
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32: {
+ // apply fpTrunc<Dtype> after min/max
+ this->fcn = [min, max](InEigenType a) -> OutEigenType {
+ return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a);
+ };
+ }
+ break;
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16: {
- std::vector<int32_t> min_int_data, max_int_data;
- TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
- TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
- int16_t min = (int16_t)min_int_data[0];
- int16_t max = (int16_t)max_int_data[0];
-
- ERROR_IF(max < min, "OpClamp: max smaller than min");
- this->fcn = [min, max](int16_t a) -> int16_t { return a <= min ? min : a >= max ? max : a; };
+ // simply min/max
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
}
break;
default:
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index 1696668..055642a 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -32,7 +32,6 @@ public:
: UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_)
{
INIT_ATTRIBUTE(Clamp);
- register_fcn();
}
virtual ~OpClamp();
static constexpr int32_t QMin = GetQMin<Dtype>::value;
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index e264284..6664ec3 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -171,11 +171,31 @@ int OpPad<Rank, Dtype>::eval()
{
InEigenType pad_value = 0;
- switch (Dtype)
- {
- case TOSA_REF_TYPE_BOOL:
- case TOSA_REF_TYPE_INT8:
- case TOSA_REF_TYPE_INT16:
+ // need to use input tensor's serializationDtype to deserialize pad_const
+ // because Dtype may be FP64 in precise_mode
+ switch (DType2RefType(inputs[0]->getSerializationDtype()))
+ {
+ case TOSA_REF_TYPE_BOOL: {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(),
+ /* size = */ 1, bool_data);
+ pad_value = (InEigenType)bool_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_INT8: {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(),
+ /* size = */ 1, int8_data);
+ pad_value = (InEigenType)int8_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_INT16: {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(),
+ /* size = */ 1, int16_data);
+ pad_value = (InEigenType)int16_data[0];
+ break;
+ }
case TOSA_REF_TYPE_INT32: {
std::vector<int32_t> int32_data;
TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
@@ -183,15 +203,38 @@ int OpPad<Rank, Dtype>::eval()
pad_value = (InEigenType)int32_data[0];
break;
}
- case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_BF16:
- case TOSA_REF_TYPE_FP32:
- case TOSA_REF_TYPE_FP64:
- case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP16: {
+ std::vector<half_float::half> f16_data;
+ TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(),
+ /* size = */ 1, f16_data);
+ pad_value = (InEigenType)f16_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_BF16: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_FP32: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_FP8E4M3: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
case TOSA_REF_TYPE_FP8E5M2: {
std::vector<float> float_data;
- TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
- /* size = */ 1, float_data);
+ TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(),
+ /* size = */ 1, float_data);
pad_value = (InEigenType)float_data[0];
break;
}
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index dd9ea5a..310a174 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -66,6 +66,13 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes()
template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::eval()
{
+ // call register_fcn() here to ensure inputs/outputs have been connected
+ // to the node by the time register_fcn() is called for Clamp Operator
+ if (register_fcn())
+ {
+ return 1;
+ }
+
this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
return GraphNode::eval();
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 7bca697..40e6c64 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -25,11 +25,11 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>;
-using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>;
-using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>;
-using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>;
-using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>;
+using fp16 = tosa::float_t<int16_t, 5, true, true, true>;
+using bf16 = tosa::float_t<int16_t, 8, true, true, true>;
+using fp32 = tosa::float_t<int32_t, 8, true, true, true>;
+using fp8e4m3 = tosa::float_t<int8_t, 4, true, true, false>;
+using fp8e5m2 = tosa::float_t<int8_t, 5, true, true, true>;
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)