Module nowcastlib.pipeline.utils
Shared functionality across the Nowcast Library Pipeline submodule
Expand source code
"""Shared functionality across the Nowcast Library Pipeline submodule"""
from typing import Union, Type, Any
import pandas as pd
import numpy as np
import attr
import cattr
from nowcastlib.pipeline.structs import config
cattr_cnvrtr = cattr.GenConverter(forbid_extra_keys=True)
def build_field_name(options: config.ProcessingOptions, field_name: str):
"""
Builds the appropriate field name depending on whether
the user wishes to overwrite the current field or not
Parameters
----------
options : nowcastlib.pipeline.structs.config.ProcessingOptions
field_name : str
the name of the current field we are acting on
Returns
-------
str
the resulting string
"""
if options.overwrite:
computed_field_name = field_name
else:
computed_field_name = "processed_{}".format(field_name)
return computed_field_name
def rename_protected_field(field: config.RawField) -> config.RawField:
"""
Renames overwrite-protected fields so to obtain a list of fields that
are overwrite-able
"""
if field.preprocessing_options is not None:
if field.preprocessing_options.overwrite is False:
correct_name = build_field_name(
field.preprocessing_options, field.field_name
)
return cattr.structure(
{
"field_name": correct_name,
**attr.asdict(
field,
filter=lambda attrib, _: attrib.name != "field_name",
),
},
config.RawField,
)
else:
return field
else:
return field
def disambiguate_intfloatstr(train_split: Any, _klass: Type) -> Union[int, float, str]:
"""Disambiguates Union of int, float and str for cattrs"""
if isinstance(train_split, int):
return int(train_split)
elif isinstance(train_split, float):
return float(train_split)
elif isinstance(train_split, str):
return str(train_split)
else:
raise ValueError("Cannot disambiguate Union[int, flaot, str]")
def handle_serialization(
data: Union[pd.core.frame.DataFrame, np.ndarray],
options: config.SerializationOptions,
):
"""
Serializes a given dataframe or numpy array
to disk in the appropriate format
"""
if isinstance(data, pd.core.frame.DataFrame):
if options.output_format == "csv":
data.to_csv(options.output_path, float_format="%g")
elif options.output_format == "pickle":
data.to_pickle(options.output_path)
else:
if options.output_format == "npy":
np.save(options.output_path, data)
def yes_or_no(question):
"""
Asks the user a yes or no question, parsing the answer
accordingly.
"""
while "the answer is invalid":
reply = str(input(question + " (y/n): ")).lower().strip()
if reply[:1] == "y":
return True
if reply[:1] == "n":
return False
else:
print("Please enter either 'y'/'Y' or 'n'/'N'")
Functions
def build_field_name(options: ProcessingOptions, field_name: str)
-
Builds the appropriate field name depending on whether the user wishes to overwrite the current field or not
Parameters
options
:ProcessingOptions
field_name
:str
- the name of the current field we are acting on
Returns
str
- the resulting string
Expand source code
def build_field_name(options: config.ProcessingOptions, field_name: str): """ Builds the appropriate field name depending on whether the user wishes to overwrite the current field or not Parameters ---------- options : nowcastlib.pipeline.structs.config.ProcessingOptions field_name : str the name of the current field we are acting on Returns ------- str the resulting string """ if options.overwrite: computed_field_name = field_name else: computed_field_name = "processed_{}".format(field_name) return computed_field_name
def rename_protected_field(field: RawField) ‑> RawField
-
Renames overwrite-protected fields so to obtain a list of fields that are overwrite-able
Expand source code
def rename_protected_field(field: config.RawField) -> config.RawField: """ Renames overwrite-protected fields so to obtain a list of fields that are overwrite-able """ if field.preprocessing_options is not None: if field.preprocessing_options.overwrite is False: correct_name = build_field_name( field.preprocessing_options, field.field_name ) return cattr.structure( { "field_name": correct_name, **attr.asdict( field, filter=lambda attrib, _: attrib.name != "field_name", ), }, config.RawField, ) else: return field else: return field
def disambiguate_intfloatstr(train_split: Any, _klass: Type) ‑> Union[int, float, str]
-
Disambiguates Union of int, float and str for cattrs
Expand source code
def disambiguate_intfloatstr(train_split: Any, _klass: Type) -> Union[int, float, str]: """Disambiguates Union of int, float and str for cattrs""" if isinstance(train_split, int): return int(train_split) elif isinstance(train_split, float): return float(train_split) elif isinstance(train_split, str): return str(train_split) else: raise ValueError("Cannot disambiguate Union[int, flaot, str]")
def handle_serialization(data: Union[pandas.core.frame.DataFrame, numpy.ndarray], options: SerializationOptions)
-
Serializes a given dataframe or numpy array to disk in the appropriate format
Expand source code
def handle_serialization( data: Union[pd.core.frame.DataFrame, np.ndarray], options: config.SerializationOptions, ): """ Serializes a given dataframe or numpy array to disk in the appropriate format """ if isinstance(data, pd.core.frame.DataFrame): if options.output_format == "csv": data.to_csv(options.output_path, float_format="%g") elif options.output_format == "pickle": data.to_pickle(options.output_path) else: if options.output_format == "npy": np.save(options.output_path, data)
def yes_or_no(question)
-
Asks the user a yes or no question, parsing the answer accordingly.
Expand source code
def yes_or_no(question): """ Asks the user a yes or no question, parsing the answer accordingly. """ while "the answer is invalid": reply = str(input(question + " (y/n): ")).lower().strip() if reply[:1] == "y": return True if reply[:1] == "n": return False else: print("Please enter either 'y'/'Y' or 'n'/'N'")