aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-03-01 09:53:35 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-03-06 12:02:27 +0100
commit3ac03be07f89655debd3cd4364d4ed9b22bfa507 (patch)
tree6314828cf62dce47c138c1bfb18debfdef4cb34c /ethosu/vela
parentc60b7e3a4cda84a196801baa407a0dcc5d39832b (diff)
downloadethos-u-vela-3ac03be07f89655debd3cd4364d4ed9b22bfa507.tar.gz
MLBEDSW-7396: MLCE: Add num elements constraint on reshape
Adding constraint for faulty reshape operators. Number of elements for IFM and OFM must be the same. Change-Id: I2e31e9d1e39b5aa3a0c595032a66e14374a0719e Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/tflite_model_semantic.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 0c2086c3..9f53a1e6 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -27,6 +27,7 @@ from .operation import Op
from .supported_operators_util import docstring_format_args
from .supported_operators_util import list_formatter
from .tensor import check_quantized_tens_scaling_equal
+from .tensor import shape_num_elements
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -148,6 +149,7 @@ class TFLiteSemantic:
# Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
for op_type in TFLiteSemantic.reshape_ops:
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_elements)
# Softmax specific checks:
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
@@ -664,6 +666,13 @@ class TFLiteSemantic:
return False, "IFM and OFM quantisation parameters are not equal."
return True, "IFM and OFM quantisation parameters matches."
+ @staticmethod
+ def constraint_matching_in_out_elements(op):
+ "Input and output number of elements must match."
+ if shape_num_elements(op.ifm.shape) != shape_num_elements(op.ofm.shape):
+ return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal."
+ return True, "IFM and OFM number of elements are equal."
+
def tflite_semantic_checker(nng):
semantic_checker = TFLiteSemantic()