diff options
Diffstat (limited to 'ethosu/vela/pass_packing.py')
-rw-r--r-- | ethosu/vela/pass_packing.py | 99 |
1 files changed, 35 insertions, 64 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index f49f9813..35e1b143 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -22,6 +22,7 @@ from .nn_graph import Pass from .nn_graph import PassPlacement from .operation import create_avgpool_nop from .operation import NpuBlockType +from .operation import Op from .tensor import TensorPurpose @@ -40,81 +41,57 @@ class PassFlags(enum.Flag): PostFusingLimited = 8192 -npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",)) +npu_pre_ops = set((Op.SplitSliceRead,)) mac_main_ops = set( ( # convolutions - "Conv2DBiasAct", - "Conv2D", - "QuantizedConv2D", - "Conv2DBackpropInputSwitchedBias", + Op.Conv2DBias, + Op.Conv2D, + Op.QuantizedConv2D, + Op.Conv2DBackpropInputSwitchedBias, # depth-wise convolutions - "DepthwiseConv2dBiasAct", - "DepthwiseConv2dNative", - "QuantizedDepthwiseConv2D", + Op.DepthwiseConv2DBias, # FC layers - "QuantizedMatMul", - "MatMul", - "FullyConnectedAct", + Op.QuantizedMatMul, + Op.MatMul, + Op.FullyConnected, # RNN/LSTM/GRU - "BlockLSTM", + Op.BlockLSTM, # pooling - "QuantizedMaxPool", - "QuantizedAvgPool", - "AvgPool", - "MaxPool", - "AvgPoolAct", - "MaxPoolAct", - "ReduceSum", + Op.QuantizedMaxPool, + Op.QuantizedAvgPool, + Op.AvgPool, + Op.MaxPool, + Op.ReduceSum, # deconvolution - "ResizeBilinear", + Op.ResizeBilinear, ) ) -binary_elem_wise_main_ops = set( - ( - # binary element-wise - "AddAct", - "MulAct", - "SubAct", - "QuantizedAdd", - "QuantizedSub", - "QuantizedMul", - "Mul", - "Add", - "Sub", - "Minimum", - "Maximum", - "SHL", - "SHR", - ) -) +binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op) -unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",)) # Unary element-wise operations +unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op) # Unary element-wise operations elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops -activation_ops = set(("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1")) -npu_post_ops = activation_ops | set( - # Bias-add operations: Get rid of these once we have rewrites from Conv2D + BiasAdd + Activation to Conv2DBiasAct. - ("Mul", "Add", "QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm") -) +activation_ops = Op.op_set(Op.is_relu_op) +npu_post_ops = activation_ops npu_post_fuse_limited_ops = set( # Set of post operators that should not be fused with main/elementwise ops - ("ConcatSliceWrite", "Sigmoid", "Tanh", "Quantize") + (Op.ConcatSliceWrite, Op.Sigmoid, Op.Tanh, Op.Quantize) ) -elem_wise_ops = elem_wise_main_ops | activation_ops | set(("Sigmoid", "Tanh")) +elem_wise_ops = elem_wise_main_ops | activation_ops | set((Op.Sigmoid, Op.Tanh)) -quantization_ops = set(("Dequantize", "QuantizeV2", "Max", "Min")) -cpu_ops = set(("Softmax", "QuantizedSoftmax", "LRN", "Shape", "QuantizedPad", "Pad", "AddN")) | quantization_ops +quantization_ops = set((Op.Dequantize, Op.Max, Op.Min)) +cpu_ops = set((Op.Softmax, Op.LRN, Op.Shape, Op.Pad, Op.AddN)) | quantization_ops -npu_dma_ops = set(("DMA",)) -startup_init_ops = set(("Const", "VariableV2", "Placeholder", "SubgraphInput")) -memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) +npu_dma_ops = set((Op.DMA,)) +startup_init_ops = set((Op.Const, Op.Placeholder, Op.SubgraphInput)) +memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) test_sequence = [ @@ -234,10 +211,6 @@ test_sequence = [ for (operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear) in test_sequence: assert not flags_to_clear & flags_to_set - if operation_set is not None: - for op in operation_set: - assert len(op) > 1 # This is to avoid string literals being decomposed - def pack_into_passes(nng, arch, verbose_packing=False): def visit_op(op, ignored): @@ -254,7 +227,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): if op.type in startup_init_ops: startup_list.append(op) else: - _, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm() + ofm_tensor = op.ofm if ofm_tensor is None: ofm_tensor = op.outputs[0] build_pass((op,), ofm_tensor) @@ -287,7 +260,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): continue reverse_ops_list.append(curr_op) - new_block_type = curr_op.attrs.get("npu_block_type", NpuBlockType.Default) + new_block_type = curr_op.type.npu_block_type if new_block_type != NpuBlockType.Default: assert npu_block_type == NpuBlockType.Default npu_block_type = new_block_type # Only one major block type per pass @@ -302,10 +275,8 @@ def pack_into_passes(nng, arch, verbose_packing=False): PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited ): assert len(curr_op.inputs) >= 1 - if curr_op.type == "BlockLSTM": - ifm_tensor = curr_op.inputs[3] - else: - ifm_tensor = curr_op.inputs[0] + ifm_tensor = curr_op.ifm + assert ifm_tensor is not None assert ifm_tensor.purpose == TensorPurpose.FeatureMap if flags_to_set & PassFlags.Dma: @@ -377,7 +348,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): primary_op = create_primary_op(ops_list) if primary_op is not None: visit_tensor_refcount[primary_op.inputs[0]] += 1 - npu_block_type = primary_op.attrs["npu_block_type"] + npu_block_type = primary_op.type.npu_block_type for input_tens in primary_op.inputs: if input_tens not in input_set: input_set.add(input_tens) @@ -394,7 +365,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): for inp in primary_op.inputs: if inp is None: continue - if len(inp.ops) == 1 and inp.ops[0].type == "DMA" and inp.purpose == TensorPurpose.FeatureMap: + if len(inp.ops) == 1 and inp.ops[0].type == Op.DMA and inp.purpose == TensorPurpose.FeatureMap: src_op = inp.ops[0] if src_op in input_ops_list: inp = src_op.inputs[0] @@ -408,7 +379,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list) name = ops_list[0].name - non_dma_ops = [op for op in ops_list if op.type != "DMA"] + non_dma_ops = [op for op in ops_list if op.type != Op.DMA] if non_dma_ops: name = non_dma_ops[0].name ps = Pass(name, placement, is_element_wise, npu_block_type) |