from typing import Callable
from toolz.curried import curry, take, first
from fklearn.tuning.utils import get_best_performing_log, get_avg_metric_from_extractor, get_used_features
from fklearn.types import ExtractorFnType, ListLogListType
StopFnType = Callable[[ListLogListType], bool]
[docs]def aggregate_stop_funcs(*stop_funcs: StopFnType) -> StopFnType:
"""
Aggregate stop functions
Parameters
----------
stop_funcs: list of function list of dict -> bool
Returns
-------
l: function logs -> bool
Function that performs the Or logic of all stop_fn applied to the
logs
"""
def p(logs: ListLogListType) -> bool:
return any([stop_fn(logs) for stop_fn in stop_funcs])
return p
[docs]@curry
def stop_by_iter_num(logs: ListLogListType,
iter_limit: int = 50) -> bool:
"""
Checks for logs to see if feature selection should stop
Parameters
----------
logs : list of list of dict
A list of log-like lists of dictionaries evaluations.
iter_limit: int (default 50)
Limit of Iterations
Returns
----------
stop: bool
A boolean whether to stop recursion or not
"""
return len(logs) >= iter_limit
[docs]@curry
def stop_by_no_improvement(logs: ListLogListType,
extractor: ExtractorFnType,
metric_name: str,
early_stop: int = 3,
threshold: float = 0.001) -> bool:
"""
Checks for logs to see if feature selection should stop
Parameters
----------
logs : list of list of dict
A list of log-like lists of dictionaries evaluations.
extractor: function str -> float
A extractor that take a string and returns the value of that string on a dict
metric_name: str
String with the name of the column that refers to the metric column to be extracted
early_stop: int (default 3)
Number of iteration without improval before stopping
threshold: float (default 0.001)
Threshold for model performance comparison
Returns
----------
stop: bool
A boolean whether to stop recursion or not
"""
if len(logs) < early_stop:
return False
limited_logs = list(take(early_stop, logs))
curr_auc = get_avg_metric_from_extractor(limited_logs[-1], extractor, metric_name)
return all(
[(curr_auc - get_avg_metric_from_extractor(log, extractor, metric_name)) <= threshold
for log in limited_logs[:-1]]
)
[docs]@curry
def stop_by_no_improvement_parallel(logs: ListLogListType,
extractor: ExtractorFnType,
metric_name: str,
early_stop: int = 3,
threshold: float = 0.001) -> bool:
"""
Checks for logs to see if feature selection should stop
Parameters
----------
logs : list of list of dict
A list of log-like lists of dictionaries evaluations.
extractor: function str -> float
A extractor that take a string and returns the value of that string on a dict
metric_name: str
String with the name of the column that refers to the metric column to be extracted
early_stop: int (default 3)
Number of iterations without improvements before stopping
threshold: float (default 0.001)
Threshold for model performance comparison
Returns
----------
stop: bool
A boolean whether to stop recursion or not
"""
if len(logs) < early_stop:
return False
log_list = [get_best_performing_log(log, extractor, metric_name) for log in logs]
limited_logs = list(take(early_stop, log_list))
curr_auc = get_avg_metric_from_extractor(limited_logs[-1], extractor, metric_name)
return all(
[(curr_auc - get_avg_metric_from_extractor(log, extractor, metric_name)) <= threshold
for log in limited_logs[:-1]])
[docs]@curry
def stop_by_num_features(logs: ListLogListType,
min_num_features: int = 50) -> bool:
"""
Checks for logs to see if feature selection should stop
Parameters
----------
logs : list of list of dict
A list of log-like lists of dictionaries evaluations.
min_num_features: int (default 50)
The minimun number of features the model can have before stopping
Returns
-------
stop: bool
A boolean whether to stop recursion or not
"""
return len(get_used_features(first(logs))) <= min_num_features
[docs]@curry
def stop_by_num_features_parallel(logs: ListLogListType,
extractor: ExtractorFnType,
metric_name: str,
min_num_features: int = 50) -> bool:
"""
Selects the best log out of a list to see if feature selection should stop
Parameters
----------
logs : list of list of list of dict
A list of log-like lists of dictionaries evaluations.
extractor: function str -> float
A extractor that take a string and returns the value of that string on a dict
metric_name: str
String with the name of the column that refers to the metric column to be extracted
min_num_features: int (default 50)
The minimun number of features the model can have before stopping
Returns
----------
stop: bool
A boolean whether to stop recursion or not
"""
best_log = get_best_performing_log(first(logs), extractor, metric_name)
return stop_by_num_features([best_log], min_num_features)