aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/arith_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/arith_util.h')
-rw-r--r--reference_model/src/arith_util.h89
1 files changed, 89 insertions, 0 deletions
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index 554a7a2..33bdeed 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -31,13 +31,18 @@
#include <math.h>
#define __STDC_LIMIT_MACROS //enable min/max of plain data type
#include "func_debug.h"
+#include "func_config.h"
#include "inttypes.h"
+#include "tosa_generated.h"
#include <cassert>
#include <iostream>
#include <limits>
#include <stdint.h>
#include <typeinfo>
+#include <Eigen/Core>
+#include <bitset>
+using namespace tosa;
using namespace std;
inline size_t _count_one(uint64_t val)
@@ -191,4 +196,88 @@ constexpr T saturate(const uint32_t width, const intmax_t value)
// clang-format on
}
+inline void float_trunc_bytes(float* src)
+{
+ /* Set the least significant two bytes to zero for the input float value.*/
+ char src_as_bytes[sizeof(float)];
+ memcpy(src_as_bytes, src, sizeof(float));
+
+ if (g_func_config.float_is_big_endian)
+ {
+ src_as_bytes[2] = '\000';
+ src_as_bytes[3] = '\000';
+ }
+ else
+ {
+ src_as_bytes[0] = '\000';
+ src_as_bytes[1] = '\000';
+ }
+
+ memcpy(src, &src_as_bytes, sizeof(float));
+}
+
+inline void truncateFloatToBFloat(float* src, int64_t size) {
+ /* Set the least significant two bytes to zero for each float
+ value in the input src buffer. */
+ ASSERT_MEM(src);
+ ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer.");
+ for (; size != 0; src++, size--)
+ {
+ float_trunc_bytes(src);
+ }
+}
+
+inline bool checkValidBFloat(float src)
+{
+ /* Checks if the least significant two bytes are zero. */
+ ASSERT_MEM(src);
+ char src_as_bytes[sizeof(float)];
+ memcpy(src_as_bytes, &src, sizeof(float));
+
+ if (g_func_config.float_is_big_endian)
+ {
+ return (src_as_bytes[2] == '\000' && src_as_bytes[3] == '\000');
+ }
+ else
+ {
+ return (src_as_bytes[0] == '\000' && src_as_bytes[1] == '\000');
+ }
+}
+
+inline bool float_is_big_endian()
+{
+ /* Compares float values 1.0 and -1.0 by checking whether the
+ negation causes the first or the last byte to change.
+ First byte changing would indicate the float representation
+ is big-endian.*/
+ float f = 1.0;
+ char f_as_bytes[sizeof(float)];
+ memcpy(f_as_bytes, &f, sizeof(float));
+ f = -f;
+ char f_neg_as_bytes[sizeof(float)];
+ memcpy(f_neg_as_bytes, &f, sizeof(float));
+ return f_as_bytes[0] != f_neg_as_bytes[0];
+}
+
+template <DType Dtype>
+float fpTrunc(float f_in)
+{
+ /* Truncates a float value based on the DType it represents.*/
+ switch (Dtype)
+ {
+ case DType_BF16:
+ truncateFloatToBFloat(&f_in, 1);
+ break;
+ case DType_FP16:
+ // TODO(jw): implement FP16 truncate function (no-op placeholder for now)
+ break;
+ case DType_FP32:
+ // No-op for fp32
+ break;
+ default:
+ ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype));
+ }
+ return f_in;
+}
+
#endif /* _ARITH_UTIL_H */