Source code for mlprimitives.cli

# -*- coding: utf-8 -*-

"""MLPrimitives Command Line Interface module."""

import argparse
import logging
import os
import sys
import warnings

import pandas as pd
from mlblocks import add_primitives_path, get_primitives_paths

from mlprimitives.evaluation import score_pipeline

LOGGER = logging.getLogger(__name__)


def _logging_setup(verbosity=1):
    logger = logging.getLogger()
    log_level = (3 - verbosity) * 10
    fmt = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
    formatter = logging.Formatter(fmt)
    logger.setLevel(log_level)
    logger.propagate = False

    console_handler = logging.StreamHandler()
    console_handler.setLevel(log_level)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)


def _test(args):
    results = pd.DataFrame(columns=['pipeline', 'mean', 'std', 'error'])
    try:
        for pipeline in args.pipeline:
            print('Scoring pipeline: {}'.format(pipeline))
            pipeline_name = os.path.basename(pipeline)
            try:
                score, std = score_pipeline(
                    pipeline,
                    args.splits,
                    args.random_state,
                    args.dataset
                )

                print('Obtained Score: {:.4f} +/- {:.4f}'.format(score, std))
                results = results.append({
                    'pipeline': pipeline_name,
                    'mean': score,
                    'std': std,
                }, ignore_index=True)

            except Exception as ex:
                results = results.append({
                    'pipeline': pipeline_name,
                    'error': ex,
                }, ignore_index=True)

    except KeyboardInterrupt:
        pass

    print(results.to_string(index=False))


def _get_primitives(pattern):
    primitives = list()
    for base_path in get_primitives_paths():
        if os.path.exists(base_path):
            for filename in os.listdir(base_path):
                if pattern in filename and filename.endswith('.json'):
                    primitives.append(filename[:-5])

    return list(sorted(primitives))


def _list(args):
    print('\n'.join(_get_primitives(args.pattern)))


[docs]class ArgumentParser(argparse.ArgumentParser):
[docs] def error(self, message): sys.stderr.write('\nERROR: {}\n\n'.format(message)) self.print_help() sys.exit(2)
def _get_parser(): parser = ArgumentParser( description='MLPrimitives Command Line Interface') parser.add_argument( '-p', '--primitives-path', action='append', help=( 'Path where primitives should be looked for. Use multiple ' 'times in order to add multiple directories' ) ) parser.add_argument('-v', '--verbose', action='count', default=0) subparsers = parser.add_subparsers(title='action', help='Action to perform') parser.set_defaults(action=None) subparser = subparsers.add_parser('test', help='Test a single pipeline.') subparser.set_defaults(action=_test) subparser.add_argument('-s', '--splits', default=1, type=int, help='Number of splits to use for Cross Validation') subparser.add_argument('-r', '--random-state', default=0, type=int, help='Random State to use for Cross Validation') subparser.add_argument('-d', '--dataset', help='Dataset to validate with.') subparser.add_argument('pipeline', nargs='+') subparser = subparsers.add_parser('list', help='List available primitives') subparser.set_defaults(action=_list) subparser.add_argument('pattern', nargs='?', default='') return parser def _add_primitives_paths(paths): if paths: for path in paths: add_primitives_path(path) def _process_common_args(args): _add_primitives_paths(args.primitives_path) _logging_setup(args.verbose)
[docs]def main(): warnings.filterwarnings('ignore', category=DeprecationWarning) parser = _get_parser() args = parser.parse_args() if not args.action: parser.print_help() sys.exit(0) _process_common_args(args) args.action(args)