# Copyright 2021 The SQLNet Company GmbH
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
"""
Custom class for handling the SQL code of the features.
"""
import os
import re
from pathlib import Path
import shutil
from getml.data.helpers import _is_typed_list
[docs]class SQLCode:
"""
Custom class for handling the SQL code of the
features generated by the pipeline.
Example:
.. code-block:: python
sql_code = my_pipeline.features.to_sql()
# You can access individual features
# by index.
feature_1_1 = sql_code[0]
# You can also access them by name.
feature_1_10 = sql_code["FEATURE_1_10"]
# You can also type the name of
# a table or column to find all
# features related to that table
# or column.
features = sql_code.find("SOME_TABLE")
# HINT: The generated SQL code always
# escapes table and column names using
# quotation marks. So if you want exact
# matching, you can do this:
features = sql_code.find('"SOME_TABLE"')
"""
def __init__(self, code):
if not _is_typed_list(code, str):
raise TypeError("'code' must be a list of str.")
self.code = code
def __getitem__(self, key):
if not isinstance(key, str):
return SQLCode([self.code[key]])
if key[:8].lower() != "feature_":
return SQLCode([])
return self.find('CREATE TABLE "' + key.upper() + '"')
def __len__(self):
return len(self.code)
def __repr__(self):
return "\n\n\n".join(self.code)
def _repr_markdown_(self):
return "```sql\n" + self.__repr__() + "\n```"
[docs] def find(self, keyword):
"""
Returns the SQLCode for all features
containing the keyword.
Args:
keyword (str): The keyword to be found.
"""
if not isinstance(keyword, str):
raise TypeError("'keyword' must be a str.")
return SQLCode([elem for elem in self.code if keyword in elem])
[docs] def save(self, fname, split=False):
"""
Saves the SQL code to a file.
Args:
fname (str): The name of the file or folder (if `split` is True) in which you
want to save the features.
split (bool): If True, the code will be split into multiple files, one for
each feature and saved into a folder `fname`.
"""
if not split:
with open(fname, "w") as f:
f.write(self.__repr__())
return
if os.path.exists(fname):
shutil.rmtree(fname, ignore_errors=True)
dir = Path(fname)
dir.mkdir(exist_ok=True)
for index, code in enumerate(self.code, 1):
match = re.search('DROP TABLE IF EXISTS "(.+)"', code)
name = match.group(1).lower()
file_path = dir / f"{index:04d}_{name}.sql"
with open(file_path, "w") as f:
f.write(code)
[docs] def to_str(self):
"""
Returns a raw string representation of the SQL code.
"""
return str(self)