diff options
Diffstat (limited to 'ethosu/vela/high_level_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/high_level_command_stream_generator.py | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 6aa88d86..2297a3bf 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -143,18 +143,21 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id if ( intermediate is not None and intermediate.shape != [] - and intermediate.purpose == TensorPurpose.FeatureMap + and intermediate.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT) ): - intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt( - strides, - skirt, - intermediate.shape, - npu_block_type, - concat_axis, - concat_offset, - split_offsets[0], - upscaling, - ) + if intermediate.purpose is TensorPurpose.FeatureMap: + intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt( + strides, + skirt, + intermediate.shape, + npu_block_type, + concat_axis, + concat_offset, + split_offsets[0], + upscaling, + ) + else: + intermediate_box = Box([0] * len(intermediate.shape), list(intermediate.shape)) yield from dma_if_necessary(ps, intermediate_box, intermediate) weight_box = None @@ -232,7 +235,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ofm_box = Box(ofm_start, ofm_end) k_height = 1 - if npu_block_type == NpuBlockType.Pooling: + if npu_block_type == set((NpuBlockType.Pooling, NpuBlockType.ReduceSum)): if ps.primary_op is not None: k_height = ps.primary_op.attrs["ksize"][1] else: |