diff options
author | Eric Kunze <eric.kunze@arm.com> | 2020-10-13 16:11:07 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-10-14 11:11:43 -0700 |
commit | e5e2676409a936431f87d31fb74d825257b20804 (patch) | |
tree | 304d93d993ef6417b02a515025f9030367682774 /reference_model/src/quant_util.h | |
parent | 88b7860f180f91b5b66764c61cfd97d8bc53cece (diff) | |
download | reference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz |
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Diffstat (limited to 'reference_model/src/quant_util.h')
-rw-r--r-- | reference_model/src/quant_util.h | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h new file mode 100644 index 0000000..3638b3b --- /dev/null +++ b/reference_model/src/quant_util.h @@ -0,0 +1,103 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_REFERENCE_QUANT_UTIL_H +#define TOSA_REFERENCE_QUANT_UTIL_H + +#include "arith_util.h" +#include "func_debug.h" +#include "ops/template_types.h" +#include "tosa_generated.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <DType AccDType> +class QuantUtil +{ +public: + using T = typename GetEigenType<AccDType>::type; + + static void reciprocal_scale(int32_t value, + // Output + int32_t& multiplier, + int32_t& shift) + { + ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); + uint32_t value_u32 = (uint32_t)value; + int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k) + int64_t numerator = ((1L << 30) + 1) << k; + multiplier = numerator / value; // (1<<30) <= multiplier < (1<<31) + shift = 30 + k; + } + + static int32_t apply_scale(T value, int32_t multiplier, int32_t shift, bool enabled_adjusted_rounding = true) + { + if (AccDType == DType_FLOAT) + { + return value; + } + ASSERT_MSG(multiplier >= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier); + int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0; + if (enabled_adjusted_rounding) + { + if (AccDType != DType_INT48) + { + if (shift > 31 && value >= 0) + round += (1L << 30); + if (shift > 31 && value < 0) + round -= (1L << 30); + } + else + { // input data could be int16, which leads to 48 bits accumulator + ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode", + (1 << 15)); + } + } + int64_t result = (int64_t)value * multiplier + round; + result = result >> shift; + ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), + "apply_scale() error: scaled result exceed int32 numeric range"); + return static_cast<int32_t>(result); + } +}; + +class TypeChecker +{ +public: + static bool is_integer(DType dtype) + { + if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_AINT8 || dtype == DType_UINT8 || + dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48) + { + return true; + } + return false; + } + static bool is_symmetric(DType dtype) + { + if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_INT16 || dtype == DType_INT32 || + dtype == DType_INT48) + { + return true; + } + return false; + } +}; +}; // namespace TosaReference + +#endif |