diff options
Diffstat (limited to 'ethosu/vela/test/extapi/test_extapi_encode_weights.py')
-rw-r--r-- | ethosu/vela/test/extapi/test_extapi_encode_weights.py | 18 |
1 files changed, 5 insertions, 13 deletions
diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py index 854d14c0..6367cb30 100644 --- a/ethosu/vela/test/extapi/test_extapi_encode_weights.py +++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py @@ -14,25 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # Description: -# Contains unit tests for encode_weights API for an external consumer +# Contains unit tests for npu_encode_weights API for an external consumer import numpy as np import pytest -from ethosu.vela import weight_compressor +from ethosu.vela.api import npu_encode_weights +from ethosu.vela.api import NpuAccelerator from ethosu.vela.api import NpuBlockTraversal -from ethosu.vela.architecture_features import Accelerator @pytest.mark.parametrize( - "arch", - [ - Accelerator.Ethos_U55_32, - Accelerator.Ethos_U55_64, - Accelerator.Ethos_U55_128, - Accelerator.Ethos_U55_256, - Accelerator.Ethos_U65_256, - Accelerator.Ethos_U65_512, - ], + "arch", list(NpuAccelerator), ) @pytest.mark.parametrize("dilation_x", [1, 2]) @pytest.mark.parametrize("dilation_y", [1, 2]) @@ -56,7 +48,7 @@ def test_encode_weights( block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST if depth_control == 3 else NpuBlockTraversal.DEPTH_FIRST dilation_xy = (dilation_x, dilation_y) - encoded_stream = weight_compressor.encode_weights( + encoded_stream = npu_encode_weights( accelerator=arch, weights_volume=weights_ohwi, dilation_xy=dilation_xy, |