aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2020-08-18 13:19:18 +0200
committertim.hall <tim.hall@arm.com>2020-08-21 15:22:05 +0000
commit1575b9413de2569de25bb2520b898a91f24ad3b0 (patch)
tree13ecfc66b104d135c8c58b0236ee1aca17c9f109
parent1cdc4675bab71c8a8d15b1687790954dab42ddd1 (diff)
downloadethos-u-vela-1575b9413de2569de25bb2520b898a91f24ad3b0.tar.gz
[MLBEDSW-2730] Implement LUT generation for softmax uint8/int8
Implemented LUT generation for softmax uint8/int8 to match the reference. Change-Id: Ib9acaa295ee1066591e800023d75f364520b44c1 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
-rw-r--r--ethosu/vela/fp_math.py138
-rw-r--r--ethosu/vela/register_command_stream_generator.py3
-rw-r--r--ethosu/vela/softmax.py133
-rw-r--r--ethosu/vela/supported_operators.py21
-rw-r--r--ethosu/vela/test/test_fp_math.py118
5 files changed, 312 insertions, 101 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
new file mode 100644
index 00000000..2055879a
--- /dev/null
+++ b/ethosu/vela/fp_math.py
@@ -0,0 +1,138 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Contains various fixed point math functions based on the gemmlowp fixed
+# point implementation.
+import numpy as np
+
+
+def saturating_rounding_mul(a, b):
+ assert np.int32(a) == a
+ assert np.int32(b) == b
+ if a == b and a == np.iinfo(np.int32).min:
+ return np.int32(np.iinfo(np.int32).max)
+ ab = np.int64(a) * np.int64(b)
+ nudge = (1 << 30) if ab >= 0 else (1 - (1 << 30))
+ result = np.int32(np.right_shift(ab + nudge, 31))
+ if result < 0:
+ result += 1
+ return result
+
+
+def shift_left(a, offset):
+ assert np.int32(a) == a
+ assert offset >= 0
+ a_info = np.iinfo(a)
+ shifted = a * (1 << offset)
+ if shifted < a_info.min:
+ return np.int32(a_info.min)
+ elif shifted > a_info.max:
+ return np.int32(a_info.max)
+ else:
+ return np.int32(shifted)
+
+
+def rounding_divide_by_pot(x, exponent):
+ assert np.int32(x) == x
+ assert np.int32(exponent) == exponent
+ mask = (1 << exponent) - 1
+ remainder = x & mask
+ threshold = mask >> 1
+ if x < 0:
+ threshold += 1
+ result = x >> exponent
+ if remainder > threshold:
+ result += 1
+ return result
+
+
+def saturating_rounding_multiply_by_pot(exponent, x):
+ assert np.int32(x) == x
+ assert np.int32(exponent) == exponent
+ threshold = (1 << (np.iinfo(np.int32).bits - 1 - exponent)) - 1
+ if x > threshold:
+ return np.iinfo(np.int32).max
+ elif x < -threshold:
+ return np.iinfo(np.int32).min
+ else:
+ return shift_left(x, exponent)
+
+
+def rescale(integer_bits_src, integer_bits_dst, x):
+ assert np.int32(integer_bits_src) == integer_bits_src
+ assert np.int32(integer_bits_dst) == integer_bits_dst
+ assert np.int32(x) == x
+ exponent = integer_bits_src - integer_bits_dst
+ result = saturating_rounding_multiply_by_pot(exponent, x)
+ return result
+
+
+# Input Q0.31
+def exp_on_interval_between_negative_one_quarter_and_0_excl(a):
+ assert np.int32(a) == a
+ assert -1 << (31 - 2) <= a < 0
+ offset = 28
+ constant_term = 1895147668
+ constant_1_over_3 = 715827883
+ x = a + (1 << offset)
+ x2 = saturating_rounding_mul(x, x)
+ x3 = saturating_rounding_mul(x2, x)
+ x4 = saturating_rounding_mul(x2, x2)
+ x4_over_4 = rounding_divide_by_pot(x4, 2)
+ x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
+ saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
+ )
+
+ return np.int32(
+ constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
+ )
+
+
+# Input Q5.26
+def exp_on_negative_values(a):
+ assert np.int32(a) == a
+ assert a <= 0
+ one_quarter = np.int32(16777216)
+ mask = np.int32(16777215)
+ a_mod_quarter_minus_one_quarter = np.int32((a & mask) - one_quarter)
+
+ result = exp_on_interval_between_negative_one_quarter_and_0_excl(rescale(5, 0, a_mod_quarter_minus_one_quarter))
+ remainder = np.int32(a_mod_quarter_minus_one_quarter - a)
+
+ def exp_barrel_shifter(exponent, multiplier, result):
+ fractional_bits = 26
+ integer_bits = 5
+ shift = fractional_bits + exponent if integer_bits > exponent else 0
+ if remainder & (1 << shift):
+ return saturating_rounding_mul(result, multiplier)
+ else:
+ return result
+
+ result = exp_barrel_shifter(-2, 1672461947, result)
+ result = exp_barrel_shifter(-1, 1302514674, result)
+ result = exp_barrel_shifter(+0, 790015084, result)
+ result = exp_barrel_shifter(+1, 290630308, result)
+ result = exp_barrel_shifter(+2, 39332535, result)
+ result = exp_barrel_shifter(+3, 720401, result)
+ result = exp_barrel_shifter(+4, 242, result)
+
+ if a == 0:
+ return np.iinfo(np.int32).max
+ else:
+ return result
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 013128b4..7b1e9a69 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -50,7 +50,6 @@ from .numeric_util import quantise_float32
from .numeric_util import round_away_zero
from .numeric_util import round_up_to_int
from .operation import NpuBlockType
-from .shared_buffer_allocation import SharedBufferAllocation
from .tensor import MemType
from .tensor import TensorBlockTraversal
from .tensor import TensorFormat
@@ -837,7 +836,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1)
assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range."
if cmd.ofm_tensor.dtype == DataType.int32:
- lut_index |= (3 << 12) # Force I8 range
+ lut_index |= 3 << 12 # Force I8 range
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index)
faf_min = ofm_quant_qmin
faf_max = ofm_quant_qmax
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index c67cc376..eb97c792 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -1,22 +1,28 @@
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
# SPDX-License-Identifier: Apache-2.0
#
-# Licensed under the Apache License, Version 2.0 (the License); you may
-# not use this file except in compliance with the License.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an AS IS BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
# Description:
# Contains SoftMax
+import math
+
import numpy as np
+from . import fp_math
from . import scaling
from .data_type import DataType
from .operation import Operation
@@ -30,76 +36,6 @@ class SoftMax:
# Turn off black formatting for the LUT tables to keep them compact
# fmt: off
- EXP_LUT_U8 = [
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
- 0x00000291, 0x000006fa, 0x000012f6, 0x0000338b, 0x00008c1b, 0x00017cd8, 0x00040b3d, 0x000afe11,
- 0x001de16c, 0x00513949, 0x00dcca03, 0x02582ac2, 0x065f6c52, 0x1152aaf6, 0x2f16ad4c, 0x7fffffff
- ]
-
- EXP_LUT_I8 = [
- 0x000011c9, 0x000012b8, 0x000013b4, 0x000014bd, 0x000015d4, 0x000016fa, 0x0000182f, 0x00001975,
- 0x00001acb, 0x00001c34, 0x00001daf, 0x00001f3f, 0x000020e3, 0x0000229e, 0x00002470, 0x0000265a,
- 0x0000285e, 0x00002a7d, 0x00002cb9, 0x00002f13, 0x0000318c, 0x00003427, 0x000036e5, 0x000039c8,
- 0x00003cd1, 0x00004004, 0x00004361, 0x000046ec, 0x00004aa6, 0x00004e93, 0x000052b4, 0x0000570d,
- 0x00005ba1, 0x00006072, 0x00006583, 0x00006ada, 0x00007077, 0x00007661, 0x00007c9a, 0x00008327,
- 0x00008a0c, 0x0000914d, 0x000098f1, 0x0000a0fb, 0x0000a971, 0x0000b259, 0x0000bbb9, 0x0000c597,
- 0x0000cffa, 0x0000dae9, 0x0000e66b, 0x0000f288, 0x0000ff48, 0x00010cb3, 0x00011ad3, 0x000129b1,
- 0x00013957, 0x000149d0, 0x00015b26, 0x00016d65, 0x0001809b, 0x000194d2, 0x0001aa1a, 0x0001c080,
- 0x0001d814, 0x0001f0e4, 0x00020b03, 0x00022681, 0x00024371, 0x000261e7, 0x000281f7, 0x0002a3b5,
- 0x0002c73b, 0x0002ec9e, 0x000313f8, 0x00033d64, 0x000368fd, 0x000396e0, 0x0003c72e, 0x0003fa05,
- 0x00042f89, 0x000467dd, 0x0004a326, 0x0004e18e, 0x0005233d, 0x00056860, 0x0005b126, 0x0005fdbf,
- 0x00064e5f, 0x0006a33b, 0x0006fc8e, 0x00075a93, 0x0007bd89, 0x000825b3, 0x00089356, 0x000906bd,
- 0x00098034, 0x000a000f, 0x000a86a2, 0x000b1447, 0x000ba95f, 0x000c464d, 0x000ceb7c, 0x000d9959,
- 0x000e505a, 0x000f10f9, 0x000fdbb8, 0x0010b120, 0x001191c0, 0x00127e2f, 0x0013770b, 0x00147cfc,
- 0x001590b2, 0x0016b2e6, 0x0017e45c, 0x001925e1, 0x001a784c, 0x001bdc81, 0x001d536f, 0x001ede14,
- 0x00207d76, 0x002232af, 0x0023fee3, 0x0025e348, 0x0027e125, 0x0029f9ce, 0x002c2ead, 0x002e813e,
- 0x0030f30f, 0x003385c7, 0x00363b1e, 0x003914e9, 0x003c150f, 0x003f3d97, 0x004290a0, 0x00461065,
- 0x0049bf40, 0x004d9fac, 0x0051b444, 0x0055ffc2, 0x005a850e, 0x005f472f, 0x00644959, 0x00698eea,
- 0x006f1b6b, 0x0074f298, 0x007b185e, 0x008190dd, 0x00886073, 0x008f8bad, 0x00971761, 0x009f08a0,
- 0x00a764c0, 0x00b03163, 0x00b9746c, 0x00c3341a, 0x00cd76f8, 0x00d843eb, 0x00e3a23a, 0x00ef9981,
- 0x00fc31d0, 0x0109739d, 0x011767cf, 0x012617cd, 0x01358d6e, 0x0145d319, 0x0156f3be, 0x0168fadc,
- 0x017bf49d, 0x018fedb3, 0x01a4f391, 0x01bb1457, 0x01d25ede, 0x01eae2e1, 0x0204b0c5, 0x021fd9e9,
- 0x023c708e, 0x025a87f5, 0x027a343a, 0x029b8ac1, 0x02bea1ea, 0x02e39148, 0x030a71be, 0x03335d49,
- 0x035e6f88, 0x038bc564, 0x03bb7d53, 0x03edb776, 0x0422956d, 0x045a3add, 0x0494cd23, 0x04d27398,
- 0x051357c1, 0x0557a511, 0x059f8990, 0x05eb3585, 0x063adbc4, 0x068eb1f7, 0x06e6f042, 0x0743d212,
- 0x07a595d0, 0x080c7d1f, 0x0878cd5d, 0x08eacf11, 0x0962cefe, 0x09e11dc0, 0x0a661028, 0x0af1ffdf,
- 0x0b854a8e, 0x0c205363, 0x0cc38284, 0x0d6f4577, 0x0e241032, 0x0ee25ba2, 0x0faaa7e6, 0x107d7b92,
- 0x115b64b1, 0x1244f774, 0x133ad1b8, 0x143d9876, 0x154df988, 0x166cac69, 0x179a70c9, 0x18d81250,
- 0x1a266643, 0x1b864d38, 0x1cf8b430, 0x1e7e9307, 0x2018f0a9, 0x21c8e098, 0x238f850c, 0x256e1033,
- 0x2765c273, 0x2977ef40, 0x2ba5faa9, 0x2df15b73, 0x305b9d6b, 0x32e65e8a, 0x3593552c, 0x38644d67,
- 0x3b5b2b66, 0x3e79ee87, 0x41c2adcb, 0x45379f4e, 0x48db158a, 0x4caf81e6, 0x50b7797f, 0x54f5af16,
- 0x596cfe2f, 0x5e2066d0, 0x631310c8, 0x684852d8, 0x6dc3a909, 0x7388c421, 0x799b84b7, 0x7fffffff,
- ]
-
EXP_LUT = [
0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
@@ -239,8 +175,27 @@ class SoftMax:
self.op = op
def generate_exp_table(self, beta, input_scale):
- # TODO: Generate the exp table using the same math as the reference
- return self.EXP_LUT_U8 if input_scale == 1.0 else self.EXP_LUT_I8
+ integer_bits = 5
+ total_signed_bits = 31
+ # Calculate scaling
+ real_beta = min(
+ np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
+ )
+ scale, shift = scaling.quantise_scale(real_beta)
+ shift = 31 - shift
+ diff_min = -1.0 * math.floor(
+ 1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
+ )
+ # Generate the exp LUT
+ lut = []
+ for x in range(256):
+ input_diff = x - 255
+ if input_diff >= diff_min:
+ rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
+ lut.append(fp_math.exp_on_negative_values(rescale))
+ else:
+ lut.append(0)
+ return lut
def get_graph(self):
ifm = self.op.inputs[0]
@@ -339,7 +294,12 @@ class SoftMax:
sub5_op = Operation("SubAct", self.op.name + "_sub5")
sub5_op.add_input_tensor(
create_const_tensor(
- "headroom_offset_const", [1, 1, 1, 1], DataType.int32, [12 + 31 - 8], np.int32, quantization=no_scale_quant
+ "headroom_offset_const",
+ [1, 1, 1, 1],
+ DataType.int32,
+ [12 + 31 - 8],
+ np.int32,
+ quantization=no_scale_quant,
),
)
sub5_op.add_input_tensor(headroom_plus_one)
@@ -348,9 +308,7 @@ class SoftMax:
sub5_op.set_output_tensor(right_shift)
# PASS 6 - Sub
- one = create_const_tensor(
- "one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
- )
+ one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
sub6_op = Operation("SubAct", self.op.name + "_sub6")
sub6_op.add_input_tensor(headroom_plus_one)
sub6_op.add_input_tensor(one)
@@ -404,7 +362,12 @@ class SoftMax:
mul11_op.add_input_tensor(half_denominator)
mul11_op.add_input_tensor(
create_const_tensor(
- "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], np.int32, quantization=one_scale_quant
+ "neg_32_over_17_const",
+ [1, 1, 1, 1],
+ DataType.int32,
+ [-1010580540],
+ np.int32,
+ quantization=one_scale_quant,
),
)
rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
@@ -428,9 +391,7 @@ class SoftMax:
F2_one = create_const_tensor(
"F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
)
- two = create_const_tensor(
- "two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant
- )
+ two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
for i in range(3):
# PASS 13, 18, 23 - MUL
mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
@@ -448,7 +409,7 @@ class SoftMax:
one_minus_half_denominator_times_x.quantization = one_scale_quant
sub_op.set_output_tensor(one_minus_half_denominator_times_x)
# PASS 15, 20, 25 - MUL
- mul_op = Operation("MulAct", self.op.name + "_mul%d" %+ (15 + i * 5))
+ mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
mul_op.add_input_tensor(nr_x)
mul_op.add_input_tensor(one_minus_half_denominator_times_x)
to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index c4186018..9e415b51 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -54,19 +54,11 @@ class SupportedOperators:
self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
self.binary_elem_wise_shift_ops = set(("SHL", "SHR",))
self.binary_elem_wise_add_mul_sub = set(
- (
- "AddAct",
- "MulAct",
- "SubAct",
- "QuantizedAdd",
- "QuantizedSub",
- "QuantizedMul",
- "Mul",
- "Add",
- "Sub",
- )
+ ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
+ )
+ self.binary_elem_wise_main_ops = (
+ self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
)
- self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
self.activation_ops = set(
(
@@ -166,7 +158,10 @@ class SupportedOperators:
return False
if (
t.element_size() > 2
- and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
+ and op.type
+ not in set(("Requantize", "ReduceSum", "CLZ",))
+ | self.binary_elem_wise_add_mul_sub
+ | self.binary_elem_wise_shift_ops
):
return False
# check size
diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py
new file mode 100644
index 00000000..2dde1e4b
--- /dev/null
+++ b/ethosu/vela/test/test_fp_math.py
@@ -0,0 +1,118 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Unit tests for fixed point math
+import numpy as np
+import pytest
+
+from ethosu.vela import fp_math
+from ethosu.vela.softmax import SoftMax
+
+# Turn off black formatting for EXP_LUT to keep it compact
+# fmt: off
+
+EXP_LUT = [
+ 0x000011c9, 0x000012b8, 0x000013b4, 0x000014bd, 0x000015d4, 0x000016fa, 0x0000182f, 0x00001975,
+ 0x00001acb, 0x00001c34, 0x00001daf, 0x00001f3f, 0x000020e3, 0x0000229e, 0x00002470, 0x0000265a,
+ 0x0000285e, 0x00002a7d, 0x00002cb9, 0x00002f13, 0x0000318c, 0x00003427, 0x000036e5, 0x000039c8,
+ 0x00003cd1, 0x00004004, 0x00004361, 0x000046ec, 0x00004aa6, 0x00004e93, 0x000052b4, 0x0000570d,
+ 0x00005ba1, 0x00006072, 0x00006583, 0x00006ada, 0x00007077, 0x00007661, 0x00007c9a, 0x00008327,
+ 0x00008a0c, 0x0000914d, 0x000098f1, 0x0000a0fb, 0x0000a971, 0x0000b259, 0x0000bbb9, 0x0000c597,
+ 0x0000cffa, 0x0000dae9, 0x0000e66b, 0x0000f288, 0x0000ff48, 0x00010cb3, 0x00011ad3, 0x000129b1,
+ 0x00013957, 0x000149d0, 0x00015b26, 0x00016d65, 0x0001809b, 0x000194d2, 0x0001aa1a, 0x0001c080,
+ 0x0001d814, 0x0001f0e4, 0x00020b03, 0x00022681, 0x00024371, 0x000261e7, 0x000281f7, 0x0002a3b5,
+ 0x0002c73b, 0x0002ec9e, 0x000313f8, 0x00033d64, 0x000368fd, 0x000396e1, 0x0003c72e, 0x0003fa05,
+ 0x00042f89, 0x000467dd, 0x0004a326, 0x0004e18e, 0x0005233d, 0x00056861, 0x0005b126, 0x0005fdbf,
+ 0x00064e5f, 0x0006a33c, 0x0006fc8e, 0x00075a93, 0x0007bd89, 0x000825b3, 0x00089356, 0x000906bd,
+ 0x00098035, 0x000a000f, 0x000a86a2, 0x000b1447, 0x000ba95f, 0x000c464e, 0x000ceb7c, 0x000d9959,
+ 0x000e505a, 0x000f10f9, 0x000fdbb9, 0x0010b120, 0x001191c0, 0x00127e2f, 0x0013770b, 0x00147cfc,
+ 0x001590b2, 0x0016b2e7, 0x0017e45d, 0x001925e1, 0x001a784c, 0x001bdc81, 0x001d536f, 0x001ede14,
+ 0x00207d77, 0x002232af, 0x0023fee4, 0x0025e349, 0x0027e125, 0x0029f9ce, 0x002c2ead, 0x002e813e,
+ 0x0030f30f, 0x003385c7, 0x00363b1f, 0x003914e9, 0x003c1510, 0x003f3d97, 0x004290a1, 0x00461066,
+ 0x0049bf41, 0x004d9fad, 0x0051b444, 0x0055ffc3, 0x005a850f, 0x005f4730, 0x0064495a, 0x00698eeb,
+ 0x006f1b6c, 0x0074f299, 0x007b185f, 0x008190de, 0x00886074, 0x008f8bae, 0x00971762, 0x009f08a2,
+ 0x00a764c2, 0x00b03164, 0x00b9746e, 0x00c3341b, 0x00cd76fa, 0x00d843ed, 0x00e3a23b, 0x00ef9983,
+ 0x00fc31d2, 0x010973a0, 0x011767d1, 0x012617cf, 0x01358d70, 0x0145d31c, 0x0156f3c1, 0x0168fadf,
+ 0x017bf4a0, 0x018fedb6, 0x01a4f394, 0x01bb145a, 0x01d25ee1, 0x01eae2e5, 0x0204b0c8, 0x021fd9ed,
+ 0x023c7091, 0x025a87f9, 0x027a343d, 0x029b8ac5, 0x02bea1ee, 0x02e3914d, 0x030a71c2, 0x03335d4e,
+ 0x035e6f8d, 0x038bc56a, 0x03bb7d57, 0x03edb77c, 0x04229573, 0x045a3ae4, 0x0494cd29, 0x04d2739e,
+ 0x051357c7, 0x0557a519, 0x059f8997, 0x05eb358d, 0x063adbcc, 0x068eb1ff, 0x06e6f049, 0x0743d21b,
+ 0x07a595d9, 0x080c7d29, 0x0878cd66, 0x08eacf1a, 0x0962cf07, 0x09e11dcc, 0x0a661032, 0x0af1ffea,
+ 0x0b854a9a, 0x0c20536f, 0x0cc3828e, 0x0d6f4584, 0x0e241040, 0x0ee25bb0, 0x0faaa7f2, 0x107d7b9e,
+ 0x115b64be, 0x1244f787, 0x133ad1c6, 0x143d9885, 0x154df999, 0x166cac7a, 0x179a70d5, 0x18d81262,
+ 0x1a266657, 0x1b864d4c, 0x1cf8b43e, 0x1e7e9316, 0x2018f0b9, 0x21c8e0b1, 0x238f851d, 0x256e1046,
+ 0x2765c287, 0x2977ef55, 0x2ba5fab4, 0x2df15b8a, 0x305b9d83, 0x32e65ea3, 0x35935539, 0x38644d75,
+ 0x3b5b2b74, 0x3e79eea7, 0x41c2addc, 0x45379f60, 0x48db159c, 0x4caf81fa, 0x50b7797f, 0x54f5af2b,
+ 0x596cfe46, 0x5e2066e8, 0x631310c8, 0x684852d8, 0x6dc3a909, 0x7388c43d, 0x799b84b7, 0x7fffffff,
+]
+# fmt: on
+
+
+def test_saturating_rounding_mul():
+ i32info = np.iinfo(np.int32)
+ shift = 22
+ multiplier = 1760306048
+ assert fp_math.saturating_rounding_mul(i32info.min, i32info.min) == i32info.max
+ assert fp_math.saturating_rounding_mul(-255 * 1 << shift, multiplier) == -876714926
+ assert fp_math.saturating_rounding_mul(-128 * 1 << shift, multiplier) == -440076512
+ assert fp_math.saturating_rounding_mul(0, multiplier) == 0
+ assert fp_math.saturating_rounding_mul(128 * 1 << shift, multiplier) == 440076512
+ assert fp_math.saturating_rounding_mul(255 * 1 << shift, multiplier) == 876714926
+
+
+def test_shift_left():
+ i32info = np.iinfo(np.int32)
+ assert fp_math.shift_left(np.int32(1), i32info.bits) == i32info.max
+ assert fp_math.shift_left(np.int32(-1), i32info.bits) == i32info.min
+ assert fp_math.shift_left(np.int32(1), i32info.bits - 2) == (i32info.max + 1) / 2
+ assert fp_math.shift_left(np.int32(-1), i32info.bits - 2) == i32info.min // 2
+
+
+def test_rounding_divide_by_pot():
+ assert fp_math.rounding_divide_by_pot(1024, 4) == 64
+ assert fp_math.rounding_divide_by_pot(1031, 4) == 64
+ assert fp_math.rounding_divide_by_pot(1032, 4) == 65
+ assert fp_math.rounding_divide_by_pot(1047, 4) == 65
+ assert fp_math.rounding_divide_by_pot(1048, 4) == 66
+ assert fp_math.rounding_divide_by_pot(1056, 4) == 66
+ assert fp_math.rounding_divide_by_pot(-1024, 4) == -64
+ assert fp_math.rounding_divide_by_pot(-1031, 4) == -64
+ assert fp_math.rounding_divide_by_pot(-1032, 4) == -65
+ assert fp_math.rounding_divide_by_pot(-1047, 4) == -65
+ assert fp_math.rounding_divide_by_pot(-1048, 4) == -66
+ assert fp_math.rounding_divide_by_pot(-1056, 4) == -66
+
+
+def test_saturating_rounding_multiply_by_pot():
+ i32info = np.iinfo(np.int32)
+ assert fp_math.saturating_rounding_multiply_by_pot(4, np.int32(1025)) == 16400
+ assert fp_math.saturating_rounding_multiply_by_pot(5, np.int32(67108865)) == i32info.max
+ assert fp_math.saturating_rounding_multiply_by_pot(5, np.int32(-67108865)) == i32info.min
+
+
+def test_rescale():
+ assert fp_math.rescale(5, 0, np.int32(1025)) == 32800
+ assert fp_math.rescale(3, 0, np.int32(1025)) == 8200
+ assert fp_math.rescale(5, 1, np.int32(1025)) == 16400
+ assert fp_math.rescale(3, 1, np.int32(1025)) == 4100
+ with pytest.raises(AssertionError):
+ fp_math.rescale(1, 3, np.int32(1024))
+
+
+def test_exp():
+ sm = SoftMax(None)
+ for (expected, actual) in zip(EXP_LUT, sm.generate_exp_table(1.0, np.float32(0.05123165))):
+ assert actual == expected