diff options
author | Richard Burton <richard.burton@arm.com> | 2020-04-08 16:39:05 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-04-10 16:11:09 +0000 |
commit | dc0c6ed9f8b993e63f492f203d7d7080ab4c835c (patch) | |
tree | ea8541990b13ebf1a038009aa6b8b4b1ea8c3f55 /python/pyarmnn/test/test_iconnectable.py | |
parent | fe5a24beeef6e9a41366e694f41093565e748048 (diff) | |
download | armnn-dc0c6ed9f8b993e63f492f203d7d7080ab4c835c.tar.gz |
Add PyArmNN to work with ArmNN API of 20.02
* Add Swig rules for generating python wrapper
* Add documentation
* Add tests and testing data
Change-Id: If48eda08931514fa21e72214dfead2835f07237c
Signed-off-by: Richard Burton <richard.burton@arm.com>
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'python/pyarmnn/test/test_iconnectable.py')
-rw-r--r-- | python/pyarmnn/test/test_iconnectable.py | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/python/pyarmnn/test/test_iconnectable.py b/python/pyarmnn/test/test_iconnectable.py new file mode 100644 index 0000000000..0d15be5e73 --- /dev/null +++ b/python/pyarmnn/test/test_iconnectable.py @@ -0,0 +1,142 @@ +# Copyright © 2020 Arm Ltd. All rights reserved. +# SPDX-License-Identifier: MIT +import pytest + +import pyarmnn as ann + + +@pytest.fixture(scope="function") +def network(): + return ann.INetwork() + + +class TestIInputIOutputIConnectable: + + def test_input_slot(self, network): + # Create input, addition & output layer + input1 = network.AddInputLayer(0, "input1") + input2 = network.AddInputLayer(1, "input2") + add = network.AddAdditionLayer("addition") + output = network.AddOutputLayer(0, "output") + + # Connect the input/output slots for each layer + input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) + input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) + add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) + + # Check IInputSlot GetConnection() + input_slot = add.GetInputSlot(0) + input_slot_connection = input_slot.GetConnection() + + assert isinstance(input_slot_connection, ann.IOutputSlot) + + del input_slot_connection + + assert input_slot.GetConnection() + assert isinstance(input_slot.GetConnection(), ann.IOutputSlot) + + del input_slot + + assert add.GetInputSlot(0) + + def test_output_slot(self, network): + + # Create input, addition & output layer + input1 = network.AddInputLayer(0, "input1") + input2 = network.AddInputLayer(1, "input2") + add = network.AddAdditionLayer("addition") + output = network.AddOutputLayer(0, "output") + + # Connect the input/output slots for each layer + input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) + input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) + add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) + + # Check IInputSlot GetConnection() + add_get_input_connection = add.GetInputSlot(0).GetConnection() + output_get_input_connection = output.GetInputSlot(0).GetConnection() + + # Check IOutputSlot GetConnection() + add_get_output_connect = add.GetOutputSlot(0).GetConnection(0) + assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot) + + # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner() + assert add_get_input_connection.GetNumConnections() == 1 + assert len(add_get_input_connection) == 1 + assert add_get_input_connection[0] + assert add_get_input_connection.CalculateIndexOnOwner() == 0 + + # Check GetOwningLayerGuid(). Check that it is different for add and output layer + assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid() + + # Set TensorInfo + test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32) + + # Check IsTensorInfoSet() + assert not add_get_input_connection.IsTensorInfoSet() + add_get_input_connection.SetTensorInfo(test_tensor_info) + assert add_get_input_connection.IsTensorInfoSet() + + # Check GetTensorInfo() + output_tensor_info = add_get_input_connection.GetTensorInfo() + assert 2 == output_tensor_info.GetNumDimensions() + assert 6 == output_tensor_info.GetNumElements() + + # Check Disconnect() + assert output_get_input_connection.GetNumConnections() == 1 # 1 connection to Outputslot0 from input1 + add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0)) # disconnect add.OutputSlot0 from Output.InputSlot0 + assert output_get_input_connection.GetNumConnections() == 0 + + def test_output_slot__out_of_range(self, network): + # Create input layer to check output slot get item handling + input1 = network.AddInputLayer(0, "input1") + + outputSlot = input1.GetOutputSlot(0) + with pytest.raises(ValueError) as err: + outputSlot[1] + + assert "Invalid index 1 provided" in str(err.value) + + def test_iconnectable_guid(self, network): + + # Check IConnectable GetGuid() + # Note Guid can change based on which tests are run so + # checking here that each layer does not have the same guid + add_id = network.AddAdditionLayer().GetGuid() + output_id = network.AddOutputLayer(0).GetGuid() + assert add_id != output_id + + def test_iconnectable_layer_functions(self, network): + + # Create input, addition & output layer + input1 = network.AddInputLayer(0, "input1") + input2 = network.AddInputLayer(1, "input2") + add = network.AddAdditionLayer("addition") + output = network.AddOutputLayer(0, "output") + + # Check GetNumInputSlots(), GetName() & GetNumOutputSlots() + assert input1.GetNumInputSlots() == 0 + assert input1.GetName() == "input1" + assert input1.GetNumOutputSlots() == 1 + + assert input2.GetNumInputSlots() == 0 + assert input2.GetName() == "input2" + assert input2.GetNumOutputSlots() == 1 + + assert add.GetNumInputSlots() == 2 + assert add.GetName() == "addition" + assert add.GetNumOutputSlots() == 1 + + assert output.GetNumInputSlots() == 1 + assert output.GetName() == "output" + assert output.GetNumOutputSlots() == 0 + + # Check GetOutputSlot() + input1_get_output = input1.GetOutputSlot(0) + assert input1_get_output.GetNumConnections() == 0 + assert len(input1_get_output) == 0 + + # Check GetInputSlot() + add_get_input = add.GetInputSlot(0) + add_get_input.GetConnection() + assert isinstance(add_get_input, ann.IInputSlot) |