diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-11-02 18:04:27 +0100 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-11-13 14:10:33 +0100 |
commit | e8a5a78dd16ec979c7a7bb1f5bd87e9b2909c32d (patch) | |
tree | 0829808f5ce047b12e1813ca382ac73c3300da91 /ethosu/vela/weight_compressor.py | |
parent | dda21afda93f3732491efdcf89af2b14396c683f (diff) | |
download | ethos-u-vela-e8a5a78dd16ec979c7a7bb1f5bd87e9b2909c32d.tar.gz |
MLBEDSW-839: Code generation using external API2.0.0.rc1
Added external API to generate register command streams.
Existing code generation has been refactored to make
use of this API.
Change-Id: Ibb4c2b167809869f16470b14da24f08a65c82b7b
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r-- | ethosu/vela/weight_compressor.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index b0187b65..c07229fb 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -20,6 +20,7 @@ from collections import namedtuple import numpy as np +from .api import NpuBlockTraversal from .architecture_features import Accelerator from .architecture_features import ArchitectureFeatures from .data_type import DataType @@ -53,7 +54,7 @@ def encode_weights( ifm_bitdepth: int, ofm_block_depth: int, is_depthwise: bool, - is_partkernel: bool, + block_traversal: NpuBlockTraversal, ): """ Public facing API to use the ethosu weight encoding. @@ -64,7 +65,7 @@ def encode_weights( :param ifm_bitdepth: the bitdepth of input feature map :param ofm_block_depth: the depth of blocks for ethosu processing :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal - :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis + :param block_traversal: indicates how these weights are traversed on sub-kernal basis :return: a bytearray of compressed weights """ @@ -75,13 +76,15 @@ def encode_weights( assert isinstance(ifm_bitdepth, int) assert isinstance(ofm_block_depth, int) assert isinstance(is_depthwise, bool) - assert isinstance(is_partkernel, bool) + assert isinstance(block_traversal, NpuBlockTraversal) # Checks for weight layout assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4" # It cannot be both partkernel and depthwise - assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive" + assert not ( + is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST + ), "encode_weights :: partkernel and depthwise are mutually exclusive" # Check valid values for dilation assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0]) @@ -95,7 +98,7 @@ def encode_weights( brick_weights=weights_volume, ofm_block_depth=ofm_block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + is_partkernel=block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST, ifm_bitdepth=ifm_bitdepth, dilation=dilation_xy, ) @@ -335,7 +338,10 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth tens.block_traversal = TensorBlockTraversal.DepthFirst is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise - is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst + if tens.block_traversal == TensorBlockTraversal.PartKernelFirst: + block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST + else: + block_traversal = NpuBlockTraversal.DEPTH_FIRST if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias: # Transpose Convoluion, reverse weights in H and W axes @@ -370,7 +376,7 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth ifm_bitdepth=ifm_bitdepth, ofm_block_depth=block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + block_traversal=block_traversal, ) encoded_stream.extend(encoded_substream) substream_offsets.append(len(encoded_stream)) |