Source code for snakemake.common

__author__ = "Johannes Köster"
__copyright__ = "Copyright 2023, Johannes Köster"
__email__ = "johannes.koester@protonmail.com"
__license__ = "MIT"

import contextlib
import itertools
import math
import operator
import platform
import hashlib
import inspect
import sys
import uuid
import os
import asyncio
import collections
from pathlib import Path

from snakemake._version import get_versions

from snakemake_interface_common.exceptions import WorkflowError

__version__ = get_versions()["version"]
del get_versions


MIN_PY_VERSION = (3, 7)
UUID_NAMESPACE = uuid.uuid5(uuid.NAMESPACE_URL, "https://snakemake.readthedocs.io")
NOTHING_TO_BE_DONE_MSG = (
    "Nothing to be done (all requested files are present and up to date)."
)

ON_WINDOWS = platform.system() == "Windows"
# limit the number of input/output files list in job properties
# see https://github.com/snakemake/snakemake/issues/2097
IO_PROP_LIMIT = 100
SNAKEFILE_CHOICES = list(
    map(
        Path,
        (
            "Snakefile",
            "snakefile",
            "workflow/Snakefile",
            "workflow/snakefile",
        ),
    )
)
PIP_DEPLOYMENTS_PATH = ".snakemake/pip-deployments"


