Source code for getml.data.split.time

# Copyright 2021 The SQLNet Company GmbH

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

"""
Splits data at random.
"""

import numbers

import numpy as np

from getml.data.columns import FloatColumn, FloatColumnView, StringColumnView
from getml.data.columns.from_value import from_value
from getml.data.data_frame import DataFrame
from getml.data.helpers import _is_typed_list
from getml.data.view import View


[docs]def time(population, time_stamp, validation=None, test=None, **kwargs): """ Returns a :class:`~getml.data.columns.StringColumnView` that can be used to divide data into training, testing, validation or other sets. The arguments are :code:`key=value` pairs of names (:code:`key`) and starting points (:code:`value`). The starting point defines the left endpoint of the subset. Intervals are left closed and right open, such that :math:`[value, next value)`. The (unnamed) subset left from the first named starting point, i.e. :math:`[0, first value)`, is always considered to be the training set. Args: population (:class:`~getml.DataFrame` or :class:`~getml.data.View`): The population table you would like to split. time_stamp (str): The name of the time stamp column in the population table you want to use. Ideally, the role of said column would be :const:`~getml.data.roles.time_stamp`. If you want to split on the rowid, then pass "rowid" to `time_stamp`. validation (float, optional): The start date of the validation set. test (float, optional): The start date of the test set. kwargs (float, optional): Any other sets you would like to assign. You can name these sets whatever you want to (in our example, we called it 'other'). Example: .. code-block:: python validation_begin = getml.data.time.datetime(2010, 1, 1) test_begin = getml.data.time.datetime(2011, 1, 1) other_begin = getml.data.time.datetime(2012, 1, 1) split = getml.data.split.time( population=data_frame, time_stamp="ds", test=test_begin, validation=validation_begin, other=other_begin ) # Contains all data before 2010-01-01 (not included) train_set = data_frame[split=='train'] # Contains all data between 2010-01-01 (included) and 2011-01-01 (not included) validation_set = data_frame[split=='validation'] # Contains all data between 2011-01-01 (included) and 2012-01-01 (not included) test_set = data_frame[split=='test'] # Contains all data after 2012-01-01 (included) other_set = data_frame[split=='other'] """ if not isinstance(population, (DataFrame, View)): raise ValueError("'population' must be a DataFrame or a View.") if not isinstance(time_stamp, (str, FloatColumn, FloatColumnView)): raise ValueError( "'time_stamp' must be a string, a FloatColumn, or a FloatColumnView." ) if not test and not validation and not kwargs: raise ValueError("You have to supply at least one starting point.") defaults = {"test": test, "validation": validation} sets = {name: value for name, value in defaults.items() if value is not None} sets.update({**kwargs}) values = np.asarray(list(sets.values())) index = np.argsort(values) values = values[index] if not _is_typed_list(values.tolist(), numbers.Real): raise ValueError("All values must be real numbers.") names = np.asarray(list(sets.keys())) names = names[index] if isinstance(time_stamp, str): time_stamp_col = ( population[time_stamp] if time_stamp != "rowid" else population.rowid ) else: time_stamp_col = time_stamp col = from_value("train") assert isinstance(col, StringColumnView), "Should be a StringColumnView" for i in range(len(names)): col = col.update( time_stamp_col >= values[i], names[i], ) return col