Fix all typing issues

This commit is contained in:
Ayaz Salikhov
2022-01-23 12:44:16 +03:00
parent 013a42fff3
commit 37c510fc8e
25 changed files with 184 additions and 129 deletions

View File

@@ -17,6 +17,13 @@ repos:
- id: black
args: [--target-version=py39]
# Check python code static typing
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
hooks:
- id: mypy
additional_dependencies: ["pytest", "types-requests", "types-tabulate"]
# Autoformat: YAML, JSON, Markdown, etc.
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.1

View File

@@ -3,7 +3,7 @@
import logging
import pytest
import pytest # type: ignore
from pathlib import Path
from conftest import TrackedContainer

View File

@@ -1,12 +1,12 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
# mypy: ignore-errors
from jupyter_core.paths import jupyter_data_dir
import subprocess
import os
import errno
import stat
c = get_config() # noqa: F821
c.NotebookApp.ip = "0.0.0.0"
c.NotebookApp.port = 8888
@@ -16,28 +16,21 @@ c.NotebookApp.open_browser = False
c.FileContentsManager.delete_to_trash = False
# Generate a self-signed certificate
OPENSSL_CONFIG = """\
[req]
distinguished_name = req_distinguished_name
[req_distinguished_name]
"""
if "GEN_CERT" in os.environ:
dir_name = jupyter_data_dir()
pem_file = os.path.join(dir_name, "notebook.pem")
try:
os.makedirs(dir_name)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(dir_name):
pass
else:
raise
os.makedirs(dir_name, exist_ok=True)
# Generate an openssl.cnf file to set the distinguished name
cnf_file = os.path.join(os.getenv("CONDA_DIR", "/usr/lib"), "ssl", "openssl.cnf")
if not os.path.isfile(cnf_file):
with open(cnf_file, "w") as fh:
fh.write(
"""\
[req]
distinguished_name = req_distinguished_name
[req_distinguished_name]
"""
)
fh.write(OPENSSL_CONFIG)
# Generate a certificate if one doesn't exist on disk
subprocess.check_call(

View File

@@ -4,7 +4,7 @@ import pathlib
import time
import logging
import pytest
import pytest # type: ignore
import requests
from conftest import TrackedContainer
@@ -303,6 +303,6 @@ def test_jupyter_env_vars_to_unset_as_root(
"-c",
"echo I like $FRUIT and ${SECRET_FRUIT:-stuff}, and love ${SECRET_ANIMAL:-to keep secrets}!",
],
**root_args,
**root_args, # type: ignore
)
assert "I like bananas and stuff, and love to keep secrets!" in logs

View File

@@ -2,7 +2,7 @@
# Distributed under the terms of the Modified BSD License.
import logging
import pytest
import pytest # type: ignore
from conftest import TrackedContainer

View File

@@ -2,7 +2,7 @@
# Distributed under the terms of the Modified BSD License.
import logging
from packaging import version
from packaging import version # type: ignore
from conftest import TrackedContainer

View File

@@ -3,7 +3,7 @@
import logging
from typing import Optional
import pytest
import pytest # type: ignore
import requests
import time

View File

