diff options
author | Johan Gunnarsson <johan.gunnarsson@arm.com> | 2023-08-04 17:16:29 +0200 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2023-08-09 11:21:24 +0000 |
commit | 81b765df02d7c7cae5f1084eec998824b68c00ab (patch) | |
tree | b7dc181f531fccc4106787c68fefe07e7b82b4c7 /ethosu/vela/tflite_model_semantic.py | |
parent | cd03504cfc29767d33d37b5c587116ab90752d74 (diff) | |
download | ethos-u-vela-81b765df02d7c7cae5f1084eec998824b68c00ab.tar.gz |
MLBEDSW-7626: Add constraint for PAD op paddings
PAD input tensor shape plus paddings must equal output tensor shape.
Change-Id: Icc5dea9bf6a8f6e1c8402f4d9af4d9796e8ef1aa
Signed-off-by: Johan Gunnarsson <johan.gunnarsson@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 62648914..ea7ef4a3 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -184,6 +184,7 @@ class TFLiteSemantic: # Pad specific checks: self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count) self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant) + self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_output_shape) # HardSwish specific checks: self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit) @@ -586,6 +587,16 @@ class TFLiteSemantic: return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}" @staticmethod + def constraint_pad_output_shape(op): + "Shape of output tensor must equal to size of input tensor plus padding" + input_shape = op.inputs[0].shape + expected_output_shape = op.outputs[0].shape + pad_tensor = op.inputs[1].values + actual_output_shape = input_shape + pad_tensor.T[0] + pad_tensor.T[1] + valid = np.array_equal(actual_output_shape, expected_output_shape) + return valid, f"Op has wrong output tensor shape: {expected_output_shape}, has shape: {actual_output_shape}" + + @staticmethod def constraint_stridedslice_inputs_const(op): "Begin, End and Stride Input tensors must be constant" valid = True |