diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-09-30 09:01:52 +0200 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-10-08 16:29:29 +0200 |
commit | aee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch) | |
tree | 495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/high_level_command_stream_generator.py | |
parent | 36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff) | |
download | ethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz |
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string
- Removed unused operator codes
- Refactored some attributes like npu_block_type, fused_activation_function
- Refactored operator index calculation
- Refactored a number of operator sets
Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/high_level_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/high_level_command_stream_generator.py | 26 |
1 files changed, 12 insertions, 14 deletions
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 8486dadc..dc52ae52 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -25,6 +25,7 @@ from .nn_graph import PassPlacement from .nn_graph import SchedulingStrategy from .numeric_util import round_up_divide from .operation import NpuBlockType +from .operation import Op from .tensor import TensorPurpose @@ -39,7 +40,7 @@ def match_tensor(source, derived): if source == derived: return True ops = derived.ops - return ops != [] and len(ops) == 1 and ops[0].type == "SplitSliceRead" and source == ops[0].inputs[0] + return ops != [] and len(ops) == 1 and ops[0].type == Op.SplitSliceRead and source == ops[0].inputs[0] def generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx): @@ -56,8 +57,8 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor for op in ps.ops: - if op.type == "SplitSliceRead": - ps.primary_op.attrs["fused_memory_function"] = op.type + if op.type == Op.SplitSliceRead: + ps.primary_op.memory_function = op.type assert len(op.inputs) == 1 if match_tensor(ps.ifm_tensor, op.inputs[0]): split_offsets[0] = op.attrs["split_start"] @@ -68,10 +69,10 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id else: ifm_idx = 0 for op in ps.ops: - if op.type == "SplitSliceRead": + if op.type == Op.SplitSliceRead: assert ifm_idx < 2 split_offsets[ifm_idx] = op.attrs["split_start"] - ps.primary_op.attrs["fused_memory_function"] = op.type + ps.primary_op.memory_function = op.type ifm_idx += 1 ifm_tensor = ps.ifm_tensor @@ -89,19 +90,16 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id if ps.primary_op is not None: strides = ps.primary_op.attrs.get("strides", None) skirt = ps.primary_op.attrs.get("skirt", None) - if ps.primary_op.type == "Conv2DBackpropInputSwitchedBias": + if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias: upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3] - elif ps.primary_op.type == "ResizeBilinear": + elif ps.primary_op.type == Op.ResizeBilinear: upscaling = round_up_divide(ofm_tensor.shape[-3], ifm_tensor.shape[-3]) concat_axis = 0 concat_offset = 0 - # Fusable activation functions - activation_ops = set(("Sigmoid", "Tanh", "Relu", "Relu6", "ReluN1To1")) - for op in ps.ops: - if op.type == "ConcatSliceWrite": + if op.type == Op.ConcatSliceWrite: concat_axis = op.attrs["concat_axis"] concat_start = op.attrs["concat_start"] concat_end = op.attrs["concat_end"] @@ -109,9 +107,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ofm_start[concat_axis] = concat_start ofm_end[concat_axis] = concat_end concat_offset = concat_start - ps.primary_op.attrs["fused_memory_function"] = op.type - elif op.type in activation_ops: - ps.primary_op.attrs["fused_activation_function"] = op.type + ps.primary_op.memory_function = op.type + elif op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid): + ps.primary_op.activation = op.type if strat == SchedulingStrategy.WeightStream: ofm_step = block_config[-1] |