aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-09-12 17:44:25 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-09-12 17:50:59 +0200
commitb393251346c7a5eef1496bcc92c722ba92c73fd9 (patch)
treec9068b98fef8d77c3f351f12e7cbccea2ba7d52f
parentb1872864b20588cc84280e5449158bb2168cc58b (diff)
downloadethos-u-vela-b393251346c7a5eef1496bcc92c722ba92c73fd9.tar.gz
MLBEDSW-6863: Cleanup the constraint for concat
Removed duplicate code and moved constraint to the correct file. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I2da3c5b88e1af351751c481217b8183b5948f0f8
-rw-r--r--ethosu/vela/tflite_model_semantic.py21
-rw-r--r--ethosu/vela/tflite_supported_operators.py46
2 files changed, 21 insertions, 46 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index abda886c..7a0e234d 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -126,6 +126,7 @@ class TFLiteSemantic:
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis)
# Element-wise checks:
for op_type in TFLiteSemantic.elem_wise_main_ops:
@@ -447,6 +448,26 @@ class TFLiteSemantic:
return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
@staticmethod
+ def constraint_valid_dimensions_axis(op):
+ """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
+ valid = True
+ extra = []
+ ofm_shape = op.ofm.shape
+ ofm_dim = len(ofm_shape)
+ axis = op.attrs["axis"]
+ axis += ofm_dim if axis < 0 else 0
+
+ sum_ifm_axis = 0
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ sum_ifm_axis += tens.shape[axis]
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+
+ valid = sum_ifm_axis == ofm_shape[axis]
+ extra = ", ".join(extra)
+ return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
+
+ @staticmethod
def constraint_stridedslice_input_count(op):
"Exactly 4 Input tensors are required"
inputs = len(op.inputs)
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 24cc26e0..be86e9a3 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -306,13 +306,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_before_mean)
- # Concat specific checks:
- for op_type in (Op.Concat, Op.ConcatTFLite):
- self.specific_constraints[op_type].append(
- TFLiteSupportedOperators.constraint_concat_valid_dimensions_non_axis
- )
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_concat_valid_dimensions_axis)
-
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -874,42 +867,3 @@ class TFLiteSupportedOperators:
if next_op is not None and next_op.type == Op.Mean:
return False, ""
return True, ""
-
- @staticmethod
- def constraint_concat_valid_dimensions_non_axis(op):
- """All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"""
- valid = True
- extra = []
- ofm_shape = op.ofm.shape
- ofm_dim = len(ofm_shape)
- axis = op.attrs["axis"]
- axis += ofm_dim if axis < 0 else 0
-
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
- valid = False
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
-
- extra = ", ".join(extra)
- return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
-
- @staticmethod
- def constraint_concat_valid_dimensions_axis(op):
- """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
- valid = True
- extra = []
- ofm_shape = op.ofm.shape
- ofm_dim = len(ofm_shape)
- axis = op.attrs["axis"]
- axis += ofm_dim if axis < 0 else 0
-
- sum_ifm_axis = 0
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- sum_ifm_axis += tens.shape[axis]
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
-
- valid = sum_ifm_axis == ofm_shape[axis]
- extra = ", ".join(extra)
- return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"