aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py55
1 files changed, 49 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 1915d43b..f01a6690 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -304,6 +304,7 @@ class TFLiteSupportedOperators:
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
+ self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_before_mean)
# Concat specific checks:
for op_type in (Op.Concat, Op.ConcatTFLite):
@@ -795,10 +796,9 @@ class TFLiteSupportedOperators:
max_prod = cls.mean_kernel_product
return h * w <= max_prod, f"Product of height and width is {h * w}"
- @classmethod
- @docstring_format_args([mean_kernel_product_int8])
- def constraint_mean_height_width_product_int8(cls, op):
- """Product of IFM height and width must be no greater than {} when:
+ @staticmethod
+ def constraint_mean_height_width_product_int8(op):
+ """Number of IFM height and width elements might cause accumulator saturation when;
The IFM shape has 4 dimensions; and
The axis indices specify reduction across 2 dimensions; and
The axis indices correspond to the width and height dimensions of the IFM; and
@@ -817,8 +817,43 @@ class TFLiteSupportedOperators:
return True, ""
h = shape[-3]
w = shape[-2]
- max_prod = cls.mean_kernel_product_int8
- return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+ ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
+
+ # Scale factor
+ real_scale = ifmq.scale_f32 / ofmq.scale_f32
+
+ # Min and max value
+ ifm_min_val = np.iinfo(np.int8).min - ifmq.zero_point
+ ifm_max_val = np.iinfo(np.int8).max - ifmq.zero_point
+
+ # Accumulator limits
+ min_acc_limit = np.iinfo(np.int16).min
+ max_acc_limit = np.iinfo(np.int16).max
+
+ # Theoretical max/min value that accumulator need to store
+ min_acc_sum = h * w * ifm_min_val * real_scale + ofmq.zero_point
+ max_acc_sum = h * w * ifm_max_val * real_scale + ofmq.zero_point
+
+ # Max product of heigth and width that will not saturate the accumulator
+ ifm_min_val = 1 if ifm_min_val == 0 else ifm_min_val
+ ifm_max_val = 1 if ifm_max_val == 0 else ifm_max_val
+ if max_acc_sum > abs(min_acc_sum):
+ max_hw = int((max_acc_limit - ofmq.zero_point) / real_scale / ifm_max_val)
+ else:
+ max_hw = int((min_acc_limit - ofmq.zero_point) / real_scale / ifm_min_val)
+
+ extra = []
+
+ extra.append(f" Possible accumulator range is ({min_acc_sum} - {max_acc_sum})\n")
+ extra.append(f" Maximum accumulator range is ({min_acc_limit} - {max_acc_limit})\n")
+ extra.append(
+ f" Based on the IFM and OFM quantization the IFM height and width must be no greater than {max_hw}"
+ )
+
+ extra = "".join(extra)
+
+ return (min_acc_sum >= min_acc_limit and max_acc_sum <= max_acc_limit, f"\n{extra}")
@classmethod
@docstring_format_args([filter_height_range[1], dilated_height_range[1]])
@@ -867,6 +902,14 @@ class TFLiteSupportedOperators:
return valid, f"Op has non-const input(s): {extra}"
@staticmethod
+ def constraint_reshape_before_mean(op):
+ "Reshape on NPU not supported before MEAN operator"
+ for next_op in op.outputs[0].consumers():
+ if next_op is not None and next_op.type == Op.Mean:
+ return False, ""
+ return True, ""
+
+ @staticmethod
def constraint_concat_valid_dimensions_non_axis(op):
"""All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"""
valid = True