diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index f629ba1..f1e2976 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -29,8 +29,8 @@ class PruningConfiguration(OptimizerConfiguration): optimization_target: float layers_to_optimize: Optional[List[str]] = None - x_train: Optional[np.array] = None - y_train: Optional[np.array] = None + x_train: Optional[np.ndarray] = None + y_train: Optional[np.ndarray] = None batch_size: int = 1 num_epochs: int = 1 @@ -73,7 +73,7 @@ class Pruner(Optimizer): """Return string representation of the optimization config.""" return str(self.optimizer_configuration) - def _mock_train_data(self) -> Tuple[np.array, np.array]: + def _mock_train_data(self) -> Tuple[np.ndarray, np.ndarray]: # get rid of the batch_size dimension in input and output shape input_shape = tuple(x for x in self.model.input_shape if x is not None) output_shape = tuple(x for x in self.model.output_shape if x is not None) |