From df99510f04aef99d1b8e9be9bfcde8fc1738b65f Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Mon, 23 Aug 2021 15:33:59 +0200 Subject: TOSA: Added Depthwise support This is mainly to add support for depthwise conv2d with dephmultiplier = 1. (But there are no testcases suited, all I have sourced has depth_multiplier set to 2, which is not supported.) -Added support for depthwise conv2d. -Added support for removing Transpose of constant data -Added support for removing reshape Signed-off-by: Patrik Gustavsson Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7 --- ethosu/vela/graph_optimiser_util.py | 114 +++++++++++++++++++++++++++++++- ethosu/vela/operation.py | 2 +- ethosu/vela/tflite_graph_optimiser.py | 86 ++---------------------- ethosu/vela/tosa_graph_optimiser.py | 60 ++++++++++++++++- ethosu/vela/tosa_mapping.py | 6 +- ethosu/vela/tosa_reader.py | 19 +++++- ethosu/vela/tosa_supported_operators.py | 41 ++++++++++-- ethosu/vela/vela.py | 2 +- 8 files changed, 235 insertions(+), 95 deletions(-) diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 570c7244..d01d4a19 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -17,14 +17,17 @@ # Common functions and definitions used during the graph optimization. from typing import Tuple +import numpy as np + from .data_type import DataType from .debug_database import DebugDatabase +from .errors import UnsupportedFeatureError from .errors import VelaError from .operation import Op +from .operation_util import create_avgpool_nop from .shape4d import Shape4D from .tensor import check_quantized_tens_scaling_equal - memory_only_ops = ( Op.Reshape, Op.Squeeze, @@ -174,6 +177,41 @@ def set_ifm_ofm_op_shapes(op, arch, nng): return op +def bypass_reshape_and_squeeze_ops(op): + assert op.type in (Op.Reshape, Op.Squeeze) + ofm = op.ofm + ifm = op.ifm + # Check if ifm/ofm are network ifm/ofm + ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) + ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list) + ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list) + # Check if ifm/ofm is produced respectively consumed by CPU + ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) + ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) + + # This case should be handled prior to this function + assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed)) + + if ofm_is_sg_ofm or ofm_is_cpu_consumed: + # Bypassed by replacing ifm with ofm + ofm.ops = [] + for prev_op in ifm.ops: + prev_op.outputs = [ofm] + ofm.ops.append(prev_op) + + # All ifm consumers need to use ofm as input + for ifm_cons in ifm.consumer_list: + for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs): + if cons_ifm == ifm: + ifm_cons.set_input_tensor(ofm, ifm_idx) + else: + # Bypassed by replacing ofm with ifm + for cons in ofm.consumer_list: + for ifm_idx, cons_ifm in enumerate(cons.inputs): + if cons_ifm == ofm: + cons.set_input_tensor(ifm, ifm_idx) + + def check_reshapes(op, arch): if op.run_on_npu and op.type == Op.Reshape: ofm = op.ofm @@ -186,3 +224,77 @@ def check_reshapes(op, arch): def record_optimised(op, arch): if op.type != Op.Const: DebugDatabase.add_optimised(op, op) + + +def insert_copy_op_after_tens(tens): + tens_cons_list_copy = tens.consumer_list.copy() + + # Create a avg_pool nop op with ifm as input + copy_tens = tens.clone() + copy_op = create_avgpool_nop(tens.name + "_avgpool") + copy_op.add_input_tensor(tens) + copy_op.set_output_tensor(copy_tens) + copy_op.set_ifm_ofm_shapes() + copy_op.run_on_npu = True + + # Set copy_ifm consumers + for tens_cons in tens_cons_list_copy: + if tens_cons is not None: + for ifm_idx, cons_inp in enumerate(tens_cons.inputs): + if cons_inp == tens: + tens_cons.set_input_tensor(copy_tens, ifm_idx) + + DebugDatabase.add_optimised(tens.ops[0], copy_op) + + +def fix_sg_input_output(op, arch, nng): + if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze): + return op + + # For the Reshape/Squeeze operators we want to remove, tensors are removed. + # But in order to to do this, they cannot be outputs of the sg, + # this need to be fixed prior to the removal. + # Solution is to add a avgpool NOP, to maintain the original tensor. + # This is also valid when reshape ifm/ofm is produced respectively + # consumed by CPU + + # Check if operator ifm/ofm are sg ifm/ofm + ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) + ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list) + ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list) + # Check if ifm/ofm is produced respectively consumed by CPU + ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) + ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) + + if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed): + # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze + insert_copy_op_after_tens(op.ifm) + + return op + + +def convert_depthwise_to_conv(op, arch, nng): + # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and + # the ofm depth equals the depth multipler. + # If those conditions are true, then we can perform a simple + # switch of the operator type (and weight order) + + if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1): + ifm_shape = op.ifm_shapes[0] + weight_tensor = op.inputs[1] + ofm_shape = op.ofm_shapes[0] + if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]): + # Change op type to Conv2d + op.type = Op.Conv2DBias + del op.attrs["channel_multiplier"] + del op.attrs["depth_multiplier"] + + weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2)) + weight_tensor.set_all_shapes(list(weight_tensor.values.shape)) + else: + raise UnsupportedFeatureError( + f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},", + f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}", + ) + DebugDatabase.add_optimised(op, op) + return op diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index ca833e21..80be228b 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -283,7 +283,7 @@ class Op(Enum): Tanh = OperatorInfo(indices=NNG_IFM_INDICES) Tile = OperatorInfo() TopKV2 = OperatorInfo() - Transpose = OperatorInfo() + Transpose = OperatorInfo(indices=NNG_IFM_INDICES) UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) Unique = OperatorInfo() diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 255a1f5e..ef39aea3 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -30,7 +30,10 @@ from .data_type import DataType from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError from .ethos_u55_regs.ethos_u55_regs import resampling_mode +from .graph_optimiser_util import bypass_reshape_and_squeeze_ops from .graph_optimiser_util import calc_explicit_padding +from .graph_optimiser_util import convert_depthwise_to_conv +from .graph_optimiser_util import fix_sg_input_output from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence @@ -244,32 +247,6 @@ def insert_copy_op_after_tens(tens): DebugDatabase.add_optimised(tens.ops[0], copy_op) -def fix_sg_input_output(op, arch, nng): - if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze): - return op - - # For the Reshape/Squeeze operators we want to remove, tensors are removed. - # But in order to to do this, they cannot be outputs of the sg, - # this need to be fixed prior to the removal. - # Solution is to add a avgpool NOP, to maintain the original tensor. - # This is also valid when reshape ifm/ofm is produced respectively - # consumed by CPU - - # Check if operator ifm/ofm are sg ifm/ofm - ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) - ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list) - ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list) - # Check if ifm/ofm is produced respectively consumed by CPU - ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) - ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) - - if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed): - # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze - insert_copy_op_after_tens(op.ifm) - - return op - - def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding): k_w, k_h = kernel.dilated_wh() s_x, s_y = kernel.stride @@ -576,33 +553,6 @@ def add_padding_fields(op, arch, nng): return op -def convert_depthwise_to_conv(op, arch, nng): - # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and - # the ofm depth equals the depth multipler. - # If those conditions are true, then we can perform a simple - # switch of the operator type (and weight order) - - if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1): - ifm_shape = op.ifm_shapes[0] - weight_tensor = op.inputs[1] - ofm_shape = op.ofm_shapes[0] - if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]): - # Change op type to Conv2d - op.type = Op.Conv2DBias - del op.attrs["channel_multiplier"] - del op.attrs["depth_multiplier"] - - weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2)) - weight_tensor.set_all_shapes(list(weight_tensor.values.shape)) - else: - raise UnsupportedFeatureError( - f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},", - f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}", - ) - DebugDatabase.add_optimised(op, op) - return op - - def reorder_depthwise_weights(op, arch, nng): if op.type.is_depthwise_conv2d_op(): weight_tensor = op.inputs[1] @@ -1058,35 +1008,7 @@ def remove_reshape_and_squeeze_ops(op, arch): # or the reshape need to be replace with a NOP. return - # Check if ifm/ofm are network ifm/ofm - ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) - ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list) - ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list) - # Check if ifm/ofm is produced respectively consumed by CPU - ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) - ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) - - # This case should be handled prior to this function - assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed)) - - if ofm_is_sg_ofm or ofm_is_cpu_consumed: - # Bypassed by replacing ifm with ofm - ofm.ops = [] - for prev_op in ifm.ops: - prev_op.outputs = [ofm] - ofm.ops.append(prev_op) - - # All ifm consumers need to use ofm as input - for ifm_cons in ifm.consumer_list: - for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs): - if cons_ifm == ifm: - ifm_cons.set_input_tensor(ofm, ifm_idx) - else: - # Bypassed by replacing ofm with ifm - for cons in ofm.consumer_list: - for ifm_idx, cons_ifm in enumerate(cons.inputs): - if cons_ifm == ofm: - cons.set_input_tensor(ifm, ifm_idx) + bypass_reshape_and_squeeze_ops(op) def fuse_activation_function_with_prev(op, arch, nng): diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 44e0f8ec..169da40d 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -19,7 +19,10 @@ from . import rewrite_graph from .api import NpuRoundingMode from .data_type import DataType from .debug_database import DebugDatabase +from .graph_optimiser_util import bypass_reshape_and_squeeze_ops from .graph_optimiser_util import calc_explicit_padding +from .graph_optimiser_util import convert_depthwise_to_conv +from .graph_optimiser_util import fix_sg_input_output from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence @@ -74,6 +77,43 @@ def add_padding_fields(op, arch, nng): return op +def remove_const_transpose(op, arch, nng): + if op.type == Op.Transpose: + removed = False + if len(op.ifm.ops) == 1: + prev_op = op.ifm.ops[0] + if prev_op.type == Op.Const: + # Transpose the Tensor and data and remove Transpose + # TODO move to Tensor? + reorder = op.attrs["perms"] + shape = op.ifm.shape.copy() + tens = op.ifm + + tens.shape = [shape[idx] for idx in reorder] + tens.bandwidth_shape = tens.shape + tens.storage_shape = tens.shape + + if tens.values is not None: + tens.values = tens.values.transpose(reorder) + + op.ofm.values = tens.values + # Bypass the Transpose op + prev_op.set_output_tensor(op.ofm) + DebugDatabase.add_optimised(op, prev_op) + removed = True + + if not removed: + print("Cannot remove Transpose, and handling of Transpose is not supported") + assert False + + return op + + +def remove_reshapes(op, arch): + if op.run_on_npu and op.type == Op.Reshape: + bypass_reshape_and_squeeze_ops(op) + + def rewrite_activation(op, arch, nng): if op.type not in (Op.ReluN, Op.Clamp): return op @@ -206,6 +246,7 @@ def fixup_quantization(op, arch, nng): def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op) + assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const) return op @@ -221,8 +262,25 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, ) + # Removal of Transpose + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False, + ) + + # Handle sg input output + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False, + ) + + # Removal of reshapes + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) + sg.refresh_after_modification() + # Rewite Operators step - op_rewrite_list = [set_tensor_equivalence, rewrite_rescale] + op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv] for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py index 75ca43ef..5d0dd33d 100644 --- a/ethosu/vela/tosa_mapping.py +++ b/ethosu/vela/tosa_mapping.py @@ -221,11 +221,9 @@ unsupported_tosa_operators = { TosaOp.REDUCE_SUM, TosaOp.CONCAT, TosaOp.PAD, - TosaOp.RESHAPE, TosaOp.REVERSE, TosaOp.SLICE, TosaOp.TILE, - TosaOp.TRANSPOSE, TosaOp.GATHER, TosaOp.SCATTER, TosaOp.RESIZE, @@ -304,11 +302,11 @@ tosa_operator_map = { # TODO TosaOp.REDUCE_SUM # TODO TosaOp.CONCAT # TODO TosaOp.PAD - # TODO TosaOp.RESHAPE + TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.REVERSE # TODO TosaOp.SLICE # TODO TosaOp.TILE - # TODO TosaOp.TRANSPOSE + TosaOp.TRANSPOSE: (Op.Transpose, None, None, TOSA_IFM_IFM2_INDICES), # TODO TosaOp.GATHER # TODO TosaOp.SCATTER # TODO TosaOp.RESIZE diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index eb317169..268d43ce 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -33,9 +33,11 @@ from .tensor import QuantizationParameters from .tensor import shape_num_elements from .tensor import Tensor from .tflite_mapping import DataType +from .tosa.Op import Op as TosaOp from .tosa.TosaGraph import TosaGraph as TG from .tosa_mapping import datatype_map from .tosa_mapping import datatype_map_numpy +from .tosa_mapping import TOSA_IFM_INDICES from .tosa_mapping import tosa_operator_map from .tosa_mapping import unsupported_tosa_operators @@ -89,7 +91,7 @@ class TosaSubgraph: op_code = op_data.Op() if op_code in unsupported_tosa_operators: print("Unsupported Operator", op_code) - assert False + return op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code] inputs = [] @@ -104,6 +106,15 @@ class TosaSubgraph: outputs.append(output_tens) assert output_tens is not None + # Permutation attribute for TRANSPOSE is an input tensor in TOSA + # TODO In order to optimise Depthwise spawning from TFLite Support for removing + # Transpose of constant data. + # Moving permutation to an attribute, to match internal graph representation for now + perms = None + if op_code == TosaOp.TRANSPOSE: + perms = perms = inputs.pop(1) + indices = TOSA_IFM_INDICES + name = "unknown_op_name" if len(outputs): name = outputs[0].name @@ -148,6 +159,7 @@ class TosaSubgraph: stride = op.attrs["stride"] if len(stride) == 2: op.attrs["strides"] = (1, stride[0], stride[1], 1) + del op.attrs["stride"] else: # TODO CONV3D more to be done.... print("Unsupported kernel dimensions: ", len(stride)) @@ -167,6 +179,11 @@ class TosaSubgraph: # TODO CONV3D more to be done.... print("Unsupported kernel dimensions: ", len(kernel)) assert False + if op.type.is_depthwise_conv2d_op(): + op.attrs["depth_multiplier"] = op.weights.shape[3] + + elif op.type == Op.Transpose: + op.attrs["perms"] = perms.values if quant_serializer is not None: quant_info = quant_serializer.deserialize(op_data) diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 3b0e6b39..90d54687 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -28,19 +28,24 @@ class TosaSupportedOperators: # TODO currently sparsely populated # Categorised lists of supported operators convolution_ops = set((Op.Conv2DBias,)) - convolution_like_ops = convolution_ops + depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,)) + convolution_like_ops = convolution_ops | depthwise_convolution_ops + + # TODO depending on what will be committed max_pooling_ops = Op.op_set(Op.is_maxpool_op) avg_pooling_ops = Op.op_set(Op.is_avgpool_op) - pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops + pooling_ops = max_pooling_ops | avg_pooling_ops + fc_vector_products = set((Op.FullyConnected,)) - mac_main_ops = convolution_like_ops | pooling_ops + mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products + memory_only_ops = set((Op.Reshape, Op.Transpose,)) type_conversion_ops = set((Op.Rescale,)) relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops npu_post_ops = activation_ops - supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops + supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same @@ -54,6 +59,12 @@ class TosaSupportedOperators: # Setup specific constraints. Note: the order matters self.specific_constraints = defaultdict(list) + self.specific_constraints[Op.Transpose].append(TosaSupportedOperators.constraint_ifm_producer) + + # Depthwise Conv specific checks: + for op_type in TosaSupportedOperators.depthwise_convolution_ops: + self.specific_constraints[op_type].append(TosaSupportedOperators.constraint_depth_multiplier) + def is_operator_supported(self, op): ext_type = optype_to_tosa_op_type(op.type) if op.type not in TosaSupportedOperators.supported_operators: @@ -87,3 +98,25 @@ class TosaSupportedOperators: valid = False extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}") return valid, ", ".join(extra) + + @staticmethod + def constraint_ifm_producer(cls, op): + "Input must be constant data" + valid = op.ifm.ops and op.ifm.ops[0].type == Op.Const + return valid, "Op has ifm with non-constant data" + + # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage + @staticmethod + def constraint_depth_multiplier(op): + "For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier" + depth_multiplier = op.attrs.get("depth_multiplier", 1) + if depth_multiplier > 1: + ifm_channels = op.ifm.shape[3] + ofm_channels = op.ofm.shape[3] + valid = (ifm_channels == 1) and (ofm_channels == depth_multiplier) + extra = ( + f"Op has ifm_channels={ifm_channels}, ofm_channels={ofm_channels}" + f" and depth_multiplier={depth_multiplier}" + ) + return valid, extra + return True, "Op has depth_multiplier=1" diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py index 94487499..6c9fbce2 100644 --- a/ethosu/vela/vela.py +++ b/ethosu/vela/vela.py @@ -85,7 +85,7 @@ def process(input_name, enable_debug_db, arch, model_reader_options, compiler_op ) output_tfl_filename = output_basename + "_vela.tflite" - if input_name.endswith(".tflite"): + if input_name.endswith(".tflite") or input_name.endswith(".tosa"): tflite_writer.write_tflite(nng, output_tfl_filename) if input_name.endswith(".tosa"): rawdata_writer.write_rawdata_output(nng, arch, output_basename) -- cgit v1.2.1