aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorJohan Gunnarsson <johan.gunnarsson@arm.com>2023-08-04 17:16:29 +0200
committerRickard Bolin <rickard.bolin@arm.com>2023-08-09 11:21:24 +0000
commit81b765df02d7c7cae5f1084eec998824b68c00ab (patch)
treeb7dc181f531fccc4106787c68fefe07e7b82b4c7 /ethosu
parentcd03504cfc29767d33d37b5c587116ab90752d74 (diff)
downloadethos-u-vela-81b765df02d7c7cae5f1084eec998824b68c00ab.tar.gz
MLBEDSW-7626: Add constraint for PAD op paddings
PAD input tensor shape plus paddings must equal output tensor shape. Change-Id: Icc5dea9bf6a8f6e1c8402f4d9af4d9796e8ef1aa Signed-off-by: Johan Gunnarsson <johan.gunnarsson@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py12
-rw-r--r--ethosu/vela/tflite_model_semantic.py11
2 files changed, 23 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index e7fd3073..7ca1bbda 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -356,6 +356,18 @@ def test_constraint_pad_input_count():
assert not semantic_checker.is_operator_semantic_valid(op)
+def test_constraint_pad_output_shape():
+ # Incorrect output tensor shape
+ op = create_pad_op(
+ in_shape=[1, 1, 1, 1],
+ out_shape=[1, 3, 3, 1],
+ padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
+ )
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.outputs[0].shape = [1, 1, 1, 1]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+
+
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 62648914..ea7ef4a3 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -184,6 +184,7 @@ class TFLiteSemantic:
# Pad specific checks:
self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
+ self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_output_shape)
# HardSwish specific checks:
self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
@@ -586,6 +587,16 @@ class TFLiteSemantic:
return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
@staticmethod
+ def constraint_pad_output_shape(op):
+ "Shape of output tensor must equal to size of input tensor plus padding"
+ input_shape = op.inputs[0].shape
+ expected_output_shape = op.outputs[0].shape
+ pad_tensor = op.inputs[1].values
+ actual_output_shape = input_shape + pad_tensor.T[0] + pad_tensor.T[1]
+ valid = np.array_equal(actual_output_shape, expected_output_shape)
+ return valid, f"Op has wrong output tensor shape: {expected_output_shape}, has shape: {actual_output_shape}"
+
+ @staticmethod
def constraint_stridedslice_inputs_const(op):
"Begin, End and Stride Input tensors must be constant"
valid = True