aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_tensor_shape.py
blob: 604e9b1ca44d746d984a1bdc9d2ee93419188efc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright © 2019 Arm Ltd. All rights reserved.
# SPDX-License-Identifier: MIT
import pytest
import pyarmnn as ann


def test_tensor_shape_tuple():
    tensor_shape = ann.TensorShape((1, 2, 3))

    assert 3 == tensor_shape.GetNumDimensions()
    assert 6 == tensor_shape.GetNumElements()


def test_tensor_shape_one():
    tensor_shape = ann.TensorShape((10,))
    assert 1 == tensor_shape.GetNumDimensions()
    assert 10 == tensor_shape.GetNumElements()


@pytest.mark.skip("This will segfault before it reaches SWIG wrapper. ???")
def test_tensor_shape_empty():
    ann.TensorShape(())


def test_tensor_shape_tuple_mess():
    tensor_shape = ann.TensorShape((1, "2", 3.0))

    assert 3 == tensor_shape.GetNumDimensions()
    assert 6 == tensor_shape.GetNumElements()


def test_tensor_shape_list():

    with pytest.raises(TypeError) as err:
        ann.TensorShape([1, 2, 3])

    assert "Argument is not a tuple" in str(err.value)


def test_tensor_shape_tuple_mess_fail():

    with pytest.raises(TypeError) as err:
        ann.TensorShape((1, "two", 3.0))

    assert "All elements must be numbers" in str(err.value)


def test_tensor_shape_varags():
    with pytest.raises(TypeError) as err:
        ann.TensorShape(1, 2, 3)

    assert "__init__() takes 2 positional arguments but 4 were given" in str(err.value)


def test_tensor_shape__get_item_out_of_bounds():
    tensor_shape = ann.TensorShape((1, 2, 3))
    with pytest.raises(ValueError) as err:
        for i in range(4):
            tensor_shape[i]

    assert "Invalid dimension index: 3 (number of dimensions is 3)" in str(err.value)


def test_tensor_shape__set_item_out_of_bounds():
    tensor_shape = ann.TensorShape((1, 2, 3))
    with pytest.raises(ValueError) as err:
        for i in range(4):
            tensor_shape[i] = 1

    assert "Invalid dimension index: 3 (number of dimensions is 3)" in str(err.value)

def test_tensor_shape___str__():
    tensor_shape = ann.TensorShape((1, 2, 3))

    assert str(tensor_shape) == "TensorShape{Shape(1, 2, 3), NumDimensions: 3, NumElements: 6}"