From d83d2e11d3dff5031fec513ca2aa22c19c9ea4d8 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 20 Jul 2020 12:05:32 +0100 Subject: [EXTAPI] refactor weight compression to be used by an external consumer *lint *added unit tests *added typecheck *added docstring for the api Change-Id: Ibd4bc40d4381ac40ad2ea3d500b26c4ec565ab07 Signed-off-by: Manupa Karunaratne --- ethosu/vela/architecture_features.py | 46 +++++++++--- ethosu/vela/errors.py | 18 +++++ .../vela/test/extapi/test_extapi_encode_weights.py | 73 +++++++++++++++++++ ethosu/vela/vela.py | 2 +- ethosu/vela/weight_compressor.py | 85 ++++++++++++++++++---- 5 files changed, 197 insertions(+), 27 deletions(-) create mode 100644 ethosu/vela/test/extapi/test_extapi_encode_weights.py (limited to 'ethosu') diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py index 6460c527..43b32109 100644 --- a/ethosu/vela/architecture_features.py +++ b/ethosu/vela/architecture_features.py @@ -120,6 +120,19 @@ class SharedBufferArea(enum.IntEnum): Size = Accumulators + 1 +class Accelerator(enum.Enum): + Ethos_U55_32 = "ethos-u55-32" + Ethos_U55_64 = "ethos-u55-64" + Ethos_U55_128 = "ethos-u55-128" + Ethos_U55_256 = "ethos-u55-256" + Yoda_256 = "yoda-256" + Yoda_512 = "yoda-512" + + @classmethod + def member_list(cls): + return [e.value for e in cls] + + class ArchitectureFeatures: """This class is a container for various parameters of the Ethos-U55 core and system configuration that can be tuned, either by command line @@ -136,15 +149,28 @@ Note the difference between ArchitectureFeatures and CompilerOptions "ArchitectureConfig", "macs cores ofm_ublock ifm_ublock shram_banks shram_granules elem_units" ) accelerator_configs = { - "yoda-512": ArchitectureConfig(256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8), - "yoda-256": ArchitectureConfig(256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8), - "ethos-u55-256": ArchitectureConfig(256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8), - "ethos-u55-128": ArchitectureConfig(128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4), - "ethos-u55-64": ArchitectureConfig(64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2), - "ethos-u55-32": ArchitectureConfig(32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1), + Accelerator.Yoda_512: ArchitectureConfig( + 256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8 + ), + Accelerator.Yoda_256: ArchitectureConfig( + 256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8 + ), + Accelerator.Ethos_U55_256: ArchitectureConfig( + 256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8 + ), + Accelerator.Ethos_U55_128: ArchitectureConfig( + 128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4 + ), + Accelerator.Ethos_U55_64: ArchitectureConfig( + 64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2 + ), + Accelerator.Ethos_U55_32: ArchitectureConfig( + 32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1 + ), } OFMSplitDepth = 16 + SubKernelMax = Block(8, 8, 65536) def __init__( self, @@ -159,20 +185,18 @@ Note the difference between ArchitectureFeatures and CompilerOptions ): accelerator_config = accelerator_config.lower() self.vela_config = vela_config - self.accelerator_config = accelerator_config - if self.accelerator_config not in ArchitectureFeatures.accelerator_configs: + if accelerator_config not in Accelerator.member_list(): raise OptionError("--accelerator-config", self.accelerator_config, "Unknown accelerator configuration") + self.accelerator_config = Accelerator(accelerator_config) accel_config = ArchitectureFeatures.accelerator_configs[self.accelerator_config] self.config = accel_config self.system_config = system_config - - self.is_yoda_system = "yoda-" in self.accelerator_config + self.is_yoda_system = self.accelerator_config in (Accelerator.Yoda_256, Accelerator.Yoda_512) self.ncores = accel_config.cores self.ofm_ublock = accel_config.ofm_ublock self.ifm_ublock = accel_config.ifm_ublock - self.subkernel_max = Block(8, 8, 65536) self.ofm_block_max = Block(64, 32, 128) self.override_block_config = override_block_config self.block_config_limit = block_config_limit diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py index 2c93fbc6..59740aac 100644 --- a/ethosu/vela/errors.py +++ b/ethosu/vela/errors.py @@ -15,6 +15,7 @@ # limitations under the License. # Description: # Defines custom exceptions. +import inspect import sys from .operation import Operation @@ -121,3 +122,20 @@ def TensorError(tens, msg): print("Error: {}".format(data)) sys.exit(1) + + +def typecheck(func): + def wrapper(*args, **kwargs): + fsig = inspect.signature(func) + args_zipped = zip(kwargs.values(), fsig.parameters.keys()) + for actual, expected in args_zipped: + expected_type = fsig.parameters[expected].annotation + actual_type = type(actual) + if expected_type is inspect.Parameter.empty: + raise TypeError("Please provide type info for {}, hint = {}".format(expected, actual_type)) + if expected_type is not actual_type: + raise TypeError("expected : {}, but got {}".format(expected_type, actual_type)) + # Actual execution + return func(*args, **kwargs) + + return wrapper diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py new file mode 100644 index 00000000..47ca02b8 --- /dev/null +++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py @@ -0,0 +1,73 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Contains unit tests for encode_weights API for an external consumer +import numpy as np +import pytest + +from ethosu.vela import weight_compressor +from ethosu.vela.architecture_features import Accelerator + + +@pytest.mark.parametrize( + "arch", + [ + Accelerator.Ethos_U55_32, + Accelerator.Ethos_U55_64, + Accelerator.Ethos_U55_128, + Accelerator.Ethos_U55_256, + Accelerator.Yoda_256, + Accelerator.Yoda_512, + ], +) +@pytest.mark.parametrize("dilation_x", [1, 2]) +@pytest.mark.parametrize("dilation_y", [1, 2]) +@pytest.mark.parametrize("ifm_bitdepth", [8, 16]) +@pytest.mark.parametrize("depth_control", [1, 2, 3]) +@pytest.mark.parametrize("weights_shape_and_block_depth", [((16, 16, 16, 16), 8), ((3, 3, 25, 16), 8)]) +def test_encode_weights( + arch, weights_shape_and_block_depth, dilation_x, dilation_y, ifm_bitdepth, depth_control, +): + """ + This unit test checks the interface of the API function but not the functionality. + Functional correctness is tested at a system level. + """ + + weights_shape = weights_shape_and_block_depth[0] + ofm_block_depth = weights_shape_and_block_depth[1] + val_max = np.iinfo(np.uint8).max + weights_hwio = np.random.randint(val_max, size=weights_shape, dtype=np.uint8) + weights_ohwi = np.transpose(weights_hwio, (3, 0, 1, 2)) + is_depthwise = True if depth_control == 2 else False + is_partkernel = True if depth_control == 3 else False + dilation_xy = (dilation_x, dilation_y) + + encoded_stream = weight_compressor.encode_weights( + accelerator=arch, + weights_volume=weights_ohwi, + dilation_xy=dilation_xy, + ifm_bitdepth=ifm_bitdepth, + ofm_block_depth=ofm_block_depth, + is_depthwise=is_depthwise, + is_partkernel=is_partkernel, + ) + assert type(encoded_stream) == bytearray + + +if __name__ == "__main__": + # two test candidates for debugging purposes + test_encode_weights(Accelerator.Ethos_U55_256, ((3, 3, 25, 16), 8), 1, 1, 8, 0) + test_encode_weights(Accelerator.Ethos_U55_256, ((16, 16, 16, 16), 8), 1, 1, 8, 0) diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py index 20bc525b..1766750e 100644 --- a/ethosu/vela/vela.py +++ b/ethosu/vela/vela.py @@ -170,7 +170,7 @@ def main(args=None): "--accelerator-config", type=str, default="ethos-u55-256", - choices=list(architecture_features.ArchitectureFeatures.accelerator_configs.keys()), + choices=list(architecture_features.Accelerator.member_list()), help="Accelerator configuration to use (default: %(default)s)", ) parser.add_argument( diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index 8ebd7511..687a0805 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -20,7 +20,10 @@ from collections import namedtuple import numpy as np +from .architecture_features import Accelerator +from .architecture_features import ArchitectureFeatures from .data_type import DataType +from .errors import typecheck from .errors import UnsupportedFeatureError from .nn_graph import SchedulingStrategy from .numeric_util import round_up @@ -42,6 +45,55 @@ WeightCompressionConfig = namedtuple( ) +@typecheck +def encode_weights( + accelerator: Accelerator, + weights_volume: np.ndarray, + dilation_xy: tuple, + ifm_bitdepth: int, + ofm_block_depth: int, + is_depthwise: bool, + is_partkernel: bool, +): + """ + Public facing API to use the ethosu weight encoding. + + :param accelerator: architecture_features.Accelerator enum to pick the correct ethosu accelerator + :param weights_volume: numpy.ndarray in OHWI layout with a shape of four + :param dilation_xy: a two element tuple of dilation attributes in x,y dimension + :param ifm_bitdepth: the bitdepth of input feature map + :param ofm_block_depth: the depth of blocks for ethosu processing + :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal + :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis + :return: a bytearray of compressed weights + """ + + # Checks for weight layout + assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4" + + # It cannot be both partkernel and depthwise + assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive" + + # Check valid values for dilation + assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0]) + assert dilation_xy[1] in (1, 2), "encode_weights :: dilation y should be 1 or 2 not {}".format(dilation_xy[1]) + + ifm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ifm_ublock + ofm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ofm_ublock + raw_stream = generate_brick( + ifm_ublock=ifm_ublock, + ofm_ublock=ofm_ublock, + brick_weights=weights_volume, + ofm_block_depth=ofm_block_depth, + is_depthwise=is_depthwise, + is_partkernel=is_partkernel, + ifm_bitdepth=ifm_bitdepth, + dilation=dilation_xy, + ) + encoded_stream = encode(raw_stream) + return encoded_stream + + def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation): # Note: for an ofm block only its depth is used in weight compression. # And block depth > ofm depth gives same result as block depth == ofm depth @@ -93,13 +145,12 @@ def encode(weight_stream): return compressed -def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth, dilation): - is_depthwise = block_traversal == TensorBlockTraversal.DepthWise - is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst - decomp_h = arch.subkernel_max.height // dilation[0] - decomp_w = arch.subkernel_max.width // dilation[1] - ofm_ublock = arch.ofm_ublock - ifm_ublock = arch.ifm_ublock +def generate_brick( + ifm_ublock, ofm_ublock, brick_weights, ofm_block_depth, is_depthwise, is_partkernel, ifm_bitdepth, dilation +): + + decomp_h = ArchitectureFeatures.SubKernelMax.height // dilation[0] + decomp_w = ArchitectureFeatures.SubKernelMax.width // dilation[1] # Expect weights formatted OHWI ofm_depth = brick_weights.shape[-4] ifm_depth = brick_weights.shape[-1] @@ -245,6 +296,9 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth else: tens.block_traversal = TensorBlockTraversal.DepthFirst + is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise + is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst + if tens.consumer_list[0].type == "Conv2DBackpropInputSwitchedBias": # Transpose Convoluion, reverse weights in H and W axes weights = np.flip(weights, axis=(0, 1)) @@ -262,7 +316,6 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth substream_offsets = [0] encoded_stream = [] - raw_size = 0 # For each core, deinterleave weights from the larger volume # and generate separate compressed streams. @@ -270,15 +323,17 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth core_weights = core_deinterleave(brick_weights, core, arch.ncores) block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores + encoded_substream = [] if block_depth != 0: - raw_stream = generate_brick( - arch, core_weights, block_depth, tens.block_traversal, ifm_bitdepth, dilation + encoded_substream = encode_weights( + accelerator=arch.accelerator_config, + weights_volume=core_weights, + dilation_xy=dilation, + ifm_bitdepth=ifm_bitdepth, + ofm_block_depth=block_depth, + is_depthwise=is_depthwise, + is_partkernel=is_partkernel, ) - else: - raw_stream = [] - - raw_size += len(raw_stream) - encoded_substream = encode(raw_stream) encoded_stream.extend(encoded_substream) substream_offsets.append(len(encoded_stream)) -- cgit v1.2.1