diff options
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r-- | ethosu/vela/weight_compressor.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index b0187b65..c07229fb 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -20,6 +20,7 @@ from collections import namedtuple import numpy as np +from .api import NpuBlockTraversal from .architecture_features import Accelerator from .architecture_features import ArchitectureFeatures from .data_type import DataType @@ -53,7 +54,7 @@ def encode_weights( ifm_bitdepth: int, ofm_block_depth: int, is_depthwise: bool, - is_partkernel: bool, + block_traversal: NpuBlockTraversal, ): """ Public facing API to use the ethosu weight encoding. @@ -64,7 +65,7 @@ def encode_weights( :param ifm_bitdepth: the bitdepth of input feature map :param ofm_block_depth: the depth of blocks for ethosu processing :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal - :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis + :param block_traversal: indicates how these weights are traversed on sub-kernal basis :return: a bytearray of compressed weights """ @@ -75,13 +76,15 @@ def encode_weights( assert isinstance(ifm_bitdepth, int) assert isinstance(ofm_block_depth, int) assert isinstance(is_depthwise, bool) - assert isinstance(is_partkernel, bool) + assert isinstance(block_traversal, NpuBlockTraversal) # Checks for weight layout assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4" # It cannot be both partkernel and depthwise - assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive" + assert not ( + is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST + ), "encode_weights :: partkernel and depthwise are mutually exclusive" # Check valid values for dilation assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0]) @@ -95,7 +98,7 @@ def encode_weights( brick_weights=weights_volume, ofm_block_depth=ofm_block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + is_partkernel=block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST, ifm_bitdepth=ifm_bitdepth, dilation=dilation_xy, ) @@ -335,7 +338,10 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth tens.block_traversal = TensorBlockTraversal.DepthFirst is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise - is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst + if tens.block_traversal == TensorBlockTraversal.PartKernelFirst: + block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST + else: + block_traversal = NpuBlockTraversal.DEPTH_FIRST if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias: # Transpose Convoluion, reverse weights in H and W axes @@ -370,7 +376,7 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth ifm_bitdepth=ifm_bitdepth, ofm_block_depth=block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + block_traversal=block_traversal, ) encoded_stream.extend(encoded_substream) substream_offsets.append(len(encoded_stream)) |