aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/pass_packing.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/pass_packing.py')
-rw-r--r--ethosu/vela/pass_packing.py99
1 files changed, 35 insertions, 64 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index f49f9813..35e1b143 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -22,6 +22,7 @@ from .nn_graph import Pass
from .nn_graph import PassPlacement
from .operation import create_avgpool_nop
from .operation import NpuBlockType
+from .operation import Op
from .tensor import TensorPurpose
@@ -40,81 +41,57 @@ class PassFlags(enum.Flag):
PostFusingLimited = 8192
-npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
+npu_pre_ops = set((Op.SplitSliceRead,))
mac_main_ops = set(
(
# convolutions
- "Conv2DBiasAct",
- "Conv2D",
- "QuantizedConv2D",
- "Conv2DBackpropInputSwitchedBias",
+ Op.Conv2DBias,
+ Op.Conv2D,
+ Op.QuantizedConv2D,
+ Op.Conv2DBackpropInputSwitchedBias,
# depth-wise convolutions
- "DepthwiseConv2dBiasAct",
- "DepthwiseConv2dNative",
- "QuantizedDepthwiseConv2D",
+ Op.DepthwiseConv2DBias,
# FC layers
- "QuantizedMatMul",
- "MatMul",
- "FullyConnectedAct",
+ Op.QuantizedMatMul,
+ Op.MatMul,
+ Op.FullyConnected,
# RNN/LSTM/GRU
- "BlockLSTM",
+ Op.BlockLSTM,
# pooling
- "QuantizedMaxPool",
- "QuantizedAvgPool",
- "AvgPool",
- "MaxPool",
- "AvgPoolAct",
- "MaxPoolAct",
- "ReduceSum",
+ Op.QuantizedMaxPool,
+ Op.QuantizedAvgPool,
+ Op.AvgPool,
+ Op.MaxPool,
+ Op.ReduceSum,
# deconvolution
- "ResizeBilinear",
+ Op.ResizeBilinear,
)
)
-binary_elem_wise_main_ops = set(
- (
- # binary element-wise
- "AddAct",
- "MulAct",
- "SubAct",
- "QuantizedAdd",
- "QuantizedSub",
- "QuantizedMul",
- "Mul",
- "Add",
- "Sub",
- "Minimum",
- "Maximum",
- "SHL",
- "SHR",
- )
-)
+binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op)
-unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",)) # Unary element-wise operations
+unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op) # Unary element-wise operations
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
-activation_ops = set(("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1"))
-npu_post_ops = activation_ops | set(
- # Bias-add operations: Get rid of these once we have rewrites from Conv2D + BiasAdd + Activation to Conv2DBiasAct.
- ("Mul", "Add", "QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm")
-)
+activation_ops = Op.op_set(Op.is_relu_op)
+npu_post_ops = activation_ops
npu_post_fuse_limited_ops = set(
# Set of post operators that should not be fused with main/elementwise ops
- ("ConcatSliceWrite", "Sigmoid", "Tanh", "Quantize")
+ (Op.ConcatSliceWrite, Op.Sigmoid, Op.Tanh, Op.Quantize)
)
-elem_wise_ops = elem_wise_main_ops | activation_ops | set(("Sigmoid", "Tanh"))
+elem_wise_ops = elem_wise_main_ops | activation_ops | set((Op.Sigmoid, Op.Tanh))
-quantization_ops = set(("Dequantize", "QuantizeV2", "Max", "Min"))
-cpu_ops = set(("Softmax", "QuantizedSoftmax", "LRN", "Shape", "QuantizedPad", "Pad", "AddN")) | quantization_ops
+quantization_ops = set((Op.Dequantize, Op.Max, Op.Min))
+cpu_ops = set((Op.Softmax, Op.LRN, Op.Shape, Op.Pad, Op.AddN)) | quantization_ops
-npu_dma_ops = set(("DMA",))
-startup_init_ops = set(("Const", "VariableV2", "Placeholder", "SubgraphInput"))
-memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",))
+npu_dma_ops = set((Op.DMA,))
+startup_init_ops = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
+memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,))
test_sequence = [
@@ -234,10 +211,6 @@ test_sequence = [
for (operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear) in test_sequence:
assert not flags_to_clear & flags_to_set
- if operation_set is not None:
- for op in operation_set:
- assert len(op) > 1 # This is to avoid string literals being decomposed
-
def pack_into_passes(nng, arch, verbose_packing=False):
def visit_op(op, ignored):
@@ -254,7 +227,7 @@ def pack_into_passes(nng, arch, verbose_packing=False):
if op.type in startup_init_ops:
startup_list.append(op)
else:
- _, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
+ ofm_tensor = op.ofm
if ofm_tensor is None:
ofm_tensor = op.outputs[0]
build_pass((op,), ofm_tensor)
@@ -287,7 +260,7 @@ def pack_into_passes(nng, arch, verbose_packing=False):
continue
reverse_ops_list.append(curr_op)
- new_block_type = curr_op.attrs.get("npu_block_type", NpuBlockType.Default)
+ new_block_type = curr_op.type.npu_block_type
if new_block_type != NpuBlockType.Default:
assert npu_block_type == NpuBlockType.Default
npu_block_type = new_block_type # Only one major block type per pass
@@ -302,10 +275,8 @@ def pack_into_passes(nng, arch, verbose_packing=False):
PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
):
assert len(curr_op.inputs) >= 1
- if curr_op.type == "BlockLSTM":
- ifm_tensor = curr_op.inputs[3]
- else:
- ifm_tensor = curr_op.inputs[0]
+ ifm_tensor = curr_op.ifm
+ assert ifm_tensor is not None
assert ifm_tensor.purpose == TensorPurpose.FeatureMap
if flags_to_set & PassFlags.Dma:
@@ -377,7 +348,7 @@ def pack_into_passes(nng, arch, verbose_packing=False):
primary_op = create_primary_op(ops_list)
if primary_op is not None:
visit_tensor_refcount[primary_op.inputs[0]] += 1
- npu_block_type = primary_op.attrs["npu_block_type"]
+ npu_block_type = primary_op.type.npu_block_type
for input_tens in primary_op.inputs:
if input_tens not in input_set:
input_set.add(input_tens)
@@ -394,7 +365,7 @@ def pack_into_passes(nng, arch, verbose_packing=False):
for inp in primary_op.inputs:
if inp is None:
continue
- if len(inp.ops) == 1 and inp.ops[0].type == "DMA" and inp.purpose == TensorPurpose.FeatureMap:
+ if len(inp.ops) == 1 and inp.ops[0].type == Op.DMA and inp.purpose == TensorPurpose.FeatureMap:
src_op = inp.ops[0]
if src_op in input_ops_list:
inp = src_op.inputs[0]
@@ -408,7 +379,7 @@ def pack_into_passes(nng, arch, verbose_packing=False):
add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list)
name = ops_list[0].name
- non_dma_ops = [op for op in ops_list if op.type != "DMA"]
+ non_dma_ops = [op for op in ops_list if op.type != Op.DMA]
if non_dma_ops:
name = non_dma_ops[0].name
ps = Pass(name, placement, is_element_wise, npu_block_type)