aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/src/pyarmnn/__init__.py
blob: c451479614c64ade6d6a78b71c06dccce0176a8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright © 2019 Arm Ltd. All rights reserved.
# SPDX-License-Identifier: MIT
import inspect
import sys
import logging

from ._generated.pyarmnn_version import GetVersion, GetMajorVersion, GetMinorVersion

# Parsers
try:
    from ._generated.pyarmnn_caffeparser import ICaffeParser
except ImportError as err:
    logger = logging.getLogger(__name__)
    message = "Your ArmNN library instance does not support Caffe models parser functionality. "
    logger.warning(message + "Skipped ICaffeParser import.")
    logger.debug(str(err))


    def ICaffeParser():
        raise RuntimeError(message)

try:
    from ._generated.pyarmnn_onnxparser import IOnnxParser
except ImportError as err:
    logger = logging.getLogger(__name__)
    message = "Your ArmNN library instance does not support Onnx models parser functionality. "
    logger.warning(message + "Skipped IOnnxParser import.")
    logger.debug(str(err))


    def IOnnxParser():
        raise RuntimeError(message)

try:
    from ._generated.pyarmnn_tfparser import ITfParser
except ImportError as err:
    logger = logging.getLogger(__name__)
    message = "Your ArmNN library instance does not support TF models parser functionality. "
    logger.warning(message + "Skipped ITfParser import.")
    logger.debug(str(err))


    def ITfParser():
        raise RuntimeError(message)

try:
    from ._generated.pyarmnn_tfliteparser import ITfLiteParser
except ImportError as err:
    logger = logging.getLogger(__name__)
    message = "Your ArmNN library instance does not support TF lite models parser functionality. "
    logger.warning(message + "Skipped ITfLiteParser import.")
    logger.debug(str(err))


    def ITfLiteParser():
        raise RuntimeError(message)

# Network
from ._generated.pyarmnn import Optimize, OptimizerOptions, IOptimizedNetwork, IInputSlot, \
    IOutputSlot, IConnectableLayer, INetwork

# Backend
from ._generated.pyarmnn import BackendId
from ._generated.pyarmnn import IDeviceSpec

# Tensors
from ._generated.pyarmnn import TensorInfo, TensorShape

# Runtime
from ._generated.pyarmnn import IRuntime, CreationOptions, INetworkProperties

# Profiler
from ._generated.pyarmnn import IProfiler

# Types
from ._generated.pyarmnn import DataType_Float32, DataType_QuantisedAsymm8, DataType_Signed32, \
    DataType_QuantisedSymm16, DataType_Float16, DataType_QuantizedSymm8PerAxis, \
    DataType_QuantisedSymm8, DataType_Boolean
from ._generated.pyarmnn import DataLayout_NCHW, DataLayout_NHWC

from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_BoundedReLu, ActivationFunction_LeakyReLu, \
    ActivationFunction_Linear, ActivationFunction_ReLu, ActivationFunction_Sigmoid, ActivationFunction_SoftReLu, \
    ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor
from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor
from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor
from ._generated.pyarmnn import ComparisonDescriptor, ComparisonOperation_Equal, ComparisonOperation_Greater, \
    ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \
    ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual
from ._generated.pyarmnn import Convolution2dDescriptor, DepthToSpaceDescriptor, DepthwiseConvolution2dDescriptor, \
    DetectionPostProcessDescriptor, FakeQuantizationDescriptor, FullyConnectedDescriptor, \
    InstanceNormalizationDescriptor, LstmDescriptor, L2NormalizationDescriptor, MeanDescriptor
from ._generated.pyarmnn import NormalizationAlgorithmChannel_Across, NormalizationAlgorithmChannel_Within, \
    NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast, NormalizationDescriptor
from ._generated.pyarmnn import PadDescriptor
from ._generated.pyarmnn import PermutationVector, PermuteDescriptor
from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding_Floor, \
    PaddingMethod_Exclude, PaddingMethod_IgnoreValue, PoolingAlgorithm_Average, PoolingAlgorithm_L2, \
    PoolingAlgorithm_Max, Pooling2dDescriptor
from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \
    ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \
    StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \
    SplitterDescriptor
from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation

from ._generated.pyarmnn import LstmInputParams

# Public API
# Quantization
from ._quantization.quantize_and_dequantize import quantize, dequantize

# Tensor
from ._tensor.tensor import Tensor
from ._tensor.const_tensor import ConstTensor
from ._tensor.workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray

# Utilities
from ._utilities.profiling_helper import ProfilerData, get_profiling_data

from ._version import __version__, __arm_ml_version__

ARMNN_VERSION = GetVersion()


def __check_version():
    from ._version import check_armnn_version
    check_armnn_version(ARMNN_VERSION)


__check_version()

__all__ = []

__private_api_names = ['__check_version']

for name, obj in inspect.getmembers(sys.modules[__name__]):
    if inspect.isclass(obj) or inspect.isfunction(obj):
        if name not in __private_api_names:
            __all__.append(name)