diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 3ac78b25..62648914 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -110,6 +110,10 @@ class TFLiteSemantic: # Conv-like checks: for op_type in TFLiteSemantic.convolution_like_ops: self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type) + if op_type in TFLiteSemantic.convolution_ops: + # Only Conv has groups + self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_ifm_depth) + self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_num_filters) if op_type not in TFLiteSemantic.transpose_convolution_ops: # Transpose Conv does not contain dilation self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type) @@ -373,6 +377,36 @@ class TFLiteSemantic: return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}" @staticmethod + def constraint_conv_groups_ifm_depth(op): + """IFM depth must be a whole multiple of the filter kernel depth""" + ifm_depth = op.ifm.shape[-1] # nhwc + kernel_ic = op.weights.shape[-2] # hwio + num_conv_groups = ifm_depth // kernel_ic + + if ifm_depth % kernel_ic == 0: + op.attrs["num_conv_groups"] = num_conv_groups + valid = True + else: + valid = False + + return valid, f"IFM depth = {ifm_depth} and filter kernel depth = {kernel_ic}" + + @staticmethod + def constraint_conv_groups_num_filters(op): + """Number of filter kernels must be equally divisible by the number of convolution groups""" + ifm_depth = op.ifm.shape[-1] # nhwc + kernel_ic = op.weights.shape[-2] # hwio + kernel_oc = op.weights.shape[-1] # hwio + num_conv_groups = ifm_depth // kernel_ic + + if kernel_oc % num_conv_groups == 0: + valid = True + else: + valid = False + + return valid, f"Filter kernels = {kernel_oc} and convolution groups = {num_conv_groups}" + + @staticmethod def constraint_dilation_type(op): "Dilation factor values for both width and height must be integer types" w, h = op.get_kernel_dilation() |