Source code for getml.feature_learning.relmt

# Copyright 2021 The SQLNet Company GmbH

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

"""
Feature learning based on Gradient Boosting.
"""

from dataclasses import dataclass

from typing import ClassVar, List, Union

from .fastprop import FastProp
from .feature_learner import _FeatureLearner
from .loss_functions import SquareLoss
from .validation import _validate_relboost_parameters

# --------------------------------------------------------------------


[docs]@dataclass(repr=False) class RelMT(_FeatureLearner): """Feature learning based on relational linear model trees. :class:`~getml.feature_learning.RelMT` automates feature learning for relational data and time series. It is based on a generalization of linear model trees to relational data, hence the name. A linear model tree is a decision tree with linear models on its leaves. Args: allow_avg (bool, optional): Whether to allow an AVG aggregation. Particularly for time series problems, AVG aggregations are not necessary and you can save some time by taking them out. delta_t (float, optional): Frequency with which lag variables will be explored in a time series setting. When set to 0.0, there will be no lag variables. For more information, please refer to :ref:`data_model_time_series`. Range: [0, :math:`\\infty`] gamma (float, optional): During the training of RelMT, which is based on gradient tree boosting, this value serves as the minimum improvement in terms of the `loss_function` required for a split of the tree to be applied. Larger `gamma` will lead to fewer partitions of the tree and a more conservative algorithm. Range: [0, :math:`\\infty`] loss_function (:class:`~getml.feature_learning.loss_functions`, optional): Objective function used by the feature learning algorithm to optimize your features. For regression problems use :class:`~getml.feature_learning.loss_functions.SquareLoss` and for classification problems use :class:`~getml.feature_learning.loss_functions.CrossEntropyLoss`. max_depth (int, optional): Maximum depth of the trees generated during the gradient tree boosting. Deeper trees will result in more complex models and increase the risk of overfitting. Range: [0, :math:`\\infty`] min_df (int, optional): Only relevant for columns with role :const:`~getml.data.roles.text`. The minimum number of fields (i.e. rows) in :const:`~getml.data.roles.text` column a given word is required to appear in to be included in the bag of words. Range: [1, :math:`\\infty`] min_num_samples (int, optional): Determines the minimum number of samples a subcondition should apply to in order for it to be considered. Higher values lead to less complex statements and less danger of overfitting. Range: [1, :math:`\\infty`] num_features (int, optional): Number of features generated by the feature learning algorithm. Range: [1, :math:`\\infty`] num_subfeatures (int, optional): The number of subfeatures you would like to extract in a subensemble (for snowflake data model only). See :ref:`data_model_snowflake_schema` for more information. Range: [1, :math:`\\infty`] num_threads (int, optional): Number of threads used by the feature learning algorithm. If set to zero or a negative value, the number of threads will be determined automatically by the getML engine. Range: [-:math:`\\infty`, :math:`\\infty`] propositionalization (:class:`~getml.feature_learning.FastProp`, optional): The feature learner used for joins which are flagged to be propositionalized (by setting a join's `relationship` parameter to :const:`getml.data.relationship.propositionalization`) reg_lambda (float, optional): L2 regularization on the weights in the gradient boosting routine. This is one of the most important hyperparameters in the :class:`~getml.feature_learning.RelMT` as it allows for the most direct regularization. Larger values will make the resulting model more conservative. Range: [0, :math:`\\infty`] sampling_factor (float, optional): RelMT uses a bootstrapping procedure (sampling with replacement) to train each of the features. The sampling factor is proportional to the share of the samples randomly drawn from the population table every time RelMT generates a new feature. A lower sampling factor (but still greater than 0.0), will lead to less danger of overfitting, less complex statements and faster training. When set to 1.0, roughly 20,000 samples are drawn from the population table. If the population table contains less than 20,000 samples, it will use standard bagging. When set to 0.0, there will be no sampling at all. Range: [0, :math:`\\infty`] seed (Union[int,None], optional): Seed used for the random number generator that underlies the sampling procedure to make the calculation reproducible. Internally, a `seed` of None will be mapped to 5543. Range: [0, :math:`\\infty`] shrinkage (float, optional): Since RelMT works using a gradient-boosting-like algorithm, `shrinkage` (or learning rate) scales down the weights and thus the impact of each new tree. This gives more room for future ones to improve the overall performance of the model in this greedy algorithm. It must be between 0.0 and 1.0 with higher values leading to more danger of overfitting. Range: [0, 1] silent (bool, optional): Controls the logging during training. vocab_size (int, optional): Determines the maximum number of words that are extracted in total from :const:`getml.data.roles.text` columns. This can be interpreted as the maximum size of the bag of words. Range: [0, :math:`\\infty`] """ # ---------------------------------------------------------------- allow_avg: bool = True delta_t: float = 0.0 gamma: float = 0.0 loss_function: Union[str, None] = None max_depth: int = 2 min_df: int = 30 min_num_samples: int = 1 num_features: int = 30 num_subfeatures: int = 30 num_threads: int = 0 propositionalization: FastProp = FastProp() reg_lambda: float = 0.0 sampling_factor: float = 1.0 seed: int = 5543 shrinkage: float = 0.1 silent: bool = True vocab_size: int = 500 # ----------------------------------------------------------------
[docs] def validate(self, params=None): """Checks both the types and the values of all instance variables and raises an exception if something is off. Args: params (dict, optional): A dictionary containing the parameters to validate. If not is passed, the own parameters will be validated. """ # ------------------------------------------------------------ if params is None: params = self.__dict__ else: params = {**self.__dict__, **params} # ------------------------------------------------------------ if not isinstance(params, dict): raise ValueError("params must be None or a dictionary!") # ------------------------------------------------------------ for kkey in params: if kkey not in type(self)._supported_params: raise KeyError( f"Instance variable '{kkey}' is not supported in {self.type}." ) # ------------------------------------------------------------ if not isinstance(params["silent"], bool): raise TypeError("'silent' must be of type bool") # ------------------------------------------------------------ _validate_relboost_parameters(**params)
# ------------------------------------------------------------