diff options
Diffstat (limited to 'ethosu/vela/test/test_register_command_generator.py')
-rw-r--r-- | ethosu/vela/test/test_register_command_generator.py | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_register_command_generator.py b/ethosu/vela/test/test_register_command_generator.py new file mode 100644 index 00000000..f2a16097 --- /dev/null +++ b/ethosu/vela/test/test_register_command_generator.py @@ -0,0 +1,104 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains unit tests for register command stream generator +from ethosu.vela.api import NpuAddressRange +from ethosu.vela.api import NpuDataType +from ethosu.vela.api import NpuFeatureMap +from ethosu.vela.api import NpuLayout +from ethosu.vela.api import NpuShape3D +from ethosu.vela.api import NpuTileBox +from ethosu.vela.register_command_stream_generator import get_address_ranges +from ethosu.vela.register_command_stream_generator import get_strides + + +def test_get_fm_strides(): + """Tests calculation of feature map strides""" + fm = NpuFeatureMap() + fm.layout = NpuLayout.NHCWB16 + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=7, width=10, depth=24) + assert get_strides(fm) == NpuShape3D(height=640, width=32, depth=320) + fm.layout = NpuLayout.NHWC + assert get_strides(fm) == NpuShape3D(height=480, width=48, depth=2) + fm.data_type = NpuDataType.UINT8 + assert get_strides(fm) == NpuShape3D(height=240, width=24, depth=1) + + +def test_get_address_ranges_one_tile(): + """Tests calculation of feature map address ranges, with 1 tile used""" + fm = NpuFeatureMap() + fm.region = 4 + fm.layout = NpuLayout.NHWC + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=50, width=40, depth=3) + fm.tiles = NpuTileBox(height_0=50, height_1=50, width_0=40, addresses=[8000, 0, 0, 0]) + ranges = get_address_ranges(fm) + assert ranges == [NpuAddressRange(region=4, address=8000, length=12000), None, None, None] + + +def test_get_address_ranges_horizontal_tiles(): + """Tests calculation of feature map address ranges, with 2 horizontal tiles used""" + fm = NpuFeatureMap() + fm.region = 6 + fm.layout = NpuLayout.NHWC + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=50, width=10, depth=20) + fm.tiles = NpuTileBox(height_0=20, height_1=30, width_0=10, addresses=[256, 0, 16000, 0]) + ranges = get_address_ranges(fm) + assert ranges == [ + NpuAddressRange(region=6, address=256, length=8000), + None, + NpuAddressRange(region=6, address=16000, length=12000), + None, + ] + + +def test_get_address_ranges_vertical_tiles(): + """Tests calculation of feature map address ranges, with 2 vertical tiles used""" + fm = NpuFeatureMap() + fm.region = 6 + fm.layout = NpuLayout.NHWC + fm.data_type = NpuDataType.INT8 + # Set strides explicitly + fm.shape = NpuShape3D(height=50, width=10, depth=20) + fm.strides = NpuShape3D(height=100, width=20, depth=1) + fm.tiles = NpuTileBox(height_0=50, height_1=50, width_0=5, addresses=[16, 32000, 0, 0]) + ranges = get_address_ranges(fm) + assert ranges == [ + NpuAddressRange(region=6, address=16, length=5000), + NpuAddressRange(region=6, address=32000, length=5000), + None, + None, + ] + + +def test_get_address_ranges_4_tiles(): + """Tests calculation of feature map address ranges, with 4 tiles used""" + fm = NpuFeatureMap() + fm.region = 6 + fm.layout = NpuLayout.NHCWB16 + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=50, width=10, depth=20) + fm.tiles = NpuTileBox(height_0=30, height_1=10, width_0=3, addresses=[16, 32000, 8000, 16000]) + ranges = get_address_ranges(fm) + assert ranges == [ + NpuAddressRange(region=6, address=16, length=18952), + NpuAddressRange(region=6, address=32000, length=6280), + NpuAddressRange(region=6, address=8000, length=12552), + NpuAddressRange(region=6, address=28800, length=12680), + ] |