aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-05-07 13:12:37 +0200
committerRickard Bolin <rickard.bolin@arm.com>2023-06-14 12:53:30 +0000
commit8e525ca8aad4a52b80c1986c5067b9b74fb3e321 (patch)
tree85f0a781d011aaf1c7631d34acea4398502a7a9f
parentc2c3063d05494f3968a7fead1b3118602fe100b9 (diff)
downloadethos-u-vela-8e525ca8aad4a52b80c1986c5067b9b74fb3e321.tar.gz
MLBEDSW-7748: Add RSQRT support
- Added RSQRT int8 support, implemented as LUT. - Added test to supported operators - Updated SUPPORTED_OPS.md Change-Id: I34904772e044be8d22a6dfe426edf85358a205b7 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--SUPPORTED_OPS.md11
-rw-r--r--ethosu/vela/lut.py79
-rw-r--r--ethosu/vela/operation.py2
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py12
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py4
-rw-r--r--ethosu/vela/tflite_mapping.py2
-rw-r--r--ethosu/vela/tflite_supported_operators.py10
7 files changed, 117 insertions, 3 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 0c3e1e0..0d42d9c 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -19,7 +19,7 @@ limitations under the License.
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `3.8.0`
+Vela version: `3.8.1.dev3+gc66541d.d20230613`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -64,6 +64,7 @@ Please check the supported operator list for your chosen runtime for further inf
| RESHAPE | [Generic](#tflite-generic-constraints), [Specific](#tflite-reshape-constraints) |
| RESIZE_BILINEAR | [Generic](#tflite-generic-constraints), [Specific](#tflite-resize_bilinear-constraints) |
| RESIZE_NEAREST_NEIGHBOR | [Generic](#tflite-generic-constraints), [Specific](#tflite-resize_nearest_neighbor-constraints) |
+| RSQRT | [Generic](#tflite-generic-constraints), [Specific](#tflite-rsqrt-constraints) |
| SHAPE | [Generic](#tflite-generic-constraints) |
| SLICE | [Generic](#tflite-generic-constraints) |
| SOFTMAX | [Generic](#tflite-generic-constraints), [Specific](#tflite-softmax-constraints) |
@@ -316,6 +317,14 @@ This is a list of constraints that the RESIZE_NEAREST_NEIGHBOR operator must sat
- The size tensor must match the output tensor shape
- Both align_corners and half_pixel_centers can't be True
+### TFLite RSQRT Constraints
+
+This is a list of constraints that the RSQRT operator must satisfy in order to be scheduled on the NPU.
+
+- At least one Input's shape must match the OFM's shape
+- IFM and OFM data types must match
+- IFM must be int8
+
### TFLite SOFTMAX Constraints
This is a list of constraints that the SOFTMAX operator must satisfy in order to be scheduled on the NPU.
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index c8fb7bc..e8759d9 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -20,6 +20,7 @@ import uuid
import numpy as np
+from . import fp_math
from . import numeric_util
from .data_type import DataType
from .debug_database import DebugDatabase
@@ -27,6 +28,7 @@ from .high_level_command_stream import DMA
from .high_level_command_stream import NpuStripe
from .numeric_util import round_away_zero
from .operation import Op
+from .scaling import quantise_scale
from .tensor import create_const_tensor
from .tensor import create_equivalence_id
from .tensor import QuantizationParameters
@@ -242,3 +244,80 @@ def create_lut_int16_op(op, lut_fn, fn_name):
lut[i] = slope + base
return convert_to_lut(op, lut, fn_name)
+
+
+def create_lut_rsqrt_int8_op(op):
+ # Turn off black formatting for the LUT tables to keep them compact
+ # fmt: off
+
+ # RSQRT_LUT has been generated by printing the output from the reference.
+ # These values are always the same but for some unknown reason it is not being
+ # implemented as a LUT in the reference.
+ # So based on the input range (-128, 127) the reference produces the following output:
+ RSQRT_LUT = [
+ 0x00000000, 0x00100000, 0x000b504e, 0x00093cd4, 0x00080000, 0x000727c9, 0x0006882f, 0x00060c24,
+ 0x0005a827, 0x00055555, 0x00050f45, 0x0004d2fe, 0x00049e6a, 0x00047007, 0x000446b4, 0x00042195,
+ 0x00040000, 0x0003e16d, 0x0003c570, 0x0003abb0, 0x000393e5, 0x00037dd2, 0x00036945, 0x00035613,
+ 0x00034418, 0x00033333, 0x0003234b, 0x00031447, 0x00030612, 0x0002f89c, 0x0002ebd3, 0x0002dfaa,
+ 0x0002d414, 0x0002c906, 0x0002be75, 0x0002b45a, 0x0002aaab, 0x0002a161, 0x00029875, 0x00028fe3,
+ 0x000287a2, 0x00027fb0, 0x00027807, 0x000270a2, 0x0002697f, 0x00026298, 0x00025bec, 0x00025577,
+ 0x00024f35, 0x00024925, 0x00024343, 0x00023d8e, 0x00023803, 0x000232a1, 0x00022d65, 0x0002284e,
+ 0x0002235a, 0x00021e87, 0x000219d5, 0x00021541, 0x000210cb, 0x00020c70, 0x00020831, 0x0002040c,
+ 0x00020000, 0x0001fc0c, 0x0001f82f, 0x0001f468, 0x0001f0b7, 0x0001ed1a, 0x0001e991, 0x0001e61b,
+ 0x0001e2b8, 0x0001df67, 0x0001dc26, 0x0001d8f7, 0x0001d5d8, 0x0001d2c8, 0x0001cfc8, 0x0001ccd6,
+ 0x0001c9f2, 0x0001c71c, 0x0001c454, 0x0001c198, 0x0001bee9, 0x0001bc46, 0x0001b9af, 0x0001b723,
+ 0x0001b4a3, 0x0001b22d, 0x0001afc2, 0x0001ad61, 0x0001ab0a, 0x0001a8bc, 0x0001a678, 0x0001a43e,
+ 0x0001a20c, 0x00019fe3, 0x00019dc2, 0x00019baa, 0x0001999a, 0x00019791, 0x00019590, 0x00019397,
+ 0x000191a5, 0x00018fbb, 0x00018dd7, 0x00018bfa, 0x00018a23, 0x00018853, 0x0001868a, 0x000184c6,
+ 0x00018309, 0x00018152, 0x00017fa0, 0x00017df4, 0x00017c4e, 0x00017aad, 0x00017911, 0x0001777b,
+ 0x000175e9, 0x0001745d, 0x000172d6, 0x00017153, 0x00016fd5, 0x00016e5b, 0x00016ce7, 0x00016b76,
+ 0x00016a0a, 0x000168a2, 0x0001673e, 0x000165de, 0x00016483, 0x0001632b, 0x000161d7, 0x00016087,
+ 0x00015f3b, 0x00015df2, 0x00015cad, 0x00015b6b, 0x00015a2d, 0x000158f2, 0x000157bb, 0x00015686,
+ 0x00015555, 0x00015427, 0x000152fd, 0x000151d5, 0x000150b0, 0x00014f8f, 0x00014e70, 0x00014d54,
+ 0x00014c3b, 0x00014b24, 0x00014a11, 0x00014900, 0x000147f1, 0x000146e5, 0x000145dc, 0x000144d5,
+ 0x000143d1, 0x000142cf, 0x000141d0, 0x000140d3, 0x00013fd8, 0x00013ee0, 0x00013de9, 0x00013cf5,
+ 0x00013c03, 0x00013b14, 0x00013a26, 0x0001393b, 0x00013851, 0x0001376a, 0x00013684, 0x000135a1,
+ 0x000134bf, 0x000133e0, 0x00013302, 0x00013226, 0x0001314c, 0x00013074, 0x00012f9e, 0x00012ec9,
+ 0x00012df6, 0x00012d25, 0x00012c55, 0x00012b87, 0x00012abb, 0x000129f1, 0x00012928, 0x00012860,
+ 0x0001279a, 0x000126d6, 0x00012613, 0x00012552, 0x00012492, 0x000123d4, 0x00012317, 0x0001225c,
+ 0x000121a2, 0x000120e9, 0x00012032, 0x00011f7c, 0x00011ec7, 0x00011e14, 0x00011d62, 0x00011cb1,
+ 0x00011c02, 0x00011b54, 0x00011aa7, 0x000119fb, 0x00011950, 0x000118a7, 0x000117ff, 0x00011758,
+ 0x000116b3, 0x0001160e, 0x0001156b, 0x000114c8, 0x00011427, 0x00011387, 0x000112e8, 0x0001124a,
+ 0x000111ad, 0x00011111, 0x00011076, 0x00010fdc, 0x00010f44, 0x00010eac, 0x00010e15, 0x00010d7f,
+ 0x00010cea, 0x00010c56, 0x00010bc4, 0x00010b32, 0x00010aa0, 0x00010a10, 0x00010981, 0x000108f3,
+ 0x00010865, 0x000107d9, 0x0001074d, 0x000106c2, 0x00010638, 0x000105af, 0x00010527, 0x0001049f,
+ 0x00010419, 0x00010393, 0x0001030e, 0x0001028a, 0x00010206, 0x00010183, 0x00010102, 0x00010080
+ ]
+
+ # Transform the above LUT so it gets the correct quantization (following the reference)
+ ifm_scale = op.ifm.quantization.scale_f32
+ ofm_scale = op.ofm.quantization.scale_f32
+ zp_in = op.ifm.quantization.zero_point
+ zp_out = op.ofm.quantization.zero_point
+
+ # Make sure zero point is valid
+ assert (-128 - zp_in) >= 0, f"Rsqrt is only defined for positive values, zeropoint is {zp_in}"
+
+ scale = np.double(1) / np.double(np.sqrt(ifm_scale) * ofm_scale)
+ output_multiplier, output_shift = quantise_scale(scale)
+
+ # Shift modification (value used in reference but Vela has opposite sign)
+ kshift = -20
+
+ ix = range(-128, 128)
+ quantized_min = min(ix)
+ quantized_max = max(ix)
+
+ # Any value close to 0 (zero index in LUT) is mapped to the max output value
+ values = [quantized_max]
+ for x in ix:
+ if x == -128:
+ # Value already populated above
+ continue
+ x_real = x - zp_in
+ val = RSQRT_LUT[x_real]
+ val = fp_math.multiply_by_quantized_multiplier(val, output_multiplier, output_shift - kshift) + zp_out
+ lut_result = min(quantized_max, max(quantized_min, val))
+ values.append(lut_result)
+
+ return convert_to_lut(op, values, "rsqrt")
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 52f06cf..da92311 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -263,7 +263,7 @@ class Op(Enum):
ReverseV2 = OperatorInfo()
Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Round = OperatorInfo()
- Rsqrt = OperatorInfo()
+ Rsqrt = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
ScatterNd = OperatorInfo()
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index cedf87a..4aca00d 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -662,3 +662,15 @@ def test_lstm_support():
op.inputs[23] = None
# Test restored valid configuration
assert support.is_operator_supported(op)
+
+
+def test_rsqrt_support():
+ # Test supported op (int8)
+ op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ assert support.is_operator_supported(op)
+ # Test not supported op (uint8)
+ op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint8)
+ assert not support.is_operator_supported(op)
+ # Test not supported op (int16)
+ op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
+ assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index daaca8d..99ac24e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -45,6 +45,7 @@ from .lstm import Lstm
from .lut import convert_to_lut
from .lut import create_lut_8bit_op
from .lut import create_lut_int16_op
+from .lut import create_lut_rsqrt_int8_op
from .numeric_util import clamp_sigmoid
from .numeric_util import full_shape
from .numeric_util import round_away_zero
@@ -2048,6 +2049,9 @@ def convert_ops_to_lut(op, arch, nng):
# Should already be catched in tflite supported ops
assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
+ if op.type == Op.Rsqrt:
+ return create_lut_rsqrt_int8_op(op)
+
return op
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index dda418c..83c55bb 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -793,7 +793,7 @@ builtin_operator_map = {
BuiltinOperator.LOG: (Op.Log, None, TFLITE_NO_INDICES),
BuiltinOperator.SUM: (Op.Sum, OptionsSerializer("ReducerOptions", ("keep_dims",)), TFLITE_NO_INDICES),
BuiltinOperator.SQRT: (Op.Sqrt, None, TFLITE_NO_INDICES),
- BuiltinOperator.RSQRT: (Op.Rsqrt, None, TFLITE_NO_INDICES),
+ BuiltinOperator.RSQRT: (Op.Rsqrt, None, TFLITE_IFM_INDICES),
BuiltinOperator.SHAPE: (
Op.Shape,
OptionsSerializer("ShapeOptions", (("out_type", datatype_deserialize, datatype_serialize),)),
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index b47104d..0dfdc66 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -326,6 +326,9 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+ # Rsqrt specific checks
+ self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_int8)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -911,3 +914,10 @@ class TFLiteSupportedOperators:
"All input and recurrent weights must be available"
valid = None not in op.inputs[1:9]
return valid, "Op has missing weights"
+
+ @staticmethod
+ def constraint_rsqrt_input_int8(op):
+ "IFM must be int8"
+ ifm_dtype = op.ifm.dtype
+ valid = ifm_dtype == DataType.int8
+ return valid, f"Op has ifm_dtype={ifm_dtype}"