Source code for autobazaar.utils

# -*- coding: utf-8 -*-

import os
import tempfile
from collections import defaultdict
from datetime import datetime

import numpy as np
from sklearn.preprocessing import LabelEncoder


[docs]def encode_score(scorer, expected, observed): if expected.dtype == 'object': le = LabelEncoder() expected = le.fit_transform(expected) observed = le.transform(observed) return scorer(expected, observed)
[docs]def ensure_dir(directory): """Create diretory if it does not exist yet.""" if not os.path.exists(directory): os.makedirs(directory)
[docs]def make_dumpable(params, datetimes=False): """Get nested dicts of params to allow json dumping. Also work around this: https://github.com/HDI-Project/BTB/issues/79 And convert numpy types to primitive types. Optionally dump datetimes to ISO format. Args: params (dict): Params dictionary with tuples as keys. datetimes (bool): whether to convert datetimes to ISO strings or not. Returns: dict: Dumpable params as a tree of dicts and nested sub-dicts. """ nested_params = defaultdict(dict) for (block, param), value in params.items(): if isinstance(value, np.integer): value = int(value) elif isinstance(value, np.floating): value = float(value) elif isinstance(value, np.ndarray): value = value.tolist() elif isinstance(value, np.bool_): value = bool(value) elif value == 'None': value = None elif datetimes and isinstance(value, datetime): value = value.isoformat() nested_params[block][param] = value return nested_params
def _walk(document, transform): if not isinstance(document, dict): return document new_doc = dict() for key, value in document.items(): if isinstance(value, dict): value = _walk(value, transform) elif isinstance(value, list): value = [_walk(v, transform) for v in value] new_key, new_value = transform(key, value) new_doc[new_key] = new_value return new_doc
[docs]def remove_dots(document): """Replace dots with dashes in all the keys from the dictionary.""" return _walk(document, lambda key, value: (key.replace('.', '-'), value))
[docs]def restore_dots(document): """Replace dashes with dots in all the keys from the dictionary.""" return _walk(document, lambda key, value: (key.replace('-', '.'), value))
[docs]def make_keras_picklable(): """Make the keras models picklable.""" import keras.models # noqa: lazy import slow dependencies def __getstate__(self): model_str = "" with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd: keras.models.save_model(self, fd.name, overwrite=True) model_str = fd.read() return {'model_str': model_str} def __setstate__(self, state): with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd: fd.write(state['model_str']) fd.flush() try: model = keras.models.load_model(fd.name) except ValueError: from keras.applications import mobilenet from keras.utils.generic_utils import CustomObjectScope scope = { 'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D } with CustomObjectScope(scope): model = keras.models.load_model(fd.name) self.__dict__ = model.__dict__ cls = keras.models.Model cls.__getstate__ = __getstate__ cls.__setstate__ = __setstate__