[docs] def get_snakemake_searchpaths(): paths = [str(Path(__file__).parent.parent.parent)] + [ path for path in sys.path if os.path.isdir(path) ] return list(unique_justseen(paths))
[docs] def mb_to_mib(mb): return int(math.ceil(mb * 0.95367431640625))
[docs] def parse_key_value_arg(arg, errmsg, strip_quotes=True): try: key, val = arg.split("=", 1) except ValueError: raise ValueError(errmsg + f" (Unparsable value: {repr(arg)})") if strip_quotes: val = val.strip("'\"") return key, val
[docs] def dict_to_key_value_args( some_dict: dict, quote_str: bool = True, repr_obj: bool = False ): items = [] for key, value in some_dict.items(): if repr_obj and not isinstance(value, str): encoded = repr(value) else: encoded = f"'{value}'" if quote_str and isinstance(value, str) else value items.append(f"{key}={encoded}") return items
[docs] def async_run(coroutine): """Attaches to running event loop or creates a new one to execute a coroutine. .. seealso:: https://github.com/snakemake/snakemake/issues/1105 https://stackoverflow.com/a/65696398 """ try: return asyncio.run(coroutine) except RuntimeError as e: coroutine.close() raise WorkflowError( "Error running coroutine in event loop. Snakemake currently does not " "support being executed from an already running event loop. " "If you run Snakemake e.g. from a Jupyter notebook, make sure to spawn a " "separate process for Snakemake.", e, )
APPDIRS = None RULEFUNC_CONTEXT_MARKER = "__is_snakemake_rule_func"
[docs] def get_appdirs(): global APPDIRS if APPDIRS is None: from appdirs import AppDirs APPDIRS = AppDirs("snakemake", "snakemake") return APPDIRS
[docs] def is_local_file(path_or_uri): return parse_uri(path_or_uri).scheme == "file"
[docs] def parse_uri(path_or_uri): from smart_open import parse_uri try: return parse_uri(path_or_uri) except NotImplementedError as e: # Snakemake sees a lot of URIs which are not supported by smart_open yet # "docker", "git+file", "shub", "ncbi","root","roots","rootk", "gsiftp", # "srm","ega","ab","dropbox" # Fall back to a simple split if we encounter something which isn't supported. scheme, _, uri_path = path_or_uri.partition("://") if scheme and uri_path: uri = collections.namedtuple("Uri", ["scheme", "uri_path"]) return uri(scheme, uri_path) else: raise e
[docs] def smart_join(base, path, abspath=False): if is_local_file(base): full = os.path.join(base, path) if abspath: return os.path.abspath(full) return full else: from smart_open import parse_uri uri = parse_uri(f"{base}/{path}") if not ON_WINDOWS: # Norm the path such that it does not contain any ../, # which is invalid in an URL. assert uri.uri_path[0] == "/" uri_path = os.path.normpath(uri.uri_path) else: uri_path = uri.uri_path return f"{uri.scheme}:/{uri_path}"
[docs] def num_if_possible(s): """Convert string to number if possible, otherwise return string.""" try: return int(s) except ValueError: try: return float(s) except ValueError: return s
[docs] def get_last_stable_version(): return __version__.split("+")[0]
[docs] def get_container_image(): return f"snakemake/snakemake:v{get_last_stable_version()}"
[docs] def get_uuid(name): return uuid.uuid5(UUID_NAMESPACE, name)
[docs] def get_file_hash(filename, algorithm="sha256"): """find the SHA256 hash string of a file. We use this so that the user can choose to cache working directories in storage. """ from snakemake.logging import logger # The algorithm must be available try: hasher = hashlib.new(algorithm) except ValueError as ex: logger.error("%s is not an available algorithm." % algorithm) raise ex with open(filename, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hasher.update(chunk) return hasher.hexdigest()
[docs] def bytesto(bytes, to, bsize=1024): """convert bytes to megabytes. bytes to mb: bytesto(bytes, 'm') bytes to gb: bytesto(bytes, 'g' etc. From https://gist.github.com/shawnbutts/3906915 """ levels = {"k": 1, "m": 2, "g": 3, "t": 4, "p": 5, "e": 6} answer = float(bytes) for _ in range(levels[to]): answer = answer / bsize return answer
[docs] def strip_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix) :] return text
[docs] def log_location(msg): from snakemake.logging import logger callerframerecord = inspect.stack()[1] frame = callerframerecord[0] info = inspect.getframeinfo(frame) logger.debug( "{}: {info.filename}, {info.function}, {info.lineno}".format(msg, info=info) )
[docs] def group_into_chunks(n, iterable): """Group iterable into chunks of size at most n. See https://stackoverflow.com/a/8998040. """ it = iter(iterable) while True: chunk = tuple(itertools.islice(it, n)) if not chunk: return yield chunk
[docs] class Rules: """A namespace for rules so that they can be accessed via dot notation."""
[docs] def __init__(self): self._rules = dict()
def _register_rule(self, name, rule): self._rules[name] = rule def __getattr__(self, name): from snakemake.exceptions import WorkflowError try: return self._rules[name] except KeyError: raise WorkflowError( f"Rule {name} is not defined in this workflow. " f"Available rules: {', '.join(self._rules)}" )
[docs] class Scatter: """A namespace for scatter to allow items to be accessed via dot notation.""" pass
[docs] class Gather: """A namespace for gather to allow items to be accessed via dot notation.""" pass
[docs] def get_function_params(func): return inspect.signature(func).parameters
[docs] def get_input_function_aux_params(func, candidate_params): func_params = get_function_params(func) has_var_keyword = any( param.kind == param.VAR_KEYWORD for param in func_params.values() ) if has_var_keyword: # If the function has a **kwargs parameter, we assume that it can take any # parameter, so we return all candidate parameters. return candidate_params else: return {k: v for k, v in candidate_params.items() if k in func_params}
[docs] def unique_justseen(iterable, key=None): """ List unique elements, preserving order. Remember only the element just seen. From https://docs.python.org/3/library/itertools.html#itertools-recipes """ # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B # unique_justseen('ABBcCAD', str.lower) --> A B c A D return map(next, map(operator.itemgetter(1), itertools.groupby(iterable, key)))
# Taken from https://stackoverflow.com/a/34333710/7070491. # Thanks to Laurent Laporte.
[docs] @contextlib.contextmanager def set_env(**environ): """ Temporarily set the process environment variables. >>> with set_env(PLUGINS_DIR='test/plugins'): ... "PLUGINS_DIR" in os.environ True >>> "PLUGINS_DIR" in os.environ False :type environ: dict[str, unicode] :param environ: Environment variables to set """ old_environ = dict(os.environ) os.environ.update(environ) try: yield finally: os.environ.clear() os.environ.update(old_environ)
[docs] def expand_vars_and_user(value): if value is not None: return os.path.expanduser(os.path.expandvars(value))
# Taken from https://stackoverflow.com/a/2166841/7070491 # Thanks to Alex Martelli.
[docs] def is_namedtuple_instance(x): t = type(x) b = t.__bases__ if len(b) != 1 or b[0] != tuple: return False f = getattr(t, "_fields", None) if not isinstance(f, tuple): return False return all(type(n) == str for n in f)