__authors__ = "Johannes Köster"
__copyright__ = "Copyright 2022, Johannes Köster"
__email__ = "johannes.koester@uni-due.de"
__license__ = "MIT"
from pathlib import Path
import posixpath
import re
import os
import shutil
import stat
from typing import Optional
from snakemake import utils
import tempfile
import io
from abc import ABC, abstractmethod
from urllib.parse import unquote
from snakemake_interface_executor_plugins.settings import ExecMode
from snakemake.common import (
ON_WINDOWS,
is_local_file,
get_appdirs,
parse_uri,
smart_join,
)
from snakemake.exceptions import WorkflowError, SourceFileError
from snakemake.common.git import split_git_path
from snakemake.logging import logger
def _check_git_args(tag: str = None, branch: str = None, commit: str = None):
n_refs = sum(1 for ref in (tag, branch, commit) if ref is not None)
if n_refs != 1:
raise SourceFileError(
"exactly one of tag, branch, or commit must be specified."
)
[docs]
class SourceFile(ABC):
[docs]
@abstractmethod
def get_path_or_uri(self): ...
[docs]
@abstractmethod
def is_persistently_cacheable(self): ...
[docs]
def get_cache_path(self):
uri = parse_uri(self.get_path_or_uri())
return os.path.join(uri.scheme, unquote(uri.uri_path.lstrip("/")))
[docs]
def get_basedir(self):
path = os.path.dirname(self.get_path_or_uri())
return self.__class__(path)
[docs]
@abstractmethod
def get_filename(self): ...
[docs]
def join(self, path):
if isinstance(path, SourceFile):
path = path.get_path_or_uri()
return self.__class__(smart_join(self.get_path_or_uri(), path))
[docs]
def mtime(self):
"""If possible, return mtime of the file. Otherwise, return None."""
return None
@property
@abstractmethod
def is_local(self): ...
def __hash__(self):
return self.get_path_or_uri().__hash__()
def __eq__(self, other):
if isinstance(other, SourceFile):
return self.get_path_or_uri() == other.get_path_or_uri()
return False
def __str__(self):
return self.get_path_or_uri()
[docs]
def simplify_path(self):
return self
[docs]
class GenericSourceFile(SourceFile):
def __init__(self, path_or_uri):
self.path_or_uri = path_or_uri
[docs]
def get_path_or_uri(self):
return self.path_or_uri
[docs]
def get_filename(self):
return os.path.basename(self.path_or_uri)
[docs]
def is_persistently_cacheable(self):
return False
@property
def is_local(self):
return False
[docs]
class LocalSourceFile(SourceFile):
def __init__(self, path):
self.path = path
[docs]
def get_path_or_uri(self):
return self.path
[docs]
def is_persistently_cacheable(self):
return False
[docs]
def get_filename(self):
return os.path.basename(self.path)
[docs]
def abspath(self):
return LocalSourceFile(os.path.abspath(self.path))
[docs]
def isabs(self):
return os.path.isabs(self.path)
[docs]
def simplify_path(self):
return utils.simplify_path(self.path)
[docs]
def mtime(self):
return os.stat(self.path).st_mtime
def __fspath__(self):
return self.path
@property
def is_local(self):
return True
[docs]
class LocalGitFile(SourceFile):
def __init__(
self, repo_path, path: str, tag: str = None, ref: str = None, commit: str = None
):
_check_git_args(tag, ref, commit)
self.tag = tag
self.commit = commit
self._ref = ref
self.repo_path = repo_path
self.path = path
[docs]
def get_path_or_uri(self):
return "git+file://{}/{}@{}".format(
os.path.abspath(self.repo_path), self.path, self.ref
)
[docs]
def join(self, path):
path = os.path.normpath("/".join((self.path, path)))
if ON_WINDOWS:
# convert back to URL separators
# (win specific separators are introduced by normpath above)
path = path.replace("\\", "/")
return LocalGitFile(
self.repo_path, path, tag=self.tag, ref=self._ref, commit=self.commit
)
[docs]
def get_basedir(self):
return self.__class__(
repo_path=self.repo_path,
path=os.path.dirname(self.path),
tag=self.tag,
commit=self.commit,
ref=self._ref,
)
[docs]
def is_persistently_cacheable(self):
return False
[docs]
def get_filename(self):
return posixpath.basename(self.path)
@property
def ref(self):
return self.tag or self.commit or self._ref
@property
def is_local(self):
return True
[docs]
class HostingProviderFile(SourceFile):
"""Marker for denoting github source files from releases."""
valid_repo = re.compile("^.+/.+$")
def __init__(
self,
repo: str = None,
path: str = None,
tag: str = None,
branch: str = None,
commit: str = None,
host: str = None,
):
if repo is None:
raise SourceFileError("repo must be given")
if not self.__class__.valid_repo.match(repo):
raise SourceFileError(
"repo {} is not a valid repo specification (must be given as owner/name)."
)
_check_git_args(tag, branch, commit)
if path is None:
raise SourceFileError("path must be given")
if not all(
isinstance(item, str)
for item in (repo, path, tag, branch, commit)
if item is not None
):
raise SourceFileError("arguments must be given as str.")
self.repo = repo
self.tag = tag
self.commit = commit
self.branch = branch
self.path = path.strip("/")
self.token = ""
self.host = host
# Via __post_init__ implementing subclasses can do additional things without
# replicating the constructor args.
self.__post_init__()
def __post_init__(self):
pass
[docs]
def mtime(self) -> Optional[float]:
# Intentionally None, hence causing any caching to generate an updated mtime.
# Switching commits/branches/refs in the same repo should cause rerun triggers
# if those files are used as input files for jobs and have changed checksums.
return None
[docs]
def is_persistently_cacheable(self):
return bool(self.tag or self.commit)
[docs]
def get_filename(self):
return os.path.basename(self.path)
@property
def ref(self):
return self.tag or self.commit or self.branch
[docs]
def get_basedir(self):
return self.__class__(
repo=self.repo,
path=os.path.dirname(self.path),
tag=self.tag,
commit=self.commit,
branch=self.branch,
host=self.host,
)
[docs]
def join(self, path):
path = os.path.normpath(f"{self.path}/{path}")
if ON_WINDOWS:
# convert back to URL separators
# (win specific separators are introduced by normpath above)
path = path.replace("\\", "/")
return self.__class__(
repo=self.repo,
path=path,
tag=self.tag,
commit=self.commit,
branch=self.branch,
host=self.host,
)
@property
def is_local(self):
return False
[docs]
class GithubFile(HostingProviderFile):
def __post_init__(self):
if self.host is not None:
raise WorkflowError(
"host keyword argument is not yet supported by GithubFile."
)
self.token = os.environ.get("GITHUB_TOKEN", "")
[docs]
def get_path_or_uri(self):
auth = f":{self.token}@" if self.token else ""
# TODO find out how this URL looks like with Github enterprise server and support
# self.host being not none by removing the check in __post_init__
return f"https://{auth}raw.githubusercontent.com/{self.repo}/{self.ref}/{self.path}"
[docs]
class GitlabFile(HostingProviderFile):
def __post_init__(self):
if self.host is None:
self.host = "gitlab.com"
self.token = os.environ.get("GITLAB_TOKEN", "")
[docs]
def get_path_or_uri(self):
from urllib.parse import quote
auth = f"&private_token={self.token}" if self.token else ""
return "https://{}/api/v4/projects/{}/repository/files/{}/raw?ref={}{}".format(
self.host,
quote(self.repo, safe=""),
quote(self.path, safe=""),
self.ref,
auth,
)
[docs]
def infer_source_file(path_or_uri, basedir: SourceFile = None):
if isinstance(path_or_uri, SourceFile):
if basedir is None or isinstance(path_or_uri, HostingProviderFile):
return path_or_uri
else:
path_or_uri = path_or_uri.get_path_or_uri()
if isinstance(path_or_uri, Path):
path_or_uri = str(path_or_uri)
if not isinstance(path_or_uri, str):
raise SourceFileError(
"must be given as Python string or one of the predefined source file marker types (see docs)"
)
if is_local_file(path_or_uri):
# either local file or relative to some remote basedir
for schema in ("file://", "file:"):
if path_or_uri.startswith(schema):
path_or_uri = path_or_uri[len(schema) :]
break
if not os.path.isabs(path_or_uri) and basedir is not None:
return basedir.join(path_or_uri)
return LocalSourceFile(path_or_uri)
if path_or_uri.startswith("git+file:"):
try:
root_path, file_path, ref = split_git_path(path_or_uri)
except Exception as e:
raise WorkflowError(
f"Failed to read source {path_or_uri} from git repo.", e
)
return LocalGitFile(root_path, file_path, ref=ref)
# something else
return GenericSourceFile(path_or_uri)
[docs]
class SourceCache:
cache_whitelist = [
r"https://raw.githubusercontent.com/snakemake/snakemake-wrappers/\d+\.\d+.\d+"
] # TODO add more prefixes for uris that are save to be cached
def __init__(self, cache_path: Path, runtime_cache_path: Path = None):
self.cache_path = cache_path
os.makedirs(self.cache_path, exist_ok=True)
if runtime_cache_path is None:
runtime_cache_parent = self.cache_path / "runtime-cache"
os.makedirs(runtime_cache_parent, exist_ok=True)
self.runtime_cache = tempfile.TemporaryDirectory(
dir=runtime_cache_parent, ignore_cleanup_errors=True
)
self._runtime_cache_path = None
else:
self._runtime_cache_path = runtime_cache_path
self.runtime_cache = None
self.cacheable_prefixes = re.compile("|".join(self.cache_whitelist))
@property
def runtime_cache_path(self):
return self._runtime_cache_path or self.runtime_cache.name
[docs]
def open(self, source_file, mode="r"):
cache_entry = self._cache(source_file)
return self._open_local_or_remote(
LocalSourceFile(cache_entry), mode, encoding="utf-8"
)
[docs]
def exists(self, source_file):
try:
self._cache(source_file, retries=1)
except Exception:
return False
return True
[docs]
def get_path(self, source_file):
cache_entry = self._cache(source_file)
return str(cache_entry)
def _cache_entry(self, source_file: SourceFile) -> Path:
file_cache_path = source_file.get_cache_path()
assert file_cache_path
# TODO add git support to smart_open!
if source_file.is_persistently_cacheable():
# check cache
return self.cache_path / file_cache_path
else:
# check runtime cache
return Path(self.runtime_cache_path) / file_cache_path
def _cache(self, source_file: SourceFile, retries: int = 3):
cache_entry = self._cache_entry(source_file)
if not cache_entry.exists():
self._do_cache(source_file, cache_entry, retries=retries)
return cache_entry
def _do_cache(self, source_file, cache_entry: Path, retries: int = 3):
# open from origin
with self._open_local_or_remote(source_file, "rb", retries=retries) as source:
cache_entry.parent.mkdir(parents=True, exist_ok=True)
tmp_source = tempfile.NamedTemporaryFile(
prefix=str(cache_entry),
delete=False, # no need to delete since we move it below
)
tmp_source.write(source.read())
tmp_source.close()
# ensure read and write permissions for owner and group
os.chmod(
tmp_source.name,
stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP,
)
# Atomic move to right name.
# This way we avoid the need to lock.
shutil.move(tmp_source.name, cache_entry)
mtime = source_file.mtime()
if mtime is not None:
# Set to mtime of original file
# In case we don't have that mtime, it is fine
# to just keep the time at the time of caching
# as mtime.
os.utime(cache_entry, times=(mtime, mtime))
def _open_local_or_remote(
self, source_file: SourceFile, mode, encoding=None, retries: int = 3
):
from reretry.api import retry_call
if source_file.is_local:
return self._open(source_file, mode, encoding=encoding)
else:
return retry_call(
self._open,
[source_file, mode, encoding],
tries=retries,
delay=3,
backoff=2,
logger=logger,
)
def _open(self, source_file: SourceFile, mode, encoding=None):
from smart_open import open
if isinstance(source_file, LocalGitFile):
import git
return io.BytesIO(
git.Repo(source_file.repo_path)
.git.show(f"{source_file.ref}:{source_file.path}")
.encode()
)
path_or_uri = source_file.get_path_or_uri()
try:
return open(path_or_uri, mode, encoding=None if "b" in mode else encoding)
except Exception as e:
raise WorkflowError(f"Failed to open source file {path_or_uri}", e)