Source code for mlprimitives.utils

# -*- coding: utf-8 -*-

import importlib
import logging
import math

import numpy as np

LOGGER = logging.getLogger(__name__)


[docs]def import_object(object_name): """Import an object from its Fully Qualified Name.""" if isinstance(object_name, str): parent_name, attribute = object_name.rsplit('.', 1) try: parent = importlib.import_module(parent_name) except ImportError: grand_parent_name, parent_name = parent_name.rsplit('.', 1) grand_parent = importlib.import_module(grand_parent_name) parent = getattr(grand_parent, parent_name) return getattr(parent, attribute) return object_name
[docs]def image_transform(X, function, reshape_before=False, reshape_after=False, width=None, height=None, **kwargs): """Apply a function image by image. Args: reshape_before: whether 1d array needs to be reshaped to a 2d image reshape_after: whether the returned values need to be reshaped back to a 1d array width: image width used to rebuild the 2d images. Required if the image is not square. height: image height used to rebuild the 2d images. Required if the image is not square. """ if not callable(function): function = import_object(function) elif not callable(function): raise ValueError("function must be a str or a callable") flat_image = len(X[0].shape) == 1 if reshape_before and flat_image: if not (width and height): side_length = math.sqrt(X.shape[1]) if side_length.is_integer(): side_length = int(side_length) width = side_length height = side_length else: raise ValueError("Image sizes must be given for non-square images") else: reshape_before = False new_X = [] for image in X: if reshape_before: image = image.reshape((width, height)) features = function( image, **kwargs ) if reshape_after: features = np.reshape(features, X.shape[1]) new_X.append(features) return np.array(new_X)
NUMPY_AGGREGATIONS = { 'min': np.min, 'max': np.max, 'sum': np.sum, 'prod': np.prod, 'mean': np.mean, 'median': np.median, 'std': np.std, 'var': np.var, } NUMPY_NAN_AGGREGATIONS = { 'min': np.nanmin, 'max': np.nanmax, 'sum': np.nansum, 'prod': np.nanprod, 'mean': np.nanmean, 'median': np.nanmedian, 'std': np.nanstd, 'var': np.nanvar, }
[docs]def np_aggregate(array, aggregation, skipna=True, *args, **kwargs): functions = NUMPY_NAN_AGGREGATIONS if skipna else NUMPY_AGGREGATIONS function = functions.get(aggregation) if function is None: raise ValueError('Unknown aggregation: {}'.format(aggregation)) return function(array, *args, **kwargs)