@@ -2,11 +2,11 @@
# Distributed under the terms of the Modified BSD License.
import os
import logging
import typing
from typing import Any, Optional
import docker
from docker.models.containers import Container
import pytest
import pytest # type: ignore
import requests
from requests.packages.urllib3.util.retry import Retry
@@ -35,7 +35,7 @@ def docker_client() -> docker.DockerClient:
@pytest.fixture(scope="session")
def image_name() -> str:
"""Image name to test"""
return os.getenv("TEST_IMAGE")
return os.environ["TEST_IMAGE"]
class TrackedContainer:
@@ -56,14 +56,14 @@ class TrackedContainer:
self,
docker_client: docker.DockerClient,
image_name: str,
**kwargs: typing.Any,
**kwargs: Any,
):
self.container = None
self.docker_client = docker_client
self.image_name = image_name
self.kwargs = kwargs
self.container: Optional[Container] = None
self.docker_client: docker.DockerClient = docker_client
self.image_name: str = image_name
self.kwargs: Any = kwargs
def run_detached(self, **kwargs: typing.Any) -> Container:
def run_detached(self, **kwargs: Any) -> Container:
"""Runs a docker container using the preconfigured image name
and a mix of the preconfigured container options and those passed
to this method.
@@ -94,11 +94,12 @@ class TrackedContainer:
timeout: int,
no_warnings: bool = True,
no_errors: bool = True,
**kwargs: typing.Any,
**kwargs: Any,
) -> str:
running_container = self.run_detached(**kwargs)
rv = running_container.wait(timeout=timeout)
logs = running_container.logs().decode("utf-8")
assert isinstance(logs, str)
LOGGER.debug(logs)
if no_warnings:
assert not self.get_warnings(logs)
@@ -119,14 +120,14 @@ class TrackedContainer:
def _lines_starting_with(logs: str, pattern: str) -> list[str]:
return [line for line in logs.splitlines() if line.startswith(pattern)]
def remove(self):
def remove(self) -> None:
"""Kills and removes the tracked docker container."""
if self.container:
self.container.remove(force=True)
@pytest.fixture(scope="function")
def container(docker_client: docker.DockerClient, image_name: str):
def container(docker_client: docker.DockerClient, image_name: str) -> Container:
"""Notebook container with initial configuration appropriate for testing
(e.g., HTTP port exposed to the host for HTTP calls).

View File

@@ -3,7 +3,7 @@
import logging
import pytest
import pytest # type: ignore
from pathlib import Path
from conftest import TrackedContainer

26
mypy.ini Normal file
View File

@@ -0,0 +1,26 @@
[mypy]
python_version = 3.9
follow_imports = normal
strict = False
no_incremental = True
[mypy-docker.*]
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-packaging.*]
ignore_missing_imports = True
[mypy-pandas.*]
ignore_missing_imports = True
[mypy-plumbum.*]
ignore_missing_imports = True
[mypy-pyspark.*]
ignore_missing_imports = True
[mypy-tensorflow.*]
ignore_missing_imports = True

View File

@@ -10,4 +10,5 @@
# Attempt to capture and forward low-level output, e.g. produced by Extension
# libraries.
# Default: True
# type:ignore
c.IPKernelApp.capture_fd_output = False # noqa: F821

View File

@@ -2,7 +2,7 @@
# Distributed under the terms of the Modified BSD License.
import logging
import pytest
import pytest # type: ignore
from conftest import TrackedContainer

View File

