diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 56dce14f..3ac78b25 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -702,7 +702,7 @@ class TFLiteSemantic: When IFM tensor is 3D or 4D: - Reduction in Batch axis is only supported if batch size is 1. - Reduction in both Height and Width axes is supported. - - Reduction in Depth axis is only supported if depth is 1.""" + - Reduction in Depth axis is supported if at least one of H,W,C are of size 1.""" input_shape = op.inputs[0].shape dims = len(input_shape) if op.inputs[1].shape == []: @@ -714,14 +714,22 @@ class TFLiteSemantic: for ax in axis: if ax < 0 or ax >= dims: return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. " - elif dims == 3: - # depth is only supported if size is 1 - if ax == 2 and input_shape[ax] != 1: + + # Batch is only supported if batch shape is 1 + if dims == 4 and ax == 0: + if input_shape[0] != 1: valid = False break - else: # 4D - # batch and depth are only supported if sizes are 1 - if ax in [0, 3] and input_shape[ax] != 1: + + # Depth is supported if any of h,w,c == 1 + if dims == 3: + if ax == 2 and not any([s == 1 for s in input_shape]): + valid = False + break + + # Depth is supported if any of h,w,c == 1 + if dims == 4: + if ax == 3 and not any([s == 1 for s in input_shape[1:]]): valid = False break |