From aeae56770f3c19182d32cc63fd32396e061a7648 Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Mon, 2 Nov 2020 18:04:27 +0100 Subject: MLBEDSW-3424: Expose API through separate file All external APIs are now exposed by api.py. Signed-off-by: Louis Verhaard Change-Id: I33f480e424692ac30e9c7d791f583199f31164a7 --- ethosu/vela/api.py | 74 +++++++++++++++++++++- ethosu/vela/architecture_features.py | 15 +++++ ethosu/vela/register_command_stream_generator.py | 6 +- ethosu/vela/test/extapi/test_extapi_encode_bias.py | 6 +- .../vela/test/extapi/test_extapi_encode_weights.py | 18 ++---- .../test/extapi/test_extapi_generate_commands.py | 20 +++--- ethosu/vela/weight_compressor.py | 11 ++-- 7 files changed, 116 insertions(+), 34 deletions(-) diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py index 0799ab1c..f64a38fb 100644 --- a/ethosu/vela/api.py +++ b/ethosu/vela/api.py @@ -15,7 +15,7 @@ # limitations under the License. # # Description: -# Contains data types used in the external API for code generation +# Contains external APIs from enum import auto from enum import Enum from typing import List @@ -23,11 +23,26 @@ from typing import NamedTuple from typing import Optional from typing import Tuple +import numpy + API_version_major = 1 API_version_minor = 0 api_version = f"{API_version_major}.{API_version_minor}" +class NpuAccelerator(Enum): + """ + Supported accelerators + """ + + Ethos_U55_32 = auto() + Ethos_U55_64 = auto() + Ethos_U55_128 = auto() + Ethos_U55_256 = auto() + Ethos_U65_256 = auto() + Ethos_U65_512 = auto() + + class NpuElementWiseOp(Enum): """ Elementwise operation @@ -381,3 +396,60 @@ def npu_get_API_version(): """ version = (API_version_major << 16) | (API_version_minor & 0xFFFF) return version + + +def npu_encode_weights( + accelerator: NpuAccelerator, + weights_volume: numpy.ndarray, + dilation_xy: Tuple[int, int], + ifm_bitdepth: int, + ofm_block_depth: int, + is_depthwise: bool, + block_traversal: NpuBlockTraversal, +): + """ + Public facing API to use the Ethos-U weight encoding. + + :param accelerator: NpuAccelerator enum to pick the correct accelerator + :param weights_volume: numpy.ndarray in OHWI layout with a shape of four + :param dilation_xy: a two element tuple of dilation attributes in x,y dimension + :param ifm_bitdepth: the bitdepth of input feature map + :param ofm_block_depth: the depth of blocks for processing + :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal + :param block_traversal: indicates how these weights are traversed on sub-kernel basis + :return: a bytearray of compressed weights + """ + from .architecture_features import Accelerator + from . import weight_compressor + + acc = Accelerator.from_npu_accelerator(accelerator) + return weight_compressor.encode_weights( + acc, weights_volume, dilation_xy, ifm_bitdepth, ofm_block_depth, is_depthwise, block_traversal + ) + + +def npu_encode_bias(bias: numpy.int64, scale: int, shift: int): + """ + Public facing API to pack bias and scale values as required by the hardware + :param bias: 64-bit signed number that includes 40-bit signed bias + :param scale: 32-bit scale value + :param shift: 6-bit shift value + :return: packed 80-bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)] + """ + from . import weight_compressor + + return weight_compressor.encode_bias(bias, scale, shift) + + +def npu_generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: NpuAccelerator) -> List[int]: + """ + Public facing API for generating an Ethos-U register command stream. + Calculates dependencies between commands and inserts wait operations if needed. + + :param npu_op_list: List[NpuOperation] list of high level NPU operations + :param accelerator: NpuAccelerator enum to pick the correct accelerator + :return register commands, as a list of 32-bit integers + """ + from . import register_command_stream_generator + + return register_command_stream_generator.generate_register_command_stream(npu_op_list, accelerator) diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py index 7b6c3bed..18846cfd 100644 --- a/ethosu/vela/architecture_features.py +++ b/ethosu/vela/architecture_features.py @@ -21,6 +21,7 @@ from configparser import ConfigParser import numpy as np +from .api import NpuAccelerator from .errors import CliOptionError from .errors import ConfigOptionError from .ethos_u55_regs.ethos_u55_regs import resampling_mode @@ -131,6 +132,20 @@ class Accelerator(enum.Enum): def member_list(cls): return [e.value for e in cls] + @classmethod + def from_npu_accelerator(cls, npu_accelerator: NpuAccelerator) -> "Accelerator": + """Converts the given public API object to Accelerator (used internally)""" + accelerator_map = { + NpuAccelerator.Ethos_U55_32: cls.Ethos_U55_32, + NpuAccelerator.Ethos_U55_64: cls.Ethos_U55_64, + NpuAccelerator.Ethos_U55_128: cls.Ethos_U55_128, + NpuAccelerator.Ethos_U55_256: cls.Ethos_U55_256, + NpuAccelerator.Ethos_U65_256: cls.Ethos_U65_256, + NpuAccelerator.Ethos_U65_512: cls.Ethos_U65_512, + } + assert npu_accelerator in accelerator_map, f"Unsupported accelerator {npu_accelerator}" + return accelerator_map[npu_accelerator] + @enum.unique class MemPort(enum.Enum): diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index e612c301..04f7072d 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -28,6 +28,7 @@ import numpy as np from . import numeric_util from . import scaling +from .api import NpuAccelerator from .api import NpuActivation from .api import NpuActivationOp from .api import NpuAddressRange @@ -1270,15 +1271,16 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): print("command stream length in words", len(sg.register_command_stream)) -def generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: Accelerator) -> List[int]: +def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]: """ - Public facing API for generating an Ethos-U register command stream. + Internal implementation of the public facing API for generating an Ethos-U register command stream. Calculates dependencies between commands and inserts wait operations if needed. :param npu_op_list: List[NpuOperation] list of high level NPU operations :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator :return Ethos-U instructions, as a list of 32-bit integers """ + accelerator = Accelerator.from_npu_accelerator(npu_accelerator) emit = CommandStreamEmitter() arch = ArchitectureFeatures( vela_config_files=None, diff --git a/ethosu/vela/test/extapi/test_extapi_encode_bias.py b/ethosu/vela/test/extapi/test_extapi_encode_bias.py index ffdd3b0c..c0a4a9ab 100644 --- a/ethosu/vela/test/extapi/test_extapi_encode_bias.py +++ b/ethosu/vela/test/extapi/test_extapi_encode_bias.py @@ -14,12 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # Description: -# Contains unit tests for encode_biases API for an external consumer +# Contains unit tests for npu_encode_bias API for an external consumer import random import numpy as np -from ethosu.vela.weight_compressor import encode_bias +from ethosu.vela.api import npu_encode_bias def test_encode_bias(): @@ -34,6 +34,6 @@ def test_encode_bias(): bias = np.int64(random.randint(bias_lower_limit, bias_upper_limit)) scale = int(random.randint(scale_lower_limit, scale_upper_limit)) shift = int(random.randint(shift_lower_limit, shift_upper_limit)) - biases_enc = encode_bias(bias, scale, shift) + biases_enc = npu_encode_bias(bias, scale, shift) assert isinstance(biases_enc, bytearray) assert len(biases_enc) == 10 diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py index 854d14c0..6367cb30 100644 --- a/ethosu/vela/test/extapi/test_extapi_encode_weights.py +++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py @@ -14,25 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # Description: -# Contains unit tests for encode_weights API for an external consumer +# Contains unit tests for npu_encode_weights API for an external consumer import numpy as np import pytest -from ethosu.vela import weight_compressor +from ethosu.vela.api import npu_encode_weights +from ethosu.vela.api import NpuAccelerator from ethosu.vela.api import NpuBlockTraversal -from ethosu.vela.architecture_features import Accelerator @pytest.mark.parametrize( - "arch", - [ - Accelerator.Ethos_U55_32, - Accelerator.Ethos_U55_64, - Accelerator.Ethos_U55_128, - Accelerator.Ethos_U55_256, - Accelerator.Ethos_U65_256, - Accelerator.Ethos_U65_512, - ], + "arch", list(NpuAccelerator), ) @pytest.mark.parametrize("dilation_x", [1, 2]) @pytest.mark.parametrize("dilation_y", [1, 2]) @@ -56,7 +48,7 @@ def test_encode_weights( block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST if depth_control == 3 else NpuBlockTraversal.DEPTH_FIRST dilation_xy = (dilation_x, dilation_y) - encoded_stream = weight_compressor.encode_weights( + encoded_stream = npu_encode_weights( accelerator=arch, weights_volume=weights_ohwi, dilation_xy=dilation_xy, diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py index 49b24b2b..86ef804a 100644 --- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py +++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py @@ -15,7 +15,9 @@ # limitations under the License. # # Description: -# Contains unit tests for generate_register_command_stream API for an external consumer +# Contains unit tests for npu_generate_register_command_stream API for an external consumer +from ethosu.vela.api import npu_generate_register_command_stream +from ethosu.vela.api import NpuAccelerator from ethosu.vela.api import NpuActivation from ethosu.vela.api import NpuActivationOp from ethosu.vela.api import NpuAddressRange @@ -35,11 +37,9 @@ from ethosu.vela.api import NpuPoolingOperation from ethosu.vela.api import NpuQuantization from ethosu.vela.api import NpuShape3D from ethosu.vela.api import NpuTileBox -from ethosu.vela.architecture_features import Accelerator from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0 from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd1 from ethosu.vela.register_command_stream_generator import CmdMode -from ethosu.vela.register_command_stream_generator import generate_register_command_stream from ethosu.vela.register_command_stream_generator import get_address_ranges @@ -109,7 +109,7 @@ def test_conv2d(): # In this example we assume that the weights were compressed with ofm depth 16; # let vela choose suitable block width and height by setting these to -1 op.block_config = NpuShape3D(height=-1, width=-1, depth=16) - cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128) + cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128) check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1) check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 512) check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE1, 0) @@ -203,7 +203,7 @@ def create_fully_connected_op() -> NpuConv2DOperation: def test_fully_connected(): """Tests command stream generation for a fully connected operation""" op = create_fully_connected_op() - cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128) + cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128) check_cmd0(cmds, cmd0.NPU_OP_CONV, 0) assert len(cmds) > 20 @@ -223,7 +223,7 @@ def test_depthwise(): op.weights = [weights_dest] op.biases = [NpuAddressRange(region=0, address=0, length=80)] op.block_config = NpuShape3D(height=-1, width=-1, depth=8) - cmds = generate_register_command_stream([dma_op, op], Accelerator.Ethos_U55_128) + cmds = npu_generate_register_command_stream([dma_op, op], NpuAccelerator.Ethos_U55_128) check_cmd0(cmds, cmd0.NPU_SET_DMA0_SRC_REGION, 0) check_cmd1(cmds, cmd1.NPU_SET_DMA0_SRC, 0x40) check_cmd0(cmds, cmd0.NPU_SET_DMA0_DST_REGION, 1) @@ -248,7 +248,7 @@ def test_mul_with_broadcast_and_relu(): op.activation = NpuActivation(NpuActivationOp.NONE_OR_RELU) op.activation.min = 0 # RELU # Do not set a block config, let vela choose one - cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_32) + cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_32) check_cmd1(cmds, cmd1.NPU_SET_OFM_SCALE, 1073741824, 30) check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1) check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 32) @@ -337,7 +337,7 @@ def create_avg_pool_op() -> NpuPoolingOperation: def test_avg_pool(): """Tests average pool operation""" op = create_avg_pool_op() - cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128) + cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128) check_cmd0(cmds, cmd0.NPU_OP_POOL, 1) assert len(cmds) > 10 @@ -346,7 +346,7 @@ def test_two_operations(): """Tests code generation with 2 operations""" op1 = create_fully_connected_op() op2 = create_avg_pool_op() - cmds = generate_register_command_stream([op1, op2], Accelerator.Ethos_U55_64) + cmds = npu_generate_register_command_stream([op1, op2], NpuAccelerator.Ethos_U55_64) check_cmd0(cmds, cmd0.NPU_OP_POOL, 1) check_cmd0(cmds, cmd0.NPU_OP_CONV, 0) check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0) @@ -363,7 +363,7 @@ def test_dma_op(): assert dest is not None src = NpuAddressRange(0, 0x24000, dest.length) dma_op = NpuDmaOperation(src, dest) - cmds = generate_register_command_stream([dma_op, pool_op], Accelerator.Ethos_U55_64) + cmds = npu_generate_register_command_stream([dma_op, pool_op], NpuAccelerator.Ethos_U55_64) check_cmd0(cmds, cmd0.NPU_OP_DMA_START, 0) # A DMA WAIT should have been inserted check_cmd0(cmds, cmd0.NPU_OP_DMA_WAIT, 0) diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index 0eab1851..40ebcd04 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -17,6 +17,7 @@ # Compresses and pads the weigths. It also calculates the scales and packs with the biases. import math from collections import namedtuple +from typing import Tuple import numpy as np @@ -50,14 +51,14 @@ WeightCompressionConfig = namedtuple( def encode_weights( accelerator: Accelerator, weights_volume: np.ndarray, - dilation_xy: tuple, + dilation_xy: Tuple[int, int], ifm_bitdepth: int, ofm_block_depth: int, is_depthwise: bool, block_traversal: NpuBlockTraversal, ): """ - Public facing API to use the Ethos-U weight encoding. + Internal implementation of the public facing API to use weight encoding. :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator :param weights_volume: numpy.ndarray in OHWI layout with a shape of four @@ -65,10 +66,10 @@ def encode_weights( :param ifm_bitdepth: the bitdepth of input feature map :param ofm_block_depth: the depth of blocks for Ethos-U processing :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal - :param block_traversal: indicates how these weights are traversed on sub-kernal basis + :param block_traversal: indicates how these weights are traversed on sub-kernel basis + :return: a bytearray of compressed weights """ - # Check arg types assert isinstance(accelerator, Accelerator) assert isinstance(weights_volume, np.ndarray) @@ -108,7 +109,7 @@ def encode_weights( def encode_bias(bias: np.int64, scale: int, shift: int): """ - Public facing API to pack bias and scale values as required by the Ethos-U + Internal implementation of public facing API to pack bias and scale values as required by the Ethos-U :param bias: 64bit signed number that includes 40bit signed bias :param scale: 32bit scale value -- cgit v1.2.1