aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-03-21 14:34:33 -0700
committerEric Kunze <eric.kunze@arm.com>2024-03-28 00:35:36 +0000
commit08fe7a5b7e2c7c1a77968130e11267ef61490ac8 (patch)
tree29e974b12bc5540a4a364aaa7b0d6aa47c30c923
parentd5b1512b1d2cea3b87e52a0ecc123db2a7a7cad3 (diff)
downloadreference_model-08fe7a5b7e2c7c1a77968130e11267ef61490ac8.tar.gz
Take into account of `output_unsigned` in rescale operation
Set QMin and QMax based on the value of attribute `output_unsigned`. Change-Id: I7f21f3edd7311295285fb3988b3c800de114777a Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--reference_model/include/dtype.h29
-rw-r--r--reference_model/src/arith_util.h70
-rw-r--r--reference_model/src/ops/type_conversion.cc151
-rw-r--r--reference_model/src/ops/type_conversion.h3
4 files changed, 176 insertions, 77 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
index 3e8bdf5..a283f39 100644
--- a/reference_model/include/dtype.h
+++ b/reference_model/include/dtype.h
@@ -145,6 +145,35 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype)
return TOSA_REF_TYPE_UNKNOWN;
}
+template <TOSA_REF_TYPE Dtype>
+bool IsSignedInt()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_INT4:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_INT48:
+ return true;
+
+ case TOSA_REF_TYPE_UINT8:
+ case TOSA_REF_TYPE_UINT16:
+ return false;
+
+ case TOSA_REF_TYPE_BOOL:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_SHAPE:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
+ default:
+ FATAL_ERROR("dtype is not an integer type");
+ break;
+ }
+}
+
}; // namespace TosaReference
#endif
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index f0d184c..fee9fef 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -22,6 +22,7 @@
* fix point arithmetic
* fp16 type conversion(in binary translation)
* fp16 arithmetic (disguised with fp32 now)
+ * and include the arithmetic helpers listed in Section 4.3.1. of the spec
*/
#ifndef ARITH_UTIL_H
@@ -35,6 +36,7 @@
#include "func_debug.h"
#include "half.hpp"
#include "inttypes.h"
+#include "ops/template_types.h"
#include <bitset>
#include <cassert>
#include <limits>
@@ -247,4 +249,72 @@ float fpTrunc(float f_in)
return f_in;
}
+// return the maximum value when interpreting type T as a signed value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getSignedMaximum()
+{
+ if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+ return GetQMax<TOSA_REF_TYPE_INT8>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+ return GetQMax<TOSA_REF_TYPE_INT16>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT32)
+ return GetQMax<TOSA_REF_TYPE_INT32>::value;
+
+ FATAL_ERROR("Get maximum_s for the dtype input is not supported");
+ return 0;
+}
+
+// return the minimum value when interpreting type T as a signed value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getSignedMinimum()
+{
+ if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+ return GetQMin<TOSA_REF_TYPE_INT8>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+ return GetQMin<TOSA_REF_TYPE_INT16>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT32)
+ return GetQMin<TOSA_REF_TYPE_INT32>::value;
+
+ FATAL_ERROR("Get minimum_s for the dtype input is not supported");
+ return 0;
+}
+
+// return the maximum value when interpreting type T as an unsigned value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getUnsignedMaximum()
+{
+ if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+ return GetQMax<TOSA_REF_TYPE_UINT8>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+ return GetQMax<TOSA_REF_TYPE_UINT16>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT32)
+ return std::numeric_limits<uint32_t>::max();
+
+ FATAL_ERROR("Get maximum_u for the dtype input is not supported");
+ return 0;
+}
+
+// return the minimum value when interpreting type T as an unsigned value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getUnsignedMinimum()
+{
+ if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+ return GetQMin<TOSA_REF_TYPE_UINT8>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+ return GetQMin<TOSA_REF_TYPE_UINT16>::value;
+
+ if (Dtype == TOSA_REF_TYPE_INT32)
+ return std::numeric_limits<uint32_t>::min();
+
+ FATAL_ERROR("Get minimum_u for the dtype input is not supported");
+ return 0;
+}
+
#endif /* _ARITH_UTIL_H */
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index d58cfeb..835b656 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -148,6 +148,9 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
bool input_unsigned = attribute->input_unsigned();
bool output_unsigned = attribute->output_unsigned();
+ int32_t QMin = output_unsigned ? getUnsignedMinimum<OutDtype>() : getSignedMinimum<OutDtype>();
+ int32_t QMax = output_unsigned ? getUnsignedMaximum<OutDtype>() : getSignedMaximum<OutDtype>();
+
// reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
Eigen::array<Eigen::Index, 2> shape_2d;
shape_2d[0] = 1;
@@ -200,13 +203,12 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
{
for (int32_t i = 0; i < shape_2d[1]; i++)
{
- begin = Eigen::array<Eigen::Index, 2>({ 0, i });
- curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
- channel_multiplier = multiplier[i];
- channel_shift = shift[i];
- curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr(
- [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned,
- output_unsigned](InEigenType in_val) -> OutEigenType {
+ begin = Eigen::array<Eigen::Index, 2>({ 0, i });
+ curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+ channel_multiplier = multiplier[i];
+ channel_shift = shift[i];
+ curr_channel_slice_postscaled =
+ curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType {
int64_t input_zp_shifted;
if (input_unsigned)
{
@@ -293,78 +295,79 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
int32_t tensor_shift = shift[0];
try
{
- output_2d =
- input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32,
- input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType {
- int64_t input_zp_shifted;
- if (input_unsigned)
- {
- int64_t in_val64;
- int64_t in_zp64;
- switch (GetNumBits<InDtype>::value)
- {
- case 8:
- in_val64 = zero_extend(static_cast<int8_t>(in_val));
- in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
- break;
- case 16:
- in_val64 = zero_extend(static_cast<int16_t>(in_val));
- in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
- break;
- default:
- in_val64 = static_cast<int64_t>(in_val);
- in_zp64 = static_cast<int64_t>(input_zp);
- break;
- }
- input_zp_shifted = in_val64 - in_zp64;
- }
- else
- {
- input_zp_shifted = in_val - input_zp;
- }
- int32_t scaled;
- if (scale32)
- scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
- tensor_shift, double_round);
- else
- scaled =
- TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
-
- int64_t output_zp_extended;
- if (output_unsigned)
- {
- switch (GetNumBits<OutDtype>::value)
- {
- case 8:
- output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
- break;
- case 16:
- output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
- break;
- default:
- output_zp_extended = static_cast<int64_t>(output_zp);
- break;
- }
- }
- else
+ output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType {
+ int64_t input_zp_shifted;
+ if (input_unsigned)
+ {
+ int64_t in_val64;
+ int64_t in_zp64;
+ switch (GetNumBits<InDtype>::value)
{
- output_zp_extended = static_cast<int64_t>(output_zp);
+ case 8:
+ in_val64 = zero_extend(static_cast<int8_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
+ break;
+ case 16:
+ in_val64 = zero_extend(static_cast<int16_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
+ break;
+ default:
+ in_val64 = static_cast<int64_t>(in_val);
+ in_zp64 = static_cast<int64_t>(input_zp);
+ break;
}
- int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
- int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
- int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
- if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
+ input_zp_shifted = in_val64 - in_zp64;
+ }
+ else
+ {
+ input_zp_shifted = in_val - input_zp;
+ }
+ int32_t scaled;
+ if (scale32)
+ scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
+ double_round);
+ else
+ scaled =
+ TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
+
+ int64_t output_zp_extended;
+ if (output_unsigned)
+ {
+ switch (GetNumBits<OutDtype>::value)
{
- std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
- std::to_string(output_zp) + "] not in i32 range";
- throw desc;
+ case 8:
+ output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
+ break;
+ case 16:
+ output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
+ break;
+ default:
+ output_zp_extended = static_cast<int64_t>(output_zp);
+ break;
}
+ }
+ else
+ {
+ output_zp_extended = static_cast<int64_t>(output_zp);
+ }
+ int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
+ int64_t i32_max_in_64 = IsSignedInt<OutDtype>()
+ ? static_cast<int64_t>(std::numeric_limits<int32_t>::max())
+ : static_cast<int64_t>(std::numeric_limits<uint32_t>::max());
+ int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
+
+ if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
+ {
+ std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
+ std::to_string(output_zp) + "] not in i32 range";
+ throw desc;
+ }
- OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
- return out_val;
- });
+ OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
}
catch (std::string desc)
{
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index a06dccc..da5537e 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -43,9 +43,6 @@ public:
using TMultiplierI32 = Eigen::Tensor<I32EigenType, 1>;
using TShift = Eigen::Tensor<I8EigenType, 1>;
- static constexpr int32_t QMin = GetQMin<OutDtype>::value;
- static constexpr int32_t QMax = GetQMax<OutDtype>::value;
-
protected:
TosaRescaleAttribute* attribute;
TosaReference::TensorTemplate<TIn>* in;