aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py38
1 files changed, 30 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 444c04ad..56dce14f 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -696,14 +696,36 @@ class TFLiteSemantic:
@staticmethod
def constraint_mean_axis(op):
- "Axis indices must correspond to height and width axes"
- dims = len(op.inputs[0].shape)
- axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
- if dims == 2 or dims == 3:
- valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
- elif dims == 4:
- valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
- return valid, f"Axis is {axis}"
+ """Requirements for axis parameter:
+ When IFM tensor is 2D:
+ - Reduction in both axes is supported.
+ When IFM tensor is 3D or 4D:
+ - Reduction in Batch axis is only supported if batch size is 1.
+ - Reduction in both Height and Width axes is supported.
+ - Reduction in Depth axis is only supported if depth is 1."""
+ input_shape = op.inputs[0].shape
+ dims = len(input_shape)
+ if op.inputs[1].shape == []:
+ axis = [int(op.inputs[1].values)]
+ else:
+ axis = list(op.inputs[1].values)
+ valid = True
+
+ for ax in axis:
+ if ax < 0 or ax >= dims:
+ return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. "
+ elif dims == 3:
+ # depth is only supported if size is 1
+ if ax == 2 and input_shape[ax] != 1:
+ valid = False
+ break
+ else: # 4D
+ # batch and depth are only supported if sizes are 1
+ if ax in [0, 3] and input_shape[ax] != 1:
+ valid = False
+ break
+
+ return valid, f"Shape is {input_shape}, Axis is {axis}."
@staticmethod
def constraint_matching_in_out_quant(op):