aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/activation_funcs.cc13
-rw-r--r--reference_model/src/ops/data_layout.cc31
-rw-r--r--reference_model/src/ops/ewise_binary.cc60
-rw-r--r--reference_model/src/ops/ewise_binary.h3
-rw-r--r--reference_model/src/ops/image.cc12
-rw-r--r--reference_model/src/ops/op_factory.cc4
-rw-r--r--reference_model/src/ops/op_factory.h2
-rw-r--r--reference_model/src/ops/tensor_ops.cc54
-rw-r--r--reference_model/src/ops/type_conversion.cc222
-rw-r--r--reference_model/src/ops/type_conversion.h68
10 files changed, 328 insertions, 141 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index fc2a9ac..caab7e0 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -50,9 +50,18 @@ int OpClamp<Rank, Dtype>::register_fcn()
}
break;
case TOSA_REF_TYPE_BF16: {
+ std::vector<bf16> bf16_min_float_data, bf16_max_float_data;
+ TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, bf16_min_float_data);
+ TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, bf16_max_float_data);
std::vector<float> min_float_data, max_float_data;
- TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, min_float_data);
- TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, max_float_data);
+ for (auto f : bf16_min_float_data)
+ {
+ min_float_data.push_back(f);
+ }
+ for (auto f : bf16_max_float_data)
+ {
+ max_float_data.push_back(f);
+ }
min = (InEigenType)min_float_data[0];
max = (InEigenType)max_float_data[0];
}
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 6664ec3..3e3770e 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -211,9 +211,16 @@ int OpPad<Rank, Dtype>::eval()
break;
}
case TOSA_REF_TYPE_BF16: {
- std::vector<float> f32_data;
+ std::vector<bf16> bf16_data;
TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(),
- /* size = */ 1, f32_data);
+ /* size = */ 1, bf16_data);
+ // Some ops use Eigen APIs for float calculation, so convert bfloat16
+ // to float
+ std::vector<float> f32_data;
+ for (auto f : bf16_data)
+ {
+ f32_data.push_back(static_cast<float>(f));
+ }
pad_value = (InEigenType)f32_data[0];
break;
}
@@ -225,17 +232,27 @@ int OpPad<Rank, Dtype>::eval()
break;
}
case TOSA_REF_TYPE_FP8E4M3: {
- std::vector<float> f32_data;
+ std::vector<fp8e4m3> f8_data;
TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(),
- /* size = */ 1, f32_data);
+ /* size = */ 1, f8_data);
+ std::vector<float> f32_data;
+ for (auto f : f8_data)
+ {
+ f32_data.push_back(static_cast<float>(f));
+ }
pad_value = (InEigenType)f32_data[0];
break;
}
case TOSA_REF_TYPE_FP8E5M2: {
- std::vector<float> float_data;
+ std::vector<fp8e5m2> f8_data;
TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(),
- /* size = */ 1, float_data);
- pad_value = (InEigenType)float_data[0];
+ /* size = */ 1, f8_data);
+ std::vector<float> f32_data;
+ for (auto f : f8_data)
+ {
+ f32_data.push_back(static_cast<float>(f));
+ }
+ pad_value = (InEigenType)f32_data[0];
break;
}
default:
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 8cc1319..bc63535 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -411,6 +411,22 @@ int OpMaximum<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
+ if (isnan(a))
+ {
+ return a;
+ }
+ else if (isnan(b))
+ {
+ return b;
+ }
+ else
+ {
+ return a > b ? a : b;
+ }
+ };
+ break;
+
case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
break;
@@ -430,6 +446,21 @@ int OpMinimum<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
+ if (isnan(a))
+ {
+ return a;
+ }
+ else if (isnan(b))
+ {
+ return b;
+ }
+ else
+ {
+ return a < b ? a : b;
+ }
+ };
+ break;
case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
break;
@@ -635,18 +666,13 @@ template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_TABLE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(2, 1);
setRequiredRank(0, 6);
-
- INIT_ATTRIBUTE(Table);
}
template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::~OpTable()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::checkTensorAttributes()
@@ -664,16 +690,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
}
ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type");
+ ERROR_IF(inputs[1]->getDtype() != TableDtype, "OpTable: Unexpected table type");
ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type");
- ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries);
-
- for (uint32_t i = 0; i < TableNumEntries; i++)
- {
- table[i] = (TableEigenType)attribute->table()[i];
- }
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(in && out);
@@ -683,13 +705,15 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::eval()
{
+ ERROR_IF(this->table->getTensor().size() != TableNumEntries, "OpTable: table tensor size must be %u",
+ TableNumEntries);
switch (InDtype)
{
case TOSA_REF_TYPE_INT8:
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
int32_t index = input_truncated - QInMin;
- int32_t value = table[index];
+ int32_t value = this->table->getTensor()(index);
return value;
});
@@ -705,8 +729,8 @@ int OpTable<Rank, InDtype>::eval()
int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
// 3. Add REQUIRE CHECK for extreme large/small slopes
- int32_t base = table[index];
- int32_t next = table[index + 1];
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
int32_t slope = next - base;
REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(),
"OpTable: slope out of int16_t range");
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 7ebd852..54c05e3 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -215,8 +215,7 @@ public:
protected:
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
- TosaTableAttribute* attribute;
- std::array<TableEigenType, TableNumEntries> table;
+ TosaReference::TensorTemplate<TTable>* table;
};
}; // namespace TosaReference
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 7c480ad..bfb19bd 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -179,7 +179,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(bf16)) ||
(typeid(resize_t) == typeid(half_float::half)))
{
dy = (resize_t)(fy - iy);
@@ -212,8 +212,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
acc += (OutEigenType)v10 * dy * (1.0 - dx);
acc += (OutEigenType)v11 * dy * dx;
}
- else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) ||
- (typeid(resize_t) == typeid(half_float::half)))
+ else if ((typeid(resize_t) == typeid(bf16)) || (typeid(resize_t) == typeid(half_float::half)))
{
resize_t f16_acc;
f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx);
@@ -246,11 +245,6 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
}
acc = in->getTensor()(b, iy, ix, c);
}
- if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
- {
- ASSERT_MSG(checkValidBFloat(acc),
- "Resize accumulator float value is not a valid bfloat16 value.");
- }
out->getTensor()(b, oy, ox, c) = acc;
}
@@ -263,6 +257,6 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half);
-DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, bf16);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 74315d7..cb4e76b 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -624,6 +624,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E5M2);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
@@ -632,9 +634,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index f1d1680..bb1e690 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -135,7 +135,7 @@
#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, bf16>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index f38f486..28d54fd 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -744,18 +744,24 @@ int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
- ETensor4<InEigenType> input_padded = input_val.pad(pad);
-
TBias bias_val = this->bias->getTensor();
if (g_func_config.abs_mode)
{
// in abs_mode: take abs values of conv operands
- input_padded = input_padded.abs();
- weight_val = weight_val.abs();
- bias_val = bias_val.abs();
+ input_val = input_val.abs();
+ weight_val = weight_val.abs();
+ bias_val = bias_val.abs();
+
+ if (!this->attribute->local_bound())
+ {
+ Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum();
+ input_val.setConstant(input_abs_max(0));
+ }
}
+ ETensor4<InEigenType> input_padded = input_val.pad(pad);
+
// extract_image_patches() output [N, KH, KW, H * W, C]
// need to transpose to [N, H * W, KH, KW, C]
ETensor5<InEigenType> input_extract_patches =
@@ -938,18 +944,24 @@ int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
- ETensor5<InEigenType> input_padded = input_val.pad(pad);
-
TBias bias_val = this->bias->getTensor();
if (g_func_config.abs_mode)
{
// in abs_mode: take abs values of conv operands
- input_padded = input_padded.abs();
- weight_val = weight_val.abs();
- bias_val = bias_val.abs();
+ input_val = input_val.abs();
+ weight_val = weight_val.abs();
+ bias_val = bias_val.abs();
+
+ if (!this->attribute->local_bound())
+ {
+ Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum();
+ input_val.setConstant(input_abs_max(0));
+ }
}
+ ETensor5<InEigenType> input_padded = input_val.pad(pad);
+
// 1. initialize with bias
Eigen::array<Eigen::Index, 5> reshape_dim;
reshape_dim.fill(1);
@@ -1140,18 +1152,24 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
- ETensor4<InEigenType> input_padded = input_val.pad(pad);
-
TBias bias_val = this->bias->getTensor();
if (g_func_config.abs_mode)
{
// in abs_mode: take abs values of conv operands
- input_padded = input_padded.abs();
- weight_val = weight_val.abs();
- bias_val = bias_val.abs();
+ input_val = input_val.abs();
+ weight_val = weight_val.abs();
+ bias_val = bias_val.abs();
+
+ if (!this->attribute->local_bound())
+ {
+ Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum();
+ input_val.setConstant(input_abs_max(0));
+ }
}
+ ETensor4<InEigenType> input_padded = input_val.pad(pad);
+
// GEMM doesn't fit well with DepthwiseConv2d
// 1. use extract_image_patches() to handle stride/dilation/pad
// 2. perform direct convolution
@@ -2078,6 +2096,12 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
input_val = input_val.abs();
weight_val = weight_val.abs();
bias_val = bias_val.abs();
+
+ if (!this->attribute->local_bound())
+ {
+ Eigen::Tensor<InEigenType, 0> input_abs_max = input_val.maximum();
+ input_val.setConstant(input_abs_max(0));
+ }
}
Eigen::array<Eigen::Index, 4> reshape_dim;
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 9719f07..a2e8da4 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -27,7 +27,6 @@ using namespace tosa;
using fp16 = ct::cfloat<int16_t, 5, true, true, true>;
using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
-using fp32 = ct::cfloat<int32_t, 8, true, true, true>;
using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
@@ -146,6 +145,21 @@ static int64_t zero_extend(int16_t val)
uint16_t* rval = reinterpret_cast<uint16_t*>(&val);
return static_cast<int64_t>(*rval);
}
+static int64_t zero_extend(int32_t val)
+{
+ uint32_t* rval = reinterpret_cast<uint32_t*>(&val);
+ return static_cast<int64_t>(*rval);
+}
+static int64_t sign_extend(uint8_t val)
+{
+ int8_t* rval = reinterpret_cast<int8_t*>(&val);
+ return static_cast<int64_t>(*rval);
+}
+static int64_t sign_extend(uint16_t val)
+{
+ int16_t* rval = reinterpret_cast<int16_t*>(&val);
+ return static_cast<int64_t>(*rval);
+}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::eval()
@@ -201,6 +215,70 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
shift.push_back(static_cast<int32_t>(shift_val(i)));
}
+ auto get_input_zp_shifted = [=](InEigenType in_val) -> int64_t {
+ int64_t input_zp_shifted;
+ if (input_unsigned)
+ {
+ int64_t in_val64;
+ int64_t in_zp64;
+ switch (GetNumBits<InDtype>::value)
+ {
+ case 8:
+ in_val64 = zero_extend(static_cast<int8_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
+ break;
+ case 16:
+ in_val64 = zero_extend(static_cast<int16_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
+ break;
+ case 32:
+ in_val64 = zero_extend(static_cast<int32_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int32_t>(input_zp));
+ break;
+ case 48:
+ in_val64 = static_cast<int64_t>(in_val) & 0x0000FFFFFFFFFFFF;
+ in_zp64 = static_cast<int64_t>(input_zp) & 0x0000FFFFFFFFFFFF;
+ break;
+ default:
+ in_val64 = static_cast<int64_t>(in_val);
+ in_zp64 = static_cast<int64_t>(input_zp);
+ break;
+ }
+ input_zp_shifted = in_val64 - in_zp64;
+ }
+ else
+ {
+ int64_t in_val64;
+ int64_t in_zp64 = static_cast<int64_t>(input_zp);
+ switch (GetNumBits<InDtype>::value)
+ {
+ case 8:
+ in_val64 = sign_extend(static_cast<uint8_t>(in_val & 0xFF));
+ break;
+ case 16:
+ in_val64 = sign_extend(static_cast<uint16_t>(in_val & 0xFFFF));
+ break;
+ case 32:
+ in_val64 = static_cast<int64_t>(static_cast<int32_t>(in_val & 0xFFFFFFFF));
+ break;
+ case 48:
+ // sign extend i48 to i64
+ in_val64 = static_cast<int64_t>(in_val);
+ if (in_val64 & 0x800000000000)
+ {
+ // in_val contains negative i48, sign extend to i64
+ in_val64 |= 0xFFFF000000000000;
+ }
+ break;
+ default:
+ in_val64 = static_cast<int64_t>(in_val);
+ break;
+ }
+ input_zp_shifted = in_val64 - in_zp64;
+ }
+ return input_zp_shifted;
+ };
+
if (per_channel)
{
ETensor2<InEigenType> curr_channel_slice_prescaled;
@@ -218,32 +296,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
channel_shift = shift[i];
curr_channel_slice_postscaled =
curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType {
- int64_t input_zp_shifted;
- if (input_unsigned)
- {
- int64_t in_val64;
- int64_t in_zp64;
- switch (GetNumBits<InDtype>::value)
- {
- case 8:
- in_val64 = zero_extend(static_cast<int8_t>(in_val));
- in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
- break;
- case 16:
- in_val64 = zero_extend(static_cast<int16_t>(in_val));
- in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
- break;
- default:
- in_val64 = static_cast<int64_t>(in_val);
- in_zp64 = static_cast<int64_t>(input_zp);
- break;
- }
- input_zp_shifted = in_val64 - in_zp64;
- }
- else
- {
- input_zp_shifted = in_val - input_zp;
- }
+ int64_t input_zp_shifted = get_input_zp_shifted(in_val);
+
int32_t scaled;
if (scale32)
scaled = TosaReference::QuantUtil::apply_scale_32(static_cast<int32_t>(input_zp_shifted),
@@ -311,32 +365,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
try
{
output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType {
- int64_t input_zp_shifted;
- if (input_unsigned)
- {
- int64_t in_val64;
- int64_t in_zp64;
- switch (GetNumBits<InDtype>::value)
- {
- case 8:
- in_val64 = zero_extend(static_cast<int8_t>(in_val));
- in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
- break;
- case 16:
- in_val64 = zero_extend(static_cast<int16_t>(in_val));
- in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
- break;
- default:
- in_val64 = static_cast<int64_t>(in_val);
- in_zp64 = static_cast<int64_t>(input_zp);
- break;
- }
- input_zp_shifted = in_val64 - in_zp64;
- }
- else
- {
- input_zp_shifted = in_val - input_zp;
- }
+ int64_t input_zp_shifted = get_input_zp_shifted(in_val);
+
int32_t scaled;
if (scale32)
scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
@@ -626,6 +656,12 @@ CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper()
fcn = [](InEigenType in) -> OutEigenType { return in; };
}
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP64>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp64
+ fcn = [](float in) -> double { return static_cast<double>(in); };
+}
+
template <TOSA_REF_TYPE OutDtype>
CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper()
{
@@ -663,12 +699,18 @@ CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper()
fcn = [](InEigenType in) -> OutEigenType { return in; };
}
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP64>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp64
+ fcn = [](float in) -> double { return static_cast<double>(in); };
+}
+
template <TOSA_REF_TYPE InDtype>
CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
{
// Integer data converted to fp8e4m3 (stored as fp32)
fcn = [](InEigenType in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in)));
+ auto f = static_cast<fp8e4m3>(float(in));
float out = static_cast<float>(f);
return out;
};
@@ -677,70 +719,60 @@ CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
{
// fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
- fcn = [](float in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
- float out = static_cast<float>(f);
- return out;
- };
+ fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e4m3>(in)); };
}
CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
{
// bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
- fcn = [](float in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
- float out = static_cast<float>(f);
- return out;
- };
+ fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e4m3>(in)); };
}
CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
{
// fp32 data converted to fp8e4m3 (stored as fp32)
- fcn = [](float in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
- float out = static_cast<float>(f);
- return out;
- };
+ fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e4m3>(in)); };
}
template <TOSA_REF_TYPE InDtype>
CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
{
// Integer data converted to fp8e5m2 (stored as fp32)
- fcn = [](InEigenType in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in)));
- float out = static_cast<float>(f);
- return out;
- };
+ fcn = [](InEigenType in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); };
}
CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
{
// fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
- fcn = [](float in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
- float out = static_cast<float>(f);
- return out;
- };
+ fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); };
}
CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
{
// bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
- fcn = [](float in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
- float out = static_cast<float>(f);
- return out;
- };
+ fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); };
}
CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
{
// fp32 data converted to fp8e5m2 (stored as fp32)
- fcn = [](float in) -> float {
- auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
- float out = static_cast<float>(f);
+ fcn = [](float in) -> float { return static_cast<float>(static_cast<fp8e5m2>(in)); };
+}
+
+CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp64 data converted to fp8e5m2 (stored as fp32)
+ fcn = [](double in) -> float {
+ float out = static_cast<float>(static_cast<fp8e4m3>(in));
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp64 data converted to fp8e5m2 (stored as fp32)
+ fcn = [](double in) -> float {
+ float out = static_cast<float>(static_cast<fp8e5m2>(in));
return out;
};
}
@@ -812,6 +844,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP8E5M2);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
@@ -821,9 +855,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
@@ -843,7 +879,19 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, UINT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, UINT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, UINT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, UINT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 0636357..9bcace4 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -371,6 +371,23 @@ private:
FcnType fcn;
};
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP64>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
template <TOSA_REF_TYPE OutDtype>
class CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>
{
@@ -441,6 +458,23 @@ private:
FcnType fcn;
};
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP64>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
template <TOSA_REF_TYPE InDtype>
class CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>
{
@@ -577,6 +611,40 @@ private:
FcnType fcn;
};
+template <>
+class CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP64, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
template <TOSA_REF_TYPE OutDtype>
class CastHelper<TOSA_REF_TYPE_FP64, OutDtype>
{