aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/weight_compressor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r--ethosu/vela/weight_compressor.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index b0187b65..c07229fb 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -20,6 +20,7 @@ from collections import namedtuple
import numpy as np
+from .api import NpuBlockTraversal
from .architecture_features import Accelerator
from .architecture_features import ArchitectureFeatures
from .data_type import DataType
@@ -53,7 +54,7 @@ def encode_weights(
ifm_bitdepth: int,
ofm_block_depth: int,
is_depthwise: bool,
- is_partkernel: bool,
+ block_traversal: NpuBlockTraversal,
):
"""
Public facing API to use the ethosu weight encoding.
@@ -64,7 +65,7 @@ def encode_weights(
:param ifm_bitdepth: the bitdepth of input feature map
:param ofm_block_depth: the depth of blocks for ethosu processing
:param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
- :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis
+ :param block_traversal: indicates how these weights are traversed on sub-kernal basis
:return: a bytearray of compressed weights
"""
@@ -75,13 +76,15 @@ def encode_weights(
assert isinstance(ifm_bitdepth, int)
assert isinstance(ofm_block_depth, int)
assert isinstance(is_depthwise, bool)
- assert isinstance(is_partkernel, bool)
+ assert isinstance(block_traversal, NpuBlockTraversal)
# Checks for weight layout
assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
# It cannot be both partkernel and depthwise
- assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive"
+ assert not (
+ is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST
+ ), "encode_weights :: partkernel and depthwise are mutually exclusive"
# Check valid values for dilation
assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0])
@@ -95,7 +98,7 @@ def encode_weights(
brick_weights=weights_volume,
ofm_block_depth=ofm_block_depth,
is_depthwise=is_depthwise,
- is_partkernel=is_partkernel,
+ is_partkernel=block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST,
ifm_bitdepth=ifm_bitdepth,
dilation=dilation_xy,
)
@@ -335,7 +338,10 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth
tens.block_traversal = TensorBlockTraversal.DepthFirst
is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
- is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst
+ if tens.block_traversal == TensorBlockTraversal.PartKernelFirst:
+ block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+ else:
+ block_traversal = NpuBlockTraversal.DEPTH_FIRST
if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias:
# Transpose Convoluion, reverse weights in H and W axes
@@ -370,7 +376,7 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth
ifm_bitdepth=ifm_bitdepth,
ofm_block_depth=block_depth,
is_depthwise=is_depthwise,
- is_partkernel=is_partkernel,
+ block_traversal=block_traversal,
)
encoded_stream.extend(encoded_substream)
substream_offsets.append(len(encoded_stream))