diff options
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/activation_funcs.cc | 96 | ||||
-rw-r--r-- | reference_model/src/ops/activation_funcs.h | 1 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 67 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 7 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 10 |
5 files changed, 128 insertions, 53 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index de7d8be..fc2a9ac 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -31,53 +31,79 @@ int OpClamp<Rank, Dtype>::register_fcn() auto tosa_level = g_func_config.tosa_level; LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK"); - switch (Dtype) + ASSERT_MSG(!(static_cast<GraphNode*>(this))->getOutputs().empty(), + "Must call register_fcn after tensors are linked to nodes"); + + InEigenType min, max; + + // need to use input tensor's serializationDtype to deserialize min/max values + // because Dtype may be FP64 in precise_mode + auto serializationDtype = (static_cast<GraphNode*>(this))->getInputs()[0]->getSerializationDtype(); + switch (DType2RefType(serializationDtype)) { - case TOSA_REF_TYPE_FP16: - case TOSA_REF_TYPE_BF16: - case TOSA_REF_TYPE_FP32: { + case TOSA_REF_TYPE_FP16: { + std::vector<half_float::half> min_float_data, max_float_data; + TosaSerializationHandler::ConvertU8toF16(attribute->min_val(), /* size = */ 1, min_float_data); + TosaSerializationHandler::ConvertU8toF16(attribute->max_val(), /* size = */ 1, max_float_data); + min = (InEigenType)min_float_data[0]; + max = (InEigenType)max_float_data[0]; + } + break; + case TOSA_REF_TYPE_BF16: { std::vector<float> min_float_data, max_float_data; - TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data); - TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data); - InEigenType min = (InEigenType)min_float_data[0]; - InEigenType max = (InEigenType)max_float_data[0]; - ERROR_IF(max < min, "OpClamp: max smaller than min"); - - this->fcn = [min, max](InEigenType a) -> OutEigenType { - return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); - }; + TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, min_float_data); + TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, max_float_data); + min = (InEigenType)min_float_data[0]; + max = (InEigenType)max_float_data[0]; } break; - case TOSA_REF_TYPE_FP64: { + case TOSA_REF_TYPE_FP32: { std::vector<float> min_float_data, max_float_data; TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data); TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data); - InEigenType min = (InEigenType)min_float_data[0]; - InEigenType max = (InEigenType)max_float_data[0]; - ERROR_IF(max < min, "OpClamp: max smaller than min"); - - this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; + min = (InEigenType)min_float_data[0]; + max = (InEigenType)max_float_data[0]; } break; case TOSA_REF_TYPE_INT8: { - std::vector<int32_t> min_int_data, max_int_data; - TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data); - TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data); - int8_t min = (int8_t)min_int_data[0]; - int8_t max = (int8_t)max_int_data[0]; - - ERROR_IF(max < min, "OpClamp: max smaller than min"); - this->fcn = [min, max](int8_t a) -> int8_t { return a <= min ? min : a >= max ? max : a; }; + std::vector<int8_t> min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI8(attribute->min_val(), /* size = */ 1, min_int_data); + TosaSerializationHandler::ConvertU8toI8(attribute->max_val(), /* size = */ 1, max_int_data); + min = (InEigenType)min_int_data[0]; + max = (InEigenType)max_int_data[0]; + } + break; + case TOSA_REF_TYPE_INT16: { + std::vector<int16_t> min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI16(attribute->min_val(), /* size = */ 1, min_int_data); + TosaSerializationHandler::ConvertU8toI16(attribute->max_val(), /* size = */ 1, max_int_data); + min = (InEigenType)min_int_data[0]; + max = (InEigenType)max_int_data[0]; } + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + ERROR_IF(max < min, "OpClamp: max smaller than min"); + + // evaluation function is still based on Dtype + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: { + // apply fpTrunc<Dtype> after min/max + this->fcn = [min, max](InEigenType a) -> OutEigenType { + return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); + }; + } + break; + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT8: case TOSA_REF_TYPE_INT16: { - std::vector<int32_t> min_int_data, max_int_data; - TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data); - TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data); - int16_t min = (int16_t)min_int_data[0]; - int16_t max = (int16_t)max_int_data[0]; - - ERROR_IF(max < min, "OpClamp: max smaller than min"); - this->fcn = [min, max](int16_t a) -> int16_t { return a <= min ? min : a >= max ? max : a; }; + // simply min/max + this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; } break; default: diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index 1696668..055642a 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -32,7 +32,6 @@ public: : UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_) { INIT_ATTRIBUTE(Clamp); - register_fcn(); } virtual ~OpClamp(); static constexpr int32_t QMin = GetQMin<Dtype>::value; diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index e264284..6664ec3 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -171,11 +171,31 @@ int OpPad<Rank, Dtype>::eval() { InEigenType pad_value = 0; - switch (Dtype) - { - case TOSA_REF_TYPE_BOOL: - case TOSA_REF_TYPE_INT8: - case TOSA_REF_TYPE_INT16: + // need to use input tensor's serializationDtype to deserialize pad_const + // because Dtype may be FP64 in precise_mode + switch (DType2RefType(inputs[0]->getSerializationDtype())) + { + case TOSA_REF_TYPE_BOOL: { + std::vector<bool> bool_data; + TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(), + /* size = */ 1, bool_data); + pad_value = (InEigenType)bool_data[0]; + break; + } + case TOSA_REF_TYPE_INT8: { + std::vector<int8_t> int8_data; + TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(), + /* size = */ 1, int8_data); + pad_value = (InEigenType)int8_data[0]; + break; + } + case TOSA_REF_TYPE_INT16: { + std::vector<int16_t> int16_data; + TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(), + /* size = */ 1, int16_data); + pad_value = (InEigenType)int16_data[0]; + break; + } case TOSA_REF_TYPE_INT32: { std::vector<int32_t> int32_data; TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(), @@ -183,15 +203,38 @@ int OpPad<Rank, Dtype>::eval() pad_value = (InEigenType)int32_data[0]; break; } - case TOSA_REF_TYPE_FP16: - case TOSA_REF_TYPE_BF16: - case TOSA_REF_TYPE_FP32: - case TOSA_REF_TYPE_FP64: - case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP16: { + std::vector<half_float::half> f16_data; + TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(), + /* size = */ 1, f16_data); + pad_value = (InEigenType)f16_data[0]; + break; + } + case TOSA_REF_TYPE_BF16: { + std::vector<float> f32_data; + TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } + case TOSA_REF_TYPE_FP32: { + std::vector<float> f32_data; + TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } + case TOSA_REF_TYPE_FP8E4M3: { + std::vector<float> f32_data; + TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } case TOSA_REF_TYPE_FP8E5M2: { std::vector<float> float_data; - TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(), - /* size = */ 1, float_data); + TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(), + /* size = */ 1, float_data); pad_value = (InEigenType)float_data[0]; break; } diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index dd9ea5a..310a174 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -66,6 +66,13 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes() template <int Rank, TOSA_REF_TYPE Dtype> int UnaryNode<Rank, Dtype>::eval() { + // call register_fcn() here to ensure inputs/outputs have been connected + // to the node by the time register_fcn() is called for Clamp Operator + if (register_fcn()) + { + return 1; + } + this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); return GraphNode::eval(); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 7bca697..40e6c64 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -25,11 +25,11 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>; -using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>; -using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>; -using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>; -using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>; +using fp16 = tosa::float_t<int16_t, 5, true, true, true>; +using bf16 = tosa::float_t<int16_t, 8, true, true, true>; +using fp32 = tosa::float_t<int32_t, 8, true, true, true>; +using fp8e4m3 = tosa::float_t<int8_t, 4, true, true, false>; +using fp8e5m2 = tosa::float_t<int8_t, 5, true, true, true>; template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) |