Source code for mlprimitives.adapters.networkx

# -*- coding: utf-8 -*-

import logging

import numpy as np
import pandas as pd

from mlprimitives.utils import import_object

LOGGER = logging.getLogger(__name__)


[docs]def graph_pairs_feature_extraction(X, functions, node_columns, graph=None): functions = [import_object(function) for function in functions] X = X.copy() pairs = X[node_columns].values # for i, graph in enumerate(graphs): def apply(function): try: values = function(graph, pairs) return np.array(list(values))[:, 2] except ZeroDivisionError: LOGGER.warn("ZeroDivisionError captured running %s", function) return np.zeros(len(pairs)) for function in functions: name = '{}_{}_{}'.format(function.__name__, *node_columns) X[name] = apply(function) return X
[docs]def graph_feature_extraction(X, functions, graphs): functions = [import_object(function) for function in functions] for node_column, graph in graphs.items(): index_type = type(X[node_column].values[0]) features = pd.DataFrame(index=graph.nodes) features.index = features.index.astype(index_type) def apply(function): values = function(graph) return np.array(list(values.values())) for function in functions: name = '{}_{}'.format(function.__name__, node_column) features[name] = apply(function) X = X.merge(features, left_on=node_column, right_index=True, how='left') graph_data = pd.DataFrame(dict(graph.nodes.items())).T graph_data.index = graph_data.index.astype(index_type) X = X.merge(graph_data, left_on=node_column, right_index=True, how='left') return X