diff options
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/activation_funcs.cc | 13 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 31 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 60 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 3 | ||||
-rw-r--r-- | reference_model/src/ops/image.cc | 12 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 4 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.h | 2 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 54 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 222 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.h | 68 |
10 files changed, 328 insertions, 141 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index fc2a9ac..caab7e0 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -50,9 +50,18 @@ int OpClamp<Rank, Dtype>::register_fcn() } break; case TOSA_REF_TYPE_BF16: { + std::vector<bf16> bf16_min_float_data, bf16_max_float_data; + TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, bf16_min_float_data); + TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, bf16_max_float_data); std::vector<float> min_float_data, max_float_data; - TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, min_float_data); - TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, max_float_data); + for (auto f : bf16_min_float_data) + { + min_float_data.push_back(f); + } + for (auto f : bf16_max_float_data) + { + max_float_data.push_back(f); + } min = (InEigenType)min_float_data[0]; max = (InEigenType)max_float_data[0]; } diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 6664ec3..3e3770e 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -211,9 +211,16 @@ int OpPad<Rank, Dtype>::eval() break; } case TOSA_REF_TYPE_BF16: { - std::vector<float> f32_data; + std::vector<bf16> bf16_data; TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(), - /* size = */ 1, f32_data); + /* size = */ 1, bf16_data); + // Some ops use Eigen APIs for float calculation, so convert bfloat16 + // to float + std::vector<float> f32_data; + for (auto f : bf16_data) + { + f32_data.push_back(static_cast<float>(f)); + } pad_value = (InEigenType)f32_data[0]; break; } @@ -225,17 +232,27 @@ int OpPad<Rank, Dtype>::eval() break; } case TOSA_REF_TYPE_FP8E4M3: { - std::vector<float> f32_data; + std::vector<fp8e4m3> f8_data; TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(), - /* size = */ 1, f32_data); + /* size = */ 1, f8_data); + std::vector<float> f32_data; + for (auto f : f8_data) + { + f32_data.push_back(static_cast<float>(f)); + } pad_value = (InEigenType)f32_data[0]; break; } case TOSA_REF_TYPE_FP8E5M2: { - std::vector<float> float_data; + std::vector<fp8e5m2> f8_data; TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(), - /* size = */ 1, float_data); - pad_value = (InEigenType)float_data[0]; + /* size = */ 1, f8_data); + std::vector<float> f32_data; + for (auto f : f8_data) + { + f32_data.push_back(static_cast<float>(f)); + } + pad_value = (InEigenType)f32_data[0]; break; } default: diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 8cc1319..bc63535 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -411,6 +411,22 @@ int OpMaximum<Rank, Dtype>::register_fcn() case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { + if (isnan(a)) + { + return a; + } + else if (isnan(b)) + { + return b; + } + else + { + return a > b ? a : b; + } + }; + break; + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; @@ -430,6 +446,21 @@ int OpMinimum<Rank, Dtype>::register_fcn() case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { + if (isnan(a)) + { + return a; + } + else if (isnan(b)) + { + return b; + } + else + { + return a < b ? a : b; + } + }; + break; case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; @@ -635,18 +666,13 @@ template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { - setRequiredOperands(1, 1); + setRequiredOperands(2, 1); setRequiredRank(0, 6); - - INIT_ATTRIBUTE(Table); } template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::~OpTable() -{ - if (attribute) - delete attribute; -} +{} template <int Rank, TOSA_REF_TYPE InDtype> int OpTable<Rank, InDtype>::checkTensorAttributes() @@ -664,16 +690,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() } ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type"); + ERROR_IF(inputs[1]->getDtype() != TableDtype, "OpTable: Unexpected table type"); ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type"); - ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries); - - for (uint32_t i = 0; i < TableNumEntries; i++) - { - table[i] = (TableEigenType)attribute->table()[i]; - } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); ASSERT_MEM(in && out); @@ -683,13 +705,15 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() template <int Rank, TOSA_REF_TYPE InDtype> int OpTable<Rank, InDtype>::eval() { + ERROR_IF(this->table->getTensor().size() != TableNumEntries, "OpTable: table tensor size must be %u", + TableNumEntries); switch (InDtype) { case TOSA_REF_TYPE_INT8: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; - int32_t value = table[index]; + int32_t value = this->table->getTensor()(index); return value; }); @@ -705,8 +729,8 @@ int OpTable<Rank, InDtype>::eval() int32_t frac = (input_truncated)&0x7F; // 7-bit fraction // 3. Add REQUIRE CHECK for extreme large/small slopes - int32_t base = table[index]; - int32_t next = table[index + 1]; + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); int32_t slope = next - base; REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range"); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 7ebd852..54c05e3 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -215,8 +215,7 @@ public: protected: TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; - TosaTableAttribute* attribute; - std::array<TableEigenType, TableNumEntries> table; + TosaReference::TensorTemplate<TTable>* table; }; }; // namespace TosaReference diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 7c480ad..bfb19bd 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -179,7 +179,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n); const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n); - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(bf16)) || (typeid(resize_t) == typeid(half_float::half))) { dy = (resize_t)(fy - iy); @@ -212,8 +212,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() acc += (OutEigenType)v10 * dy * (1.0 - dx); acc += (OutEigenType)v11 * dy * dx; } - else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) || - (typeid(resize_t) == typeid(half_float::half))) + else if ((typeid(resize_t) == typeid(bf16)) || (typeid(resize_t) == typeid(half_float::half))) { resize_t f16_acc; f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx); @@ -246,11 +245,6 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() } acc = in->getTensor()(b, iy, ix, c); } - if ((typeid(resize_t) == typeid(Eigen::bfloat16))) - { - ASSERT_MSG(checkValidBFloat(acc), - "Resize accumulator float value is not a valid bfloat16 value."); - } out->getTensor()(b, oy, ox, c) = acc; } @@ -263,6 +257,6 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half); -DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, bf16); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 74315d7..cb4e76b 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -624,6 +624,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E5M2); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); @@ -632,9 +634,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index f1d1680..bb1e690 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -135,7 +135,7 @@ #define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, bf16>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index f38f486..28d54fd 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -744,18 +744,24 @@ int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } - ETensor4<InEigenType> input_padded = input_val.pad(pad); - TBias bias_val = this->bias->getTensor(); if (g_func_config.abs_mode) { // in abs_mode: take abs values of conv operands - input_padded = input_padded.abs(); - weight_val = weight_val.abs(); - bias_val = bias_val.abs(); + input_val = input_val.abs(); + weight_val = weight_val.abs(); + bias_val = bias_val.abs(); + + if (!this->attribute->local_bound()) + { + Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum(); + input_val.setConstant(input_abs_max(0)); + } } + ETensor4<InEigenType> input_padded = input_val.pad(pad); + // extract_image_patches() output [N, KH, KW, H * W, C] // need to transpose to [N, H * W, KH, KW, C] ETensor5<InEigenType> input_extract_patches = @@ -938,18 +944,24 @@ int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } - ETensor5<InEigenType> input_padded = input_val.pad(pad); - TBias bias_val = this->bias->getTensor(); if (g_func_config.abs_mode) { // in abs_mode: take abs values of conv operands - input_padded = input_padded.abs(); - weight_val = weight_val.abs(); - bias_val = bias_val.abs(); + input_val = input_val.abs(); + weight_val = weight_val.abs(); + bias_val = bias_val.abs(); + + if (!this->attribute->local_bound()) + { + Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum(); + input_val.setConstant(input_abs_max(0)); + } } + ETensor5<InEigenType> input_padded = input_val.pad(pad); + // 1. initialize with bias Eigen::array<Eigen::Index, 5> reshape_dim; reshape_dim.fill(1); @@ -1140,18 +1152,24 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } - ETensor4<InEigenType> input_padded = input_val.pad(pad); - TBias bias_val = this->bias->getTensor(); if (g_func_config.abs_mode) { // in abs_mode: take abs values of conv operands - input_padded = input_padded.abs(); - weight_val = weight_val.abs(); - bias_val = bias_val.abs(); + input_val = input_val.abs(); + weight_val = weight_val.abs(); + bias_val = bias_val.abs(); + + if (!this->attribute->local_bound()) + { + Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum(); + input_val.setConstant(input_abs_max(0)); + } } + ETensor4<InEigenType> input_padded = input_val.pad(pad); + // GEMM doesn't fit well with DepthwiseConv2d // 1. use extract_image_patches() to handle stride/dilation/pad // 2. perform direct convolution @@ -2078,6 +2096,12 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() input_val = input_val.abs(); weight_val = weight_val.abs(); bias_val = bias_val.abs(); + + if (!this->attribute->local_bound()) + { + Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum(); + input_val.setConstant(input_abs_max(0)); + } } Eigen::array<Eigen::Index, 4> reshape_dim; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 9719f07..a2e8da4 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -27,7 +27,6 @@ using namespace tosa; using fp16 = ct::cfloat<int16_t, 5, true, true, true>; using bf16 = ct::cfloat<int16_t, 8, true, true, true>; -using fp32 = ct::cfloat<int32_t, 8, true, true, true>; using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; @@ -146,6 +145,21 @@ static int64_t zero_extend(int16_t val) uint16_t* rval = reinterpret_cast<uint16_t*>(&val); return static_cast<int64_t>(*rval); } +static int64_t zero_extend(int32_t val) +{ + uint32_t* rval = reinterpret_cast<uint32_t*>(&val); + return static_cast<int64_t>(*rval); +} +static int64_t sign_extend(uint8_t val) +{ + int8_t* rval = reinterpret_cast<int8_t*>(&val); + return static_cast<int64_t>(*rval); +} +static int64_t sign_extend(uint16_t val) +{ + int16_t* rval = reinterpret_cast<int16_t*>(&val); + return static_cast<int64_t>(*rval); +} template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpRescale<Rank, InDtype, OutDtype>::eval() @@ -201,6 +215,70 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() shift.push_back(static_cast<int32_t>(shift_val(i))); } + auto get_input_zp_shifted = [=](InEigenType in_val) -> int64_t { + int64_t input_zp_shifted; + if (input_unsigned) + { + int64_t in_val64; + int64_t in_zp64; + switch (GetNumBits<InDtype>::value) + { + case 8: + in_val64 = zero_extend(static_cast<int8_t>(in_val)); + in_zp64 = zero_extend(static_cast<int8_t>(input_zp)); + break; + case 16: + in_val64 = zero_extend(static_cast<int16_t>(in_val)); + in_zp64 = zero_extend(static_cast<int16_t>(input_zp)); + break; + case 32: + in_val64 = zero_extend(static_cast<int32_t>(in_val)); + in_zp64 = zero_extend(static_cast<int32_t>(input_zp)); + break; + case 48: + in_val64 = static_cast<int64_t>(in_val) & 0x0000FFFFFFFFFFFF; + in_zp64 = static_cast<int64_t>(input_zp) & 0x0000FFFFFFFFFFFF; + break; + default: + in_val64 = static_cast<int64_t>(in_val); + in_zp64 = static_cast<int64_t>(input_zp); + break; + } + input_zp_shifted = in_val64 - in_zp64; + } + else + { + int64_t in_val64; + int64_t in_zp64 = static_cast<int64_t>(input_zp); + switch (GetNumBits<InDtype>::value) + { + case 8: + in_val64 = sign_extend(static_cast<uint8_t>(in_val & 0xFF)); + break; + case 16: + in_val64 = sign_extend(static_cast<uint16_t>(in_val & 0xFFFF)); + break; + case 32: + in_val64 = static_cast<int64_t>(static_cast<int32_t>(in_val & 0xFFFFFFFF)); + break; + case 48: + // sign extend i48 to i64 + in_val64 = static_cast<int64_t>(in_val); + if (in_val64 & 0x800000000000) + { + // in_val contains negative i48, sign extend to i64 + in_val64 |= 0xFFFF000000000000; + } + break; + default: + in_val64 = static_cast<int64_t>(in_val); + break; + } + input_zp_shifted = in_val64 - in_zp64; + } + return input_zp_shifted; + }; + if (per_channel) { ETensor2<InEigenType> curr_channel_slice_prescaled; @@ -218,32 +296,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() channel_shift = shift[i]; curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType { - int64_t input_zp_shifted; - if (input_unsigned) - { - int64_t in_val64; - int64_t in_zp64; - switch (GetNumBits<InDtype>::value) - { - case 8: - in_val64 = zero_extend(static_cast<int8_t>(in_val)); - in_zp64 = zero_extend(static_cast<int8_t>(input_zp)); - break; - case 16: - in_val64 = zero_extend(static_cast<int16_t>(in_val)); - in_zp64 = zero_extend(static_cast<int16_t>(input_zp)); - break; - default: - in_val64 = static_cast<int64_t>(in_val); - in_zp64 = static_cast<int64_t>(input_zp); - break; - } - input_zp_shifted = in_val64 - in_zp64; - } - else - { - input_zp_shifted = in_val - input_zp; - } + int64_t input_zp_shifted = get_input_zp_shifted(in_val); + int32_t scaled; if (scale32) scaled = TosaReference::QuantUtil::apply_scale_32(static_cast<int32_t>(input_zp_shifted), @@ -311,32 +365,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() try { output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType { - int64_t input_zp_shifted; - if (input_unsigned) - { - int64_t in_val64; - int64_t in_zp64; - switch (GetNumBits<InDtype>::value) - { - case 8: - in_val64 = zero_extend(static_cast<int8_t>(in_val)); - in_zp64 = zero_extend(static_cast<int8_t>(input_zp)); - break; - case 16: - in_val64 = zero_extend(static_cast<int16_t>(in_val)); - in_zp64 = zero_extend(static_cast<int16_t>(input_zp)); - break; - default: - in_val64 = static_cast<int64_t>(in_val); - in_zp64 = static_cast<int64_t>(input_zp); - break; - } - input_zp_shifted = in_val64 - in_zp64; - } - else - { - input_zp_shifted = in_val - input_zp; - } + int64_t input_zp_shifted = get_input_zp_shifted(in_val); + int32_t scaled; if (scale32) scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, @@ -626,6 +656,12 @@ CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper() fcn = [](InEigenType in) -> OutEigenType { return in; }; } +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP64>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp64 + fcn = [](float in) -> double { return static_cast<double>(in); }; +} + template <TOSA_REF_TYPE OutDtype> CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper() { @@ -663,12 +699,18 @@ CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper() fcn = [](InEigenType in) -> OutEigenType { return in; }; } +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP64>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp64 + fcn = [](float in) -> double { return static_cast<double>(in); }; +} + template <TOSA_REF_TYPE InDtype> CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper() { // Integer data converted to fp8e4m3 (stored as fp32) fcn = [](InEigenType in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in))); + auto f = static_cast<fp8e4m3>(float(in)); float out = static_cast<float>(f); return out; }; @@ -677,70 +719,60 @@ CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper() CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper() { // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) - fcn = [](float in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); - float out = static_cast<float>(f); - return out; - }; + fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e4m3>(in)); }; } CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper() { // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) - fcn = [](float in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); - float out = static_cast<float>(f); - return out; - }; + fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e4m3>(in)); }; } CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper() { // fp32 data converted to fp8e4m3 (stored as fp32) - fcn = [](float in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); - float out = static_cast<float>(f); - return out; - }; + fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e4m3>(in)); }; } template <TOSA_REF_TYPE InDtype> CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper() { // Integer data converted to fp8e5m2 (stored as fp32) - fcn = [](InEigenType in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in))); - float out = static_cast<float>(f); - return out; - }; + fcn = [](InEigenType in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); }; } CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper() { // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) - fcn = [](float in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); - float out = static_cast<float>(f); - return out; - }; + fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); }; } CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper() { // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) - fcn = [](float in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); - float out = static_cast<float>(f); - return out; - }; + fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); }; } CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper() { // fp32 data converted to fp8e5m2 (stored as fp32) - fcn = [](float in) -> float { - auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); - float out = static_cast<float>(f); + fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); }; +} + +CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // fp64 data converted to fp8e5m2 (stored as fp32) + fcn = [](double in) -> float { + float out = static_cast<float>(static_cast<fp8e4m3>(in)); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // fp64 data converted to fp8e5m2 (stored as fp32) + fcn = [](double in) -> float { + float out = static_cast<float>(static_cast<fp8e5m2>(in)); return out; }; } @@ -812,6 +844,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); @@ -821,9 +855,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); @@ -843,7 +879,19 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, UINT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, UINT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 0636357..9bcace4 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -371,6 +371,23 @@ private: FcnType fcn; }; +template <> +class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP64> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template <TOSA_REF_TYPE OutDtype> class CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype> { @@ -441,6 +458,23 @@ private: FcnType fcn; }; +template <> +class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP64> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template <TOSA_REF_TYPE InDtype> class CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3> { @@ -577,6 +611,40 @@ private: FcnType fcn; }; +template <> +class CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E4M3> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E5M2> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template <TOSA_REF_TYPE OutDtype> class CastHelper<TOSA_REF_TYPE_FP64, OutDtype> { |