diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2020-06-03 15:43:31 +0200 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-08-05 16:26:04 +0200 |
commit | a0c3624899edc601525a589643c802469003f89d (patch) | |
tree | 4fc52db04cd29901b3e5d4a7425a7a641e9647fb /ethosu/vela/high_level_command_stream_generator.py | |
parent | 9a03fdff316662be69a1adc4e391e43bc6519b08 (diff) | |
download | ethos-u-vela-a0c3624899edc601525a589643c802469003f89d.tar.gz |
[MLBEDSW-2335] SoftMax int16
Added graph rewrite of Softmax for int16.
Change-Id: Id7885af6056a23e8b8362fb61ae94283251eb398
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/high_level_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/high_level_command_stream_generator.py | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 6aa88d86..2297a3bf 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -143,18 +143,21 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id if ( intermediate is not None and intermediate.shape != [] - and intermediate.purpose == TensorPurpose.FeatureMap + and intermediate.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT) ): - intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt( - strides, - skirt, - intermediate.shape, - npu_block_type, - concat_axis, - concat_offset, - split_offsets[0], - upscaling, - ) + if intermediate.purpose is TensorPurpose.FeatureMap: + intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt( + strides, + skirt, + intermediate.shape, + npu_block_type, + concat_axis, + concat_offset, + split_offsets[0], + upscaling, + ) + else: + intermediate_box = Box([0] * len(intermediate.shape), list(intermediate.shape)) yield from dma_if_necessary(ps, intermediate_box, intermediate) weight_box = None @@ -232,7 +235,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ofm_box = Box(ofm_start, ofm_end) k_height = 1 - if npu_block_type == NpuBlockType.Pooling: + if npu_block_type == set((NpuBlockType.Pooling, NpuBlockType.ReduceSum)): if ps.primary_op is not None: k_height = ps.primary_op.attrs["ksize"][1] else: |