Source code for pycmor.core.utils

"""
Various utility functions needed around the package
"""

import importlib
import inspect
import os
import tempfile
import time
from functools import partial

import pkg_resources
import requests

from .logging import logger


[docs] def get_callable(name): """Get a callable from a string First, tries standard import, then tries entry points, then from script """ try: return get_callable_by_name(name) except (ImportError, AttributeError): pass try: return get_entrypoint_by_name(name) except ValueError: pass try: return get_callable_by_script(name) except ValueError: pass raise ValueError(f"Callable '{name}' not found")
[docs] def get_callable_by_name(name): """ Get a callable by its name. This function takes a string that represents the fully qualified name of a callable object (i.e., a function or a method), and returns the actual callable object. The name should be in the format 'module.submodule.callable'. If the callable does not exist, this function will raise an AttributeError. Parameters ---------- name : str The fully qualified name of the callable to be retrieved. It should be in the format 'module.submodule.callable'. Returns ------- callable The callable object that corresponds to the given name. Raises ------ ImportError If the module or submodule specified in the name does not exist. AttributeError If the callable specified in the name does not exist in the given module or submodule. """ if "." not in name: raise ValueError( f"Name '{name}' is not a fully qualified name. It should be in the format 'module.submodule.callable'." ) module_name, callable_name = name.rsplit(".", 1) logger.debug(f"Importing module '{module_name}' to get callable '{callable_name}'") module = __import__(module_name, fromlist=[callable_name]) return getattr(module, callable_name)
[docs] def get_entrypoint_by_name(name, group="pycmor.steps"): """ Get an entry point by its name. This function takes a string that represents the name of an entry point in a given group, and returns the actual entry point object. If the entry point does not exist, this function will raise a ValueError. Parameters ---------- name : str The name of the entry point to be retrieved. group : str The group that the entry point belongs to. Returns ------- EntryPoint The entry point object that corresponds to the given name. Raises ------ ValueError If the entry point specified by the name does not exist in the given group. """ logger.debug(f"Getting entry point '{name}' from group '{group}'") groups_to_try = [group] if group == "pycmor.steps": groups_to_try.append("pymor.steps") # legacy fallback for grp in groups_to_try: for entry_point in pkg_resources.iter_entry_points(group=grp): if entry_point.name == name: return entry_point.load() raise ValueError(f"Entry point '{name}' not found in groups {groups_to_try}")
[docs] def generate_partial_function(func: callable, open_arg: str, *args, **kwargs): """ Reduces func to a partial function by fixing all but the argument named by open_arg. Parameters ---------- func : callable The function to be partially applied. open_arg : str The name of the argument that should remain open in the partial function. *args Positional arguments to be passed to the partial function. **kwargs Keyword arguments to be passed to the partial function. Returns ------- callable The partial function with the specified arguments fixed. """ if not can_be_partialized(func, open_arg, args, kwargs): raise ValueError( f"Function '{func.__name__}' cannot be partially applied with open " f"argument '{open_arg}' by using the provided arguments {args=} and " f"keyword arguments {kwargs=}." ) logger.debug( f"Generating partial function for '{func.__name__}' with open argument '{open_arg}'" ) # Get the signature of the function signature = inspect.signature(func) # Get the parameter names param_names = list(signature.parameters.keys()) # Get the index of the open argument open_arg_index = param_names.index(open_arg) # Get the names of the arguments to be fixed fixed_args = ( param_names[:open_arg_index] + param_names[open_arg_index + 1 :] # noqa: E203 ) # Get the values of the arguments to be fixed fixed_values = [kwargs[arg] for arg in fixed_args if arg in kwargs] # Remove the fixed arguments from the keyword arguments for arg in fixed_args: kwargs.pop(arg, None) # Create the partial function return partial(func, *fixed_values, *args, **kwargs)
[docs] def can_be_partialized( func: callable, open_arg: str, arg_list: list, kwargs_dict: dict ) -> bool: """ Checks if a function can be reasonably partialized with a single argument open. Parameters ---------- func : callable The function to be partially applied. open_arg : str The name of the argument that should remain open in the partial function. arg_list : list The list of arguments that will be passed to the partial function. kwargs_dict : dict The dictionary of keyword arguments that will be passed to the partial function. Returns ------- bool True if the function can be partially applied with a single argument open, False otherwise. """ signature = inspect.signature(func) param_names = list(signature.parameters.keys()) # Check that all arguments in arg_list are in the function signature for arg in arg_list: if arg in param_names: param_names.remove(arg) for kwarg in kwargs_dict: if kwarg in param_names: param_names.remove(kwarg) # Check that there is only one argument left and that it is open_arg return len(param_names) == 1 and param_names[0] == open_arg
[docs] def get_function_from_script(script_path: str, function_name: str): """ Get a function from a Python script. This function takes the path to a Python script and the name of a function defined in that script, and returns the actual function object. If the script does not exist or the function is not defined in the script, this function will raise an ImportError. Parameters ---------- script_path : str The path to the Python script where the function is defined. function_name : str The name of the function to be retrieved. Returns ------- callable The function object that corresponds to the given name in the specified script. Raises ------ ImportError If the script does not exist or the function is not defined in the script. """ logger.debug(f"Importing function '{function_name}' from script '{script_path}'") spec = importlib.util.spec_from_file_location("script", script_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return getattr(module, function_name)
[docs] def get_callable_by_script(step_signature): if not step_signature.startswith("script://"): raise ValueError(f"Step signature '{step_signature}' is not a script step") script_spec = step_signature.split("script://")[1] script_path = script_spec.split(":")[0] function_name = script_spec.split(":")[1] return get_function_from_script(script_path, function_name)
[docs] def wait_for_workers(client, n_workers, timeout=600): """ Wait for a specific number of workers to be available. Args: client (distributed.Client): The Dask client n_workers (int): The number of workers to wait for timeout (int): Maximum time to wait in seconds Returns: bool: True if the required number of workers are available, False if timeout occurred """ start_time = time.time() while len(client.scheduler_info()["workers"]) < n_workers: if time.time() - start_time > timeout: logger.critical( f"Timeout reached. Only {len(client.scheduler_info()['workers'])} workers available." ) return False time.sleep(1) # Wait for 1 second before checking again logger.info(f"{n_workers} workers are now available.") return True
[docs] def git_url_to_api_url(git_url, path="", branch="main"): """ Convert a GitHub URL to the GitHub API URL for accessing directory contents. Parameters --------- git_url : str the original GitHub repository URL. path : str the path to the directory within the repository (default: ""). branch : str the branch or commit hash to target (default: main). Returns ------- str : the API URL. """ if not git_url.startswith("https://github.com/"): raise ValueError("Invalid GitHub URL. Must start with 'https://github.com/'.") # Extract repo owner and name parts = git_url.replace("https://github.com/", "").strip("/").split("/") if len(parts) < 2: raise ValueError( "Invalid GitHub URL. Must include both owner and repository name." ) repo_owner, repo_name = parts[:2] # Build the API URL api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents/{path}?ref={branch}" return api_url
[docs] def list_files_in_directory(git_url, directory_path, branch="main"): """ Get a list of file names in a directory from a GitHub repository. Parameters: - git_url: str, the GitHub repository URL. - directory_path: str, the path to the directory in the repository. - branch: str, the branch or commit hash to target (default: main). Returns: - list of str, filenames in the directory. """ api_url = git_url_to_api_url(git_url, path=directory_path, branch=branch) response = requests.get(api_url) if response.status_code == 200: contents = response.json() filenames = [item["name"] for item in contents if item["type"] == "file"] return filenames else: raise ValueError( f"Failed to fetch directory contents. Status code: {response.status_code}" )
[docs] def download_json_tables_from_url(url: str, filenames: list): """ Downloads JSON tables from a raw git URL Parameters ---------- url : str The URL to download the JSON tables from. Returns ------- str : The directory where the JSON tables were downloaded. """ directory = tempfile.mkdtemp() logger.debug(f"Downloading JSON tables from '{url}' to '{directory}'") for filename in filenames: response = requests.get(f"{url}/{filename}") response.raise_for_status() with open(os.path.join(directory, filename), "w") as file: file.write(response.text) logger.debug(f"Loaded file {filename}") return directory