@@ -3,7 +3,7 @@
import logging
import pytest
import pytest # type: ignore
from pathlib import Path
from conftest import TrackedContainer
@@ -29,7 +29,7 @@ THIS_DIR = Path(__file__).parent.resolve()
)
def test_matplotlib(
container: TrackedContainer, test_file: str, expected_file: str, description: str
):
) -> None:
"""Various tests performed on matplotlib
- Test that matplotlib is able to plot a graph and write it as an image

View File

@@ -5,6 +5,7 @@ import argparse
import datetime
import logging
import os
from docker.models.containers import Container
from .docker_runner import DockerRunner
from .get_taggers_and_manifests import get_taggers_and_manifests
from .git_helper import GitHelper
@@ -55,9 +56,9 @@ def create_manifest_file(
owner: str,
wiki_path: str,
manifests: list[ManifestInterface],
container,
container: Container,
) -> None:
manifest_names = [manifest.__name__ for manifest in manifests]
manifest_names = [manifest.__class__.__name__ for manifest in manifests]
LOGGER.info(f"Using manifests: {manifest_names}")
commit_hash_tag = GitHelper.commit_hash_tag()

View File

@@ -1,6 +1,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from typing import Optional
from types import TracebackType
import docker
from docker.models.containers import Container
import logging
@@ -13,7 +14,7 @@ class DockerRunner:
def __init__(
self,
image_name: str,
docker_client=docker.from_env(),
docker_client: docker.DockerClient = docker.from_env(),
command: str = "sleep infinity",
):
self.container: Optional[Container] = None
@@ -31,7 +32,13 @@ class DockerRunner:
LOGGER.info(f"Container {self.container.name} created")
return self.container
def __exit__(self, exc_type, exc_value, traceback) -> None:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
assert self.container is not None
LOGGER.info(f"Removing container {self.container.name} ...")
if self.container:
self.container.remove(force=True)
@@ -44,6 +51,7 @@ class DockerRunner:
LOGGER.info(f"Running cmd: '{cmd}' on container: {container}")
out = container.exec_run(cmd)
result = out.output.decode("utf-8").rstrip()
assert isinstance(result, str)
if print_result:
LOGGER.info(f"Command result: {result}")
assert out.exit_code == 0, f"Command: {cmd} failed"

View File

@@ -1,20 +1,22 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from typing import Optional
from .images_hierarchy import ALL_IMAGES
from .manifests import ManifestInterface
from .taggers import TaggerInterface
def get_taggers_and_manifests(
short_image_name: str,
short_image_name: Optional[str],
) -> tuple[list[TaggerInterface], list[ManifestInterface]]:
taggers: list[TaggerInterface] = []
manifests: list[ManifestInterface] = []
while short_image_name is not None:
if short_image_name is None:
return [[], []] # type: ignore
image_description = ALL_IMAGES[short_image_name]
taggers = image_description.taggers + taggers
manifests = image_description.manifests + manifests
short_image_name = image_description.parent_image
return taggers, manifests
parent_taggers, parent_manifests = get_taggers_and_manifests(
image_description.parent_image
)
return (
parent_taggers + image_description.taggers,
parent_manifests + image_description.manifests,
)

View File

@@ -7,7 +7,7 @@ from plumbum.cmd import git
class GitHelper:
@staticmethod
def commit_hash() -> str:
return git["rev-parse", "HEAD"]().strip()
return git["rev-parse", "HEAD"]().strip() # type: ignore
@staticmethod
def commit_hash_tag() -> str:
@@ -15,7 +15,7 @@ class GitHelper:
@staticmethod
def commit_message() -> str:
return git["log", -1, "--pretty=%B"]().strip()
return git["log", -1, "--pretty=%B"]().strip() # type: ignore
if __name__ == "__main__":

View File

@@ -3,7 +3,7 @@
import os
def github_set_env(env_name, env_value):
def github_set_env(env_name: str, env_value: str) -> None:
if not os.environ.get("GITHUB_ACTIONS") or not os.environ.get("GITHUB_ENV"):
return

View File

@@ -39,39 +39,39 @@ ALL_IMAGES = {
"base-notebook": ImageDescription(
parent_image=None,
taggers=[
SHATagger,
DateTagger,
UbuntuVersionTagger,
PythonVersionTagger,
JupyterNotebookVersionTagger,
JupyterLabVersionTagger,
JupyterHubVersionTagger,
SHATagger(),
DateTagger(),
UbuntuVersionTagger(),
PythonVersionTagger(),
JupyterNotebookVersionTagger(),
JupyterLabVersionTagger(),
JupyterHubVersionTagger(),
],
manifests=[CondaEnvironmentManifest, AptPackagesManifest],
manifests=[CondaEnvironmentManifest(), AptPackagesManifest()],
),
"minimal-notebook": ImageDescription(parent_image="base-notebook"),
"scipy-notebook": ImageDescription(parent_image="minimal-notebook"),
"r-notebook": ImageDescription(
parent_image="minimal-notebook",
taggers=[RVersionTagger],
manifests=[RPackagesManifest],
taggers=[RVersionTagger()],
manifests=[RPackagesManifest()],
),
"tensorflow-notebook": ImageDescription(
parent_image="scipy-notebook", taggers=[TensorflowVersionTagger]
parent_image="scipy-notebook", taggers=[TensorflowVersionTagger()]
),
"datascience-notebook": ImageDescription(
parent_image="scipy-notebook",
taggers=[RVersionTagger, JuliaVersionTagger],
manifests=[RPackagesManifest, JuliaPackagesManifest],
taggers=[RVersionTagger(), JuliaVersionTagger()],
manifests=[RPackagesManifest(), JuliaPackagesManifest()],
),
"pyspark-notebook": ImageDescription(
parent_image="scipy-notebook",
taggers=[SparkVersionTagger, HadoopVersionTagger, JavaVersionTagger],
manifests=[SparkInfoManifest],
taggers=[SparkVersionTagger(), HadoopVersionTagger(), JavaVersionTagger()],
manifests=[SparkInfoManifest()],
),
"all-spark-notebook": ImageDescription(
parent_image="pyspark-notebook",
taggers=[RVersionTagger],
manifests=[RPackagesManifest],
taggers=[RVersionTagger()],
manifests=[RPackagesManifest()],
),
}

View File

@@ -1,11 +1,12 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from plumbum.cmd import docker
from docker.models.containers import Container
from .docker_runner import DockerRunner
from .git_helper import GitHelper
def quoted_output(container, cmd: str) -> str:
def quoted_output(container: Container, cmd: str) -> str:
return "\n".join(
[
"```",
@@ -50,13 +51,13 @@ class ManifestInterface:
"""Common interface for all manifests"""
@staticmethod
def markdown_piece(container) -> str:
def markdown_piece(container: Container) -> str:
raise NotImplementedError
class CondaEnvironmentManifest(ManifestInterface):
@staticmethod
def markdown_piece(container) -> str:
def markdown_piece(container: Container) -> str:
return "\n".join(
[
"## Python Packages",
@@ -72,7 +73,7 @@ class CondaEnvironmentManifest(ManifestInterface):
class AptPackagesManifest(ManifestInterface):
@staticmethod
def markdown_piece(container) -> str:
def markdown_piece(container: Container) -> str:
return "\n".join(
[
"## Apt Packages",
@@ -84,7 +85,7 @@ class AptPackagesManifest(ManifestInterface):
class RPackagesManifest(ManifestInterface):
@staticmethod
def markdown_piece(container) -> str:
def markdown_piece(container: Container) -> str:
return "\n".join(
[
"## R Packages",
@@ -101,7 +102,7 @@ class RPackagesManifest(ManifestInterface):
class JuliaPackagesManifest(ManifestInterface):
@staticmethod
def markdown_piece(container) -> str:
def markdown_piece(container: Container) -> str:
return "\n".join(
[
"## Julia Packages",
@@ -118,7 +119,7 @@ class JuliaPackagesManifest(ManifestInterface):
class SparkInfoManifest(ManifestInterface):
@staticmethod
def markdown_piece(container) -> str:
def markdown_piece(container: Container) -> str:
return "\n".join(
[
"## Apache Spark",

View File

@@ -28,11 +28,11 @@ def tag_image(short_image_name: str, owner: str) -> None:
with DockerRunner(image) as container:
tags = []
for tagger in taggers:
tagger_name = tagger.__name__
tagger_name = tagger.__class__.__name__
tag_value = tagger.tag_value(container)
tags.append(tag_value)
LOGGER.info(
f"Applying tag tagger_name: {tagger_name} tag_value: {tag_value}"
f"Applying tag, tagger_name: {tagger_name} tag_value: {tag_value}"
)
docker["tag", image, f"{owner}/{short_image_name}:{tag_value}"]()

View File

@@ -1,15 +1,16 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime
from docker.models.containers import Container
from .git_helper import GitHelper
from .docker_runner import DockerRunner
def _get_program_version(container, program: str) -> str:
def _get_program_version(container: Container, program: str) -> str:
return DockerRunner.run_simple_command(container, cmd=f"{program} --version")
def _get_env_variable(container, variable: str) -> str:
def _get_env_variable(container: Container, variable: str) -> str:
env = DockerRunner.run_simple_command(
container,
cmd="env",
@@ -21,7 +22,7 @@ def _get_env_variable(container, variable: str) -> str:
raise KeyError(variable)
def _get_pip_package_version(container, package: str) -> str:
def _get_pip_package_version(container: Container, package: str) -> str:
VERSION_PREFIX = "Version: "
package_info = DockerRunner.run_simple_command(
container,
@@ -37,25 +38,25 @@ class TaggerInterface:
"""Common interface for all taggers"""
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
raise NotImplementedError
class SHATagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return GitHelper.commit_hash_tag()
class DateTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return datetime.utcnow().strftime("%Y-%m-%d")
class UbuntuVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
os_release = DockerRunner.run_simple_command(
container,
"cat /etc/os-release",
@@ -63,63 +64,64 @@ class UbuntuVersionTagger(TaggerInterface):
for line in os_release:
if line.startswith("VERSION_ID"):
return "ubuntu-" + line.split("=")[1].strip('"')
raise RuntimeError(f"did not find ubuntu version in: {os_release}")
class PythonVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "python-" + _get_program_version(container, "python").split()[1]
class JupyterNotebookVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "notebook-" + _get_program_version(container, "jupyter-notebook")
class JupyterLabVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "lab-" + _get_program_version(container, "jupyter-lab")
class JupyterHubVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "hub-" + _get_program_version(container, "jupyterhub")
class RVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "r-" + _get_program_version(container, "R").split()[2]
class TensorflowVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "tensorflow-" + _get_pip_package_version(container, "tensorflow")
class JuliaVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "julia-" + _get_program_version(container, "julia").split()[2]
class SparkVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "spark-" + _get_env_variable(container, "APACHE_SPARK_VERSION")
class HadoopVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "hadoop-" + _get_env_variable(container, "HADOOP_VERSION")
class JavaVersionTagger(TaggerInterface):
@staticmethod
def tag_value(container) -> str:
def tag_value(container: Container) -> str:
return "java-" + _get_program_version(container, "java").split()[1]

View File

@@ -27,7 +27,8 @@ from collections import defaultdict
from itertools import chain
import logging
import json
from typing import Optional
from typing import Any, Optional
from docker.models.containers import Container
from tabulate import tabulate
@@ -40,14 +41,16 @@ class CondaPackageHelper:
"""Conda package helper permitting to get information about packages"""
def __init__(self, container: TrackedContainer):
self.running_container = CondaPackageHelper.start_container(container)
self.running_container: Container = CondaPackageHelper.start_container(
container
)
self.requested: Optional[dict[str, set[str]]] = None
self.installed: Optional[dict[str, set[str]]] = None
self.available: Optional[dict[str, set[str]]] = None
self.comparison: list[dict[str, str]] = []
@staticmethod
def start_container(container: TrackedContainer):
def start_container(container: TrackedContainer) -> Container:
"""Start the TrackedContainer and return an instance of a running container"""
LOGGER.info(f"Starting container {container.image_name} ...")
return container.run_detached(
@@ -85,13 +88,13 @@ class CondaPackageHelper:
)
return self.requested
def _execute_command(self, command):
def _execute_command(self, command: list[str]) -> str:
"""Execute a command on a running container"""
rc = self.running_container.exec_run(command)
return rc.output.decode("utf-8")
return rc.output.decode("utf-8") # type: ignore
@staticmethod
def _packages_from_json(env_export) -> dict[str, set[str]]:
def _packages_from_json(env_export: str) -> dict[str, set[str]]:
"""Extract packages and versions from the lines returned by the list of specifications"""
# dependencies = filter(lambda x: isinstance(x, str), json.loads(env_export).get("dependencies"))
dependencies = json.loads(env_export).get("dependencies")
@@ -114,7 +117,7 @@ class CondaPackageHelper:
packages_dict[package] = version
return packages_dict
def available_packages(self):
def available_packages(self) -> dict[str, set[str]]:
"""Return the available packages"""
if self.available is None:
LOGGER.info("Grabing the list of available packages (can take a while) ...")
@@ -125,11 +128,13 @@ class CondaPackageHelper:
return self.available
@staticmethod
def _extract_available(lines):
def _extract_available(lines: str) -> dict[str, set[str]]:
"""Extract packages and versions from the lines returned by the list of packages"""
ddict = defaultdict(set)
for line in lines.splitlines()[2:]:
pkg, version = re.match(r"^(\S+)\s+(\S+)", line, re.MULTILINE).groups()
match = re.match(r"^(\S+)\s+(\S+)", line, re.MULTILINE)
assert match is not None
pkg, version = match.groups()
ddict[pkg].add(version)
return ddict
@@ -162,11 +167,11 @@ class CondaPackageHelper:
return self.comparison
@staticmethod
def semantic_cmp(version_string: str):
def semantic_cmp(version_string: str) -> Any:
"""Manage semantic versioning for comparison"""
def mysplit(string):
def version_substrs(x):
def mysplit(string: str) -> list[Any]:
def version_substrs(x: str) -> list[str]:
return re.findall(r"([A-z]+|\d+)", x)
return list(chain(map(version_substrs, string.split("."))))
@@ -189,7 +194,9 @@ class CondaPackageHelper:
def get_outdated_summary(self, requested_only: bool = True) -> str:
"""Return a summary of outdated packages"""
nb_packages = len(self.requested if requested_only else self.installed)
packages = self.requested if requested_only else self.installed
assert packages is not None
nb_packages = len(packages)
nb_updatable = len(self.comparison)
updatable_ratio = nb_updatable / nb_packages
return f"{nb_updatable}/{nb_packages} ({updatable_ratio:.0%}) packages could be updated"

View File

@@ -3,7 +3,7 @@
import logging
import pytest
import pytest # type: ignore
from conftest import TrackedContainer
from package_helper import CondaPackageHelper
@@ -12,7 +12,9 @@ LOGGER = logging.getLogger(__name__)
@pytest.mark.info
def test_outdated_packages(container: TrackedContainer, requested_only: bool = True):
def test_outdated_packages(
container: TrackedContainer, requested_only: bool = True
) -> None:
"""Getting the list of updatable packages"""
LOGGER.info(f"Checking outdated packages in {container.image_name} ...")
pkg_helper = CondaPackageHelper(container)

View File

@@ -37,8 +37,9 @@ Example:
import logging
import pytest
import pytest # type: ignore
from conftest import TrackedContainer
from typing import Callable, Iterable
from package_helper import CondaPackageHelper
@@ -87,7 +88,7 @@ def packages(package_helper: CondaPackageHelper) -> dict[str, set[str]]:
return package_helper.requested_packages()
def package_map(package: str) -> str:
def get_package_import_name(package: str) -> str:
"""Perform a mapping between the python package name and the name used for the import"""
return PACKAGE_MAPPING.get(package, package)
@@ -113,7 +114,7 @@ def _check_import_package(
"""Generic function executing a command"""
LOGGER.debug(f"Trying to import a package with [{command}] ...")
rc = package_helper.running_container.exec_run(command)
return rc.exit_code
return rc.exit_code # type: ignore
def check_import_python_package(
@@ -130,10 +131,10 @@ def check_import_r_package(package_helper: CondaPackageHelper, package: str) ->
)
def _import_packages(
def _check_import_packages(
package_helper: CondaPackageHelper,
filtered_packages: dict[str, set[str]],
check_function,
filtered_packages: Iterable[str],
check_function: Callable[[CondaPackageHelper, str], int],
max_failures: int,
) -> None:
"""Test if packages can be imported
@@ -157,33 +158,36 @@ def _import_packages(
@pytest.fixture(scope="function")
def r_packages(packages: dict[str, set[str]]):
def r_packages(packages: dict[str, set[str]]) -> Iterable[str]:
"""Return an iterable of R packages"""
# package[2:] is to remove the leading "r-" appended on R packages
return map(
lambda package: package_map(package[2:]), filter(r_package_predicate, packages)
lambda package: get_package_import_name(package[2:]),
filter(r_package_predicate, packages),
)
def test_r_packages(
package_helper: CondaPackageHelper, r_packages, max_failures: int = 0
):
package_helper: CondaPackageHelper, r_packages: Iterable[str], max_failures: int = 0
) -> None:
"""Test the import of specified R packages"""
return _import_packages(
_check_import_packages(
package_helper, r_packages, check_import_r_package, max_failures
)
@pytest.fixture(scope="function")
def python_packages(packages: dict[str, set[str]]):
def python_packages(packages: dict[str, set[str]]) -> Iterable[str]:
"""Return an iterable of Python packages"""
return map(package_map, filter(python_package_predicate, packages))
return map(get_package_import_name, filter(python_package_predicate, packages))
def test_python_packages(
package_helper: CondaPackageHelper, python_packages, max_failures: int = 0
):
package_helper: CondaPackageHelper,
python_packages: Iterable[str],
max_failures: int = 0,
) -> None:
"""Test the import of specified python packages"""
return _import_packages(
_check_import_packages(
package_helper, python_packages, check_import_python_package, max_failures
)