Source code for badcrossbar.utils

import logging
import os
import pickle

import numpy as np
import numpy.typing as npt
from pathvalidate import sanitize_filepath

logger = logging.getLogger(__name__)


[docs]def unique_path(path: str, extension: str = "pdf", sanitize: bool = True) -> str: """Append a number to the path, if it is not unique. Args: path: Path of the filename without the extension. extension: File extension. sanitize: If True, sanitizes the filename by removing illegal characters and making the path compatible with the operating system. Returns: Unique path. """ if sanitize: path = sanitize_filepath(path, platform="auto") full_path = f"{path}.{extension}" if os.path.exists(full_path): number = 1 while True: number += 1 new_full_path = f"{path}-{number}.{extension}" if os.path.exists(new_full_path): continue else: full_path = new_full_path break return full_path
[docs]def squeeze_third_axis(array: npt.NDArray) -> npt.NDArray: """Removes third axis of ndarray if it has shape of 1. Args: array: 3D array. Returns: 2D or 3D array. """ if array.ndim == 3: if array.shape[2] == 1: array = np.squeeze(array, axis=2) return array
[docs]def average_if_3D(array: npt.NDArray) -> npt.NDArray: """If array is 3D, it is averaged along the third axis. Args: array: 2D or 3D array. Returns: 2D array. """ if array.ndim == 3: array = np.mean(array, axis=2) return array
[docs]def arrays_shape(*arrays: npt.NDArray): """Returns the shape of the first array that is not None. Args: arrays: Arrays. Returns: Shape. """ for array in arrays: if array is not None: shape = array.shape return shape
[docs]def save_pickle(variable, path: str, allow_overwrite: bool = False, sanitize: bool = True): """Saves variable to a pickle file. Args: variable: Variable to be saved. path: Path to the pickle file, excluding extension. allow_overwrite: If False, will not check for existing files with the same name and will overwrite if such files exist. sanitize: If True, sanitizes the filename by removing illegal characters and making the path compatible with the operating system. """ if sanitize: path = sanitize_filepath(path, platform="auto") if allow_overwrite: path = f"{path}.pickle" else: path = unique_path(path, "pickle") with open(path, "wb") as handle: pickle.dump(variable, handle, protocol=pickle.HIGHEST_PROTOCOL) logger.info("Saved {path}.")
[docs]def load_pickle(path: str, sanitize: bool = True): """Loads pickle file. Args: path: Path to the pickle file, including extension. sanitize: If True, sanitizes the filename by removing illegal characters and making the path compatible with the operating system. Returns: Extracted contents. """ if sanitize: path = sanitize_filepath(path, platform="auto") with open(path, "rb") as handle: variable = pickle.load(handle) return variable
[docs]def distributed_array(flattened_array: npt.NDArray, model_array: npt.NDArray) -> npt.NDArray: """Reshapes flattened array. Args: flattened_array: An array whose each column contains a flattened array. model_array: An array whose shape is used for reshaping. Returns: Array or a list of arrays in specified shape. """ reshaped_i = flattened_array.reshape( (model_array.shape[0], model_array.shape[1], flattened_array.shape[1]) ) reshaped_i = squeeze_third_axis(reshaped_i) return reshaped_i