From 245d64c60d0ea30f5080ff53225b5169927e24d6 Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Mon, 2 Dec 2019 12:59:43 +0000 Subject: Work in progress of python bindings for Arm NN Not built or tested in any way Signed-off-by: Matthew Bentham Change-Id: Ie7f92b529aa5087130f0c5cc8c17db1581373236 --- python/pyarmnn/src/pyarmnn/__init__.py | 138 +++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 python/pyarmnn/src/pyarmnn/__init__.py (limited to 'python/pyarmnn/src/pyarmnn/__init__.py') diff --git a/python/pyarmnn/src/pyarmnn/__init__.py b/python/pyarmnn/src/pyarmnn/__init__.py new file mode 100644 index 0000000000..c451479614 --- /dev/null +++ b/python/pyarmnn/src/pyarmnn/__init__.py @@ -0,0 +1,138 @@ +# Copyright © 2019 Arm Ltd. All rights reserved. +# SPDX-License-Identifier: MIT +import inspect +import sys +import logging + +from ._generated.pyarmnn_version import GetVersion, GetMajorVersion, GetMinorVersion + +# Parsers +try: + from ._generated.pyarmnn_caffeparser import ICaffeParser +except ImportError as err: + logger = logging.getLogger(__name__) + message = "Your ArmNN library instance does not support Caffe models parser functionality. " + logger.warning(message + "Skipped ICaffeParser import.") + logger.debug(str(err)) + + + def ICaffeParser(): + raise RuntimeError(message) + +try: + from ._generated.pyarmnn_onnxparser import IOnnxParser +except ImportError as err: + logger = logging.getLogger(__name__) + message = "Your ArmNN library instance does not support Onnx models parser functionality. " + logger.warning(message + "Skipped IOnnxParser import.") + logger.debug(str(err)) + + + def IOnnxParser(): + raise RuntimeError(message) + +try: + from ._generated.pyarmnn_tfparser import ITfParser +except ImportError as err: + logger = logging.getLogger(__name__) + message = "Your ArmNN library instance does not support TF models parser functionality. " + logger.warning(message + "Skipped ITfParser import.") + logger.debug(str(err)) + + + def ITfParser(): + raise RuntimeError(message) + +try: + from ._generated.pyarmnn_tfliteparser import ITfLiteParser +except ImportError as err: + logger = logging.getLogger(__name__) + message = "Your ArmNN library instance does not support TF lite models parser functionality. " + logger.warning(message + "Skipped ITfLiteParser import.") + logger.debug(str(err)) + + + def ITfLiteParser(): + raise RuntimeError(message) + +# Network +from ._generated.pyarmnn import Optimize, OptimizerOptions, IOptimizedNetwork, IInputSlot, \ + IOutputSlot, IConnectableLayer, INetwork + +# Backend +from ._generated.pyarmnn import BackendId +from ._generated.pyarmnn import IDeviceSpec + +# Tensors +from ._generated.pyarmnn import TensorInfo, TensorShape + +# Runtime +from ._generated.pyarmnn import IRuntime, CreationOptions, INetworkProperties + +# Profiler +from ._generated.pyarmnn import IProfiler + +# Types +from ._generated.pyarmnn import DataType_Float32, DataType_QuantisedAsymm8, DataType_Signed32, \ + DataType_QuantisedSymm16, DataType_Float16, DataType_QuantizedSymm8PerAxis, \ + DataType_QuantisedSymm8, DataType_Boolean +from ._generated.pyarmnn import DataLayout_NCHW, DataLayout_NHWC + +from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_BoundedReLu, ActivationFunction_LeakyReLu, \ + ActivationFunction_Linear, ActivationFunction_ReLu, ActivationFunction_Sigmoid, ActivationFunction_SoftReLu, \ + ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor +from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor +from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor +from ._generated.pyarmnn import ComparisonDescriptor, ComparisonOperation_Equal, ComparisonOperation_Greater, \ + ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \ + ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual +from ._generated.pyarmnn import Convolution2dDescriptor, DepthToSpaceDescriptor, DepthwiseConvolution2dDescriptor, \ + DetectionPostProcessDescriptor, FakeQuantizationDescriptor, FullyConnectedDescriptor, \ + InstanceNormalizationDescriptor, LstmDescriptor, L2NormalizationDescriptor, MeanDescriptor +from ._generated.pyarmnn import NormalizationAlgorithmChannel_Across, NormalizationAlgorithmChannel_Within, \ + NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast, NormalizationDescriptor +from ._generated.pyarmnn import PadDescriptor +from ._generated.pyarmnn import PermutationVector, PermuteDescriptor +from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding_Floor, \ + PaddingMethod_Exclude, PaddingMethod_IgnoreValue, PoolingAlgorithm_Average, PoolingAlgorithm_L2, \ + PoolingAlgorithm_Max, Pooling2dDescriptor +from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \ + ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \ + StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \ + SplitterDescriptor +from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation + +from ._generated.pyarmnn import LstmInputParams + +# Public API +# Quantization +from ._quantization.quantize_and_dequantize import quantize, dequantize + +# Tensor +from ._tensor.tensor import Tensor +from ._tensor.const_tensor import ConstTensor +from ._tensor.workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray + +# Utilities +from ._utilities.profiling_helper import ProfilerData, get_profiling_data + +from ._version import __version__, __arm_ml_version__ + +ARMNN_VERSION = GetVersion() + + +def __check_version(): + from ._version import check_armnn_version + check_armnn_version(ARMNN_VERSION) + + +__check_version() + +__all__ = [] + +__private_api_names = ['__check_version'] + +for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) or inspect.isfunction(obj): + if name not in __private_api_names: + __all__.append(name) -- cgit v1.2.1