__author__ = "Johannes Köster"
__copyright__ = "Copyright 2023, Johannes Köster"
__email__ = "johannes.koester@protonmail.com"
__license__ = "MIT"
import contextlib
import itertools
import math
import operator
import platform
import hashlib
import inspect
import shutil
import sys
from typing import Callable, List
import uuid
import os
import asyncio
import collections
from pathlib import Path
from typing import Union
from snakemake import __version__
from snakemake_interface_common.exceptions import WorkflowError
MIN_PY_VERSION = (3, 7)
UUID_NAMESPACE = uuid.uuid5(uuid.NAMESPACE_URL, "https://snakemake.readthedocs.io")
NOTHING_TO_BE_DONE_MSG = (
"Nothing to be done (all requested files are present and up to date)."
)
ON_WINDOWS = platform.system() == "Windows"
# limit the number of input/output files list in job properties
# see https://github.com/snakemake/snakemake/issues/2097
IO_PROP_LIMIT = 100
SNAKEFILE_CHOICES = list(
map(
Path,
(
"Snakefile",
"snakefile",
"workflow/Snakefile",
"workflow/snakefile",
),
)
)
[docs]
def get_snakemake_searchpaths():
paths = [str(Path(__file__).parent.parent.parent)] + [
path for path in sys.path if os.path.isdir(path)
]
return list(unique_justseen(paths))
[docs]
def get_report_id(path: Union[str, Path]) -> str:
h = hashlib.sha256()
h.update(str(path).encode())
return h.hexdigest()
[docs]
def mb_to_mib(mb):
return int(math.ceil(mb * 0.95367431640625))
[docs]
def parse_key_value_arg(arg, errmsg, strip_quotes=True):
try:
key, val = arg.split("=", 1)
except ValueError:
raise ValueError(errmsg + f" (Unparsable value: {repr(arg)})")
if strip_quotes:
val = val.strip("'\"")
return key, val
[docs]
def dict_to_key_value_args(
some_dict: dict, quote_str: bool = True, repr_obj: bool = False
):
items = []
for key, value in some_dict.items():
if repr_obj and not isinstance(value, str):
encoded = repr(value)
else:
encoded = f"'{value}'" if quote_str and isinstance(value, str) else value
items.append(f"{key}={encoded}")
return items
[docs]
def async_run(coroutine):
"""Attaches to running event loop or creates a new one to execute a
coroutine.
.. seealso::
https://github.com/snakemake/snakemake/issues/1105
https://stackoverflow.com/a/65696398
"""
try:
return asyncio.run(coroutine)
except RuntimeError as e:
coroutine.close()
raise WorkflowError(
"Error running coroutine in event loop. Snakemake currently does not "
"support being executed from an already running event loop. "
"If you run Snakemake e.g. from a Jupyter notebook, make sure to spawn a "
"separate process for Snakemake.",
e,
)
APPDIRS = None
RULEFUNC_CONTEXT_MARKER = "__is_snakemake_rule_func"
[docs]
def get_appdirs():
global APPDIRS
if APPDIRS is None:
from appdirs import AppDirs
APPDIRS = AppDirs("snakemake", "snakemake")
return APPDIRS
[docs]
def is_local_file(path_or_uri):
return parse_uri(path_or_uri).scheme == "file"
[docs]
def parse_uri(path_or_uri):
from smart_open import parse_uri
try:
return parse_uri(path_or_uri)
except NotImplementedError as e:
# Snakemake sees a lot of URIs which are not supported by smart_open yet
# "docker", "git+file", "shub", "ncbi","root","roots","rootk", "gsiftp",
# "srm","ega","ab","dropbox"
# Fall back to a simple split if we encounter something which isn't supported.
scheme, _, uri_path = path_or_uri.partition("://")
if scheme and uri_path:
uri = collections.namedtuple("Uri", ["scheme", "uri_path"])
return uri(scheme, uri_path)
else:
raise e
[docs]
def smart_join(base, path, abspath=False):
if is_local_file(base):
full = os.path.join(base, path)
if abspath:
return os.path.abspath(full)
return full
else:
from smart_open import parse_uri
uri = parse_uri(f"{base}/{path}")
if not ON_WINDOWS:
# Norm the path such that it does not contain any ../,
# which is invalid in an URL.
assert uri.uri_path[0] == "/"
uri_path = os.path.normpath(uri.uri_path)
else:
uri_path = uri.uri_path
return f"{uri.scheme}:/{uri_path}"
[docs]
def num_if_possible(s):
"""Convert string to number if possible, otherwise return string."""
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return s
[docs]
def get_last_stable_version():
return __version__.split("+")[0]
[docs]
def get_container_image():
return f"snakemake/snakemake:v{get_last_stable_version()}"
[docs]
def get_uuid(name):
return uuid.uuid5(UUID_NAMESPACE, name)
[docs]
def get_file_hash(filename, algorithm="sha256"):
"""find the SHA256 hash string of a file. We use this so that the
user can choose to cache working directories in storage.
"""
from snakemake.logging import logger
# The algorithm must be available
try:
hasher = hashlib.new(algorithm)
except ValueError as ex:
logger.error("%s is not an available algorithm." % algorithm)
raise ex
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hasher.update(chunk)
return hasher.hexdigest()
[docs]
def bytesto(bytes, to, bsize=1024):
"""convert bytes to megabytes.
bytes to mb: bytesto(bytes, 'm')
bytes to gb: bytesto(bytes, 'g' etc.
From https://gist.github.com/shawnbutts/3906915
"""
levels = {"k": 1, "m": 2, "g": 3, "t": 4, "p": 5, "e": 6}
answer = float(bytes)
for _ in range(levels[to]):
answer = answer / bsize
return answer
[docs]
def strip_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
[docs]
def log_location(msg):
from snakemake.logging import logger
callerframerecord = inspect.stack()[1]
frame = callerframerecord[0]
info = inspect.getframeinfo(frame)
logger.debug(
"{}: {info.filename}, {info.function}, {info.lineno}".format(msg, info=info)
)
[docs]
def group_into_chunks(n, iterable):
"""Group iterable into chunks of size at most n.
See https://stackoverflow.com/a/8998040.
"""
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
[docs]
class Rules:
"""A namespace for rules so that they can be accessed via dot notation."""
def __init__(self):
self._rules = dict()
def _register_rule(self, name, rule):
self._rules[name] = rule
def __getattr__(self, name):
from snakemake.exceptions import WorkflowError
try:
return self._rules[name]
except KeyError:
raise WorkflowError(
f"Rule {name} is not defined in this workflow. "
f"Available rules: {', '.join(self._rules)}"
)
[docs]
class Scatter:
"""A namespace for scatter to allow items to be accessed via dot notation."""
pass
[docs]
class Gather:
"""A namespace for gather to allow items to be accessed via dot notation."""
pass
FUNC_OVERWRITE_PARAMS_ATTR = "_overwrite_params"
[docs]
def get_function_params(func: Callable):
if hasattr(func, FUNC_OVERWRITE_PARAMS_ATTR):
return getattr(func, FUNC_OVERWRITE_PARAMS_ATTR)
else:
return inspect.signature(func).parameters
[docs]
def overwrite_function_params(func: Callable, params: List[str]):
"""Force function params to be the given list. Useful for functions that
use *args to get all parameters in dynamically created cases like in
snakemake.ioutils.subpath.subpath.
"""
setattr(func, FUNC_OVERWRITE_PARAMS_ATTR, params)
[docs]
def unique_justseen(iterable, key=None):
"""
List unique elements, preserving order. Remember only the element just seen.
From https://docs.python.org/3/library/itertools.html#itertools-recipes
"""
# unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
# unique_justseen('ABBcCAD', str.lower) --> A B c A D
return map(next, map(operator.itemgetter(1), itertools.groupby(iterable, key)))
# Taken from https://stackoverflow.com/a/34333710/7070491.
# Thanks to Laurent Laporte.
[docs]
@contextlib.contextmanager
def set_env(**environ):
"""
Temporarily set the process environment variables.
>>> with set_env(PLUGINS_DIR='test/plugins'):
... "PLUGINS_DIR" in os.environ
True
>>> "PLUGINS_DIR" in os.environ
False
:type environ: Dict[str, unicode]
:param environ: Environment variables to set
"""
old_environ = dict(os.environ)
os.environ.update(environ)
try:
yield
finally:
os.environ.clear()
os.environ.update(old_environ)
[docs]
def expand_vars_and_user(value):
if value is not None:
return os.path.expanduser(os.path.expandvars(value))
# Taken from https://stackoverflow.com/a/2166841/7070491
# Thanks to Alex Martelli.
[docs]
def is_namedtuple_instance(x):
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) is str for n in f)
[docs]
def copy_permission_safe(src: str, dst: str):
"""Copy a file to a given destination.
If destination exists, it is removed first in order to avoid permission issues when
the destination permissions are tried to be applied to an already existing
destination.
"""
if os.path.exists(dst):
os.unlink(dst)
shutil.copy(src, dst)