Source code for fklearn.training.pipeline

from collections import defaultdict
from inspect import Parameter, signature
from typing import Dict

import pandas as pd
import toolz as fp

from fklearn.types import LearnerFnType, LearnerReturnType, PredictFnType


[docs]def build_pipeline(*learners: LearnerFnType, has_repeated_learners: bool = False) -> LearnerFnType: """ Builds a pipeline of different chained learners functions with the possibility of using keyword arguments in the predict functions of the pipeline. Say you have two learners, you create a pipeline with `pipeline = build_pipeline(learner1, learner2)`. Those learners must be functions with just one unfilled argument (the dataset itself). Then, you train the pipeline with `predict_fn, transformed_df, logs = pipeline(df)`, which will be like applying the learners in the following order: `learner2(learner1(df))`. Finally, you predict on different datasets with `pred_df = predict_fn(new_df)`, with optional kwargs. For example, if you have XGBoost or LightGBM, you can get SHAP values with `predict_fn(new_df, apply_shap=True)`. Parameters ---------- learners : partially-applied learner functions. has_repeated_learners : bool Boolean value indicating wheter the pipeline contains learners with the same name or not. Returns ---------- p : function pandas.DataFrame, **kwargs -> pandas.DataFrame A function that when applied to a DataFrame will apply all learner functions in sequence, with optional kwargs. new_df : pandas.DataFrame A DataFrame that is the result of applying all learner function in sequence. log : dict A log-like Dict that stores information of all learner functions. """ def _has_one_unfilled_arg(learner: LearnerFnType) -> None: no_default_list = [p for p, a in signature(learner).parameters.items() if a.default == '__no__default__'] if len(no_default_list) > 1: raise ValueError("Learner {0} has more than one unfilled argument: {1}\n" "Make sure all learners are curried properly and only require one argument," " which is the dataset (usually `df`)." .format(learner.__name__, ', '.join(no_default_list))) def _no_variable_args(learner: LearnerFnType, predict_fn: PredictFnType) -> None: invalid_parameter_kinds = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) var_args = [p for p, a in signature(predict_fn).parameters.items() if a.kind in invalid_parameter_kinds] if len(var_args) != 0: raise ValueError("Predict function of learner {0} contains variable length arguments: {1}\n" "Make sure no predict function uses arguments like *args or **kwargs." .format(learner.__name__, ', '.join(var_args))) # Check for unfilled arguments of learners for learner in learners: _has_one_unfilled_arg(learner) def pipeline(data: pd.DataFrame) -> LearnerReturnType: current_data = data.copy() features = list(data.columns) fns = [] logs = [] pipeline = [] serialisation = defaultdict(list) # type: dict for learner in learners: learner_fn, new_data, learner_log = learner(current_data) # Check for invalid predict fn arguments _no_variable_args(learner, learner_fn) learner_name = learner.__name__ fns.append(learner_fn) pipeline.append(learner_name) current_data = new_data model_objects = {} if learner_log.get("obj"): model_objects["obj"] = learner_log.pop("obj") serialisation[learner_name].append({"fn": learner_fn, "log": learner_log, **model_objects}) logs.append(learner_log) merged_logs = fp.merge(logs) def predict_fn(df: pd.DataFrame, **kwargs: Dict) -> pd.DataFrame: # Get the proper arguments for each predict function (based on their signature) fns_args = [{k: v for k, v in kwargs.items() if k in signature(f).parameters} for f in fns] # Partially apply the arguments to the predict functions when applicable fns_with_args = [fp.curry(fn)(**args) if len(args) > 0 else fn for fn, args in zip(fns, fns_args)] return fp.pipe(df, *fns_with_args) serialisation_logs = {k: v if has_repeated_learners else v[-1] for k, v in serialisation.items()} merged_logs["__fkml__"] = {"pipeline": pipeline, "output_columns": list(current_data.columns), "features": features, "learners": {**serialisation_logs}} return predict_fn, current_data, merged_logs return pipeline