diff --git a/src/runboat/controller.py b/src/runboat/controller.py index 8347096..7180bd0 100644 --- a/src/runboat/controller.py +++ b/src/runboat/controller.py @@ -1,5 +1,6 @@ import asyncio import logging +from typing import Any, Awaitable, Callable from . import k8s from .db import BuildsDb @@ -247,7 +248,7 @@ class Controller: async def start(self) -> None: _logger.info("Starting controller tasks.") - async def walking_dead(func): + async def walking_dead(func: Callable[..., Awaitable[Any]]) -> None: while True: _logger.info(f"(Re)starting {func.__name__}") try: @@ -275,7 +276,7 @@ class Controller: task.cancel() # Wait until all tasks are cancelled. await asyncio.gather(*self._tasks, return_exceptions=True) - self._task = [] + self._tasks.clear() controller = Controller() diff --git a/src/runboat/db.py b/src/runboat/db.py index 4573d22..09092c7 100644 --- a/src/runboat/db.py +++ b/src/runboat/db.py @@ -1,6 +1,6 @@ import logging import sqlite3 -from typing import Optional +from typing import cast from .models import Build, BuildInitStatus, BuildStatus @@ -59,7 +59,7 @@ class BuildsDb: self, repo: str, target_branch: str, pr: int | None, git_commit: str ) -> Build | None: query = "SELECT * FROM builds WHERE repo=? AND target_branch=? AND git_commit=?" - params = [repo.lower(), target_branch, git_commit] + params: list[str | int] = [repo.lower(), target_branch, git_commit] if pr: query += " AND pr=?" params.append(pr) @@ -131,17 +131,20 @@ class BuildsDb: return True def count_by_status(self, status: BuildStatus) -> int: - return self._con.execute( + count = self._con.execute( "SELECT COUNT(name) FROM builds WHERE status=?", (status,) ).fetchone()[0] + return cast(int, count) def count_by_init_status(self, init_status: BuildInitStatus) -> int: - return self._con.execute( + count = self._con.execute( "SELECT COUNT(name) FROM builds WHERE init_status=?", (init_status,) ).fetchone()[0] + return cast(int, count) def count_all(self) -> int: - return self._con.execute("SELECT COUNT(name) FROM builds").fetchone()[0] + count = self._con.execute("SELECT COUNT(name) FROM builds").fetchone()[0] + return cast(int, count) def to_initialize(self, limit: int) -> list[Build]: """Return the list of builds to initialize, ordered by creation timestamp.""" @@ -168,7 +171,7 @@ class BuildsDb: ).fetchall() return [self._build_from_row(row) for row in rows] - def search(self, repo: Optional[str] = None) -> list[Build]: + def search(self, repo: str | None = None) -> list[Build]: query = "SELECT * FROM builds " where = [] params = [] diff --git a/src/runboat/k8s.py b/src/runboat/k8s.py index 31c2666..391e855 100644 --- a/src/runboat/k8s.py +++ b/src/runboat/k8s.py @@ -9,13 +9,14 @@ from contextlib import contextmanager from enum import Enum from importlib import resources from pathlib import Path -from typing import Any, Generator, Optional +from typing import Generator, Optional, TypedDict, cast import urllib3 from jinja2 import Template from kubernetes import client, config, watch from kubernetes.client.exceptions import ApiException from kubernetes.client.models.v1_deployment import V1Deployment +from kubernetes.client.models.v1_job import V1Job from pydantic import BaseModel from .settings import settings @@ -24,10 +25,9 @@ from .utils import sync_to_async, sync_to_async_iterator _logger = logging.getLogger(__name__) -def _split_image_name_tag(img: str) -> tuple[str, str]: - if ":" in img: - return img.split(":", 2) - return (img, "latest") +def _split_image_name_tag(image: str) -> tuple[str, str]: + img, _, tag = image.partition(":") + return (img, tag or "latest") @sync_to_async @@ -39,7 +39,7 @@ def load_kube_config() -> None: @sync_to_async -def read_deployment(name: str) -> Optional[V1Deployment]: +def read_deployment(name: str) -> V1Deployment | None: appsv1 = client.AppsV1Api() items = appsv1.list_namespaced_deployment( namespace=settings.build_namespace, @@ -56,9 +56,15 @@ def delete_deployment(deployment_name: str) -> None: ) +class PatchOperation(TypedDict, total=False): + op: str + path: str + value: str | int # maybe absent, hence total=False above + + @sync_to_async def patch_deployment( - deployment_name: str, ops: list[dict["str", Any]], not_found_ok: bool + deployment_name: str, ops: list[PatchOperation], not_found_ok: bool ) -> None: appsv1 = client.AppsV1Api() try: @@ -113,7 +119,7 @@ def _watch(list_method, *args, **kwargs): @sync_to_async_iterator -def watch_deployments(): +def watch_deployments() -> Generator[V1Deployment, None, None]: appsv1 = client.AppsV1Api() yield from _watch( appsv1.list_namespaced_deployment, namespace=settings.build_namespace @@ -121,7 +127,7 @@ def watch_deployments(): @sync_to_async_iterator -def watch_jobs(): +def watch_jobs() -> Generator[V1Job, None, None]: batchv1 = client.BatchV1Api() yield from _watch(batchv1.list_namespaced_job, namespace=settings.build_namespace) @@ -258,8 +264,8 @@ async def delete_job(build_name: str, job_kind: DeploymentMode) -> None: @sync_to_async -def log(build_name: str, job_kind: DeploymentMode | None) -> str: - """Return the buil log. +def log(build_name: str, job_kind: DeploymentMode | None) -> str | None: + """Return the build log. The pod for which the log is returned is the first that matches the build_name (via its runboat/build label) and job_kind (via its @@ -269,20 +275,18 @@ def log(build_name: str, job_kind: DeploymentMode | None) -> str: pods = corev1.list_namespaced_pod( namespace=settings.build_namespace, label_selector=f"runboat/build={build_name}" ).items - pod = None for pod in pods: - if job_kind is None: - if "runboat/job-kind" not in pod.metadata.labels: - break - else: - if pod.metadata.labels.get("runboat/job-kind") == job_kind: - break + if pod.metadata.labels.get("runboat/job-kind") == job_kind: + break else: # no matching pod found - return - return corev1.read_namespaced_pod_log( - pod.metadata.name, - namespace=settings.build_namespace, - tail_lines=None if job_kind else None, - follow=False, + return None + return cast( + str, + corev1.read_namespaced_pod_log( + pod.metadata.name, + namespace=settings.build_namespace, + tail_lines=None if job_kind else None, + follow=False, + ), ) diff --git a/src/runboat/models.py b/src/runboat/models.py index 3166775..9c5d0fe 100644 --- a/src/runboat/models.py +++ b/src/runboat/models.py @@ -325,8 +325,8 @@ class Build(BaseModel): desired_replicas: int | None = None, remove_finalizers: bool = False, not_found_ok: bool = False, - ) -> None: - ops = [] + ) -> bool: + ops: list[k8s.PatchOperation] = [] if init_status is not None and init_status != self.init_status: ops.extend( [ diff --git a/src/runboat/settings.py b/src/runboat/settings.py index bff321e..17d5443 100644 --- a/src/runboat/settings.py +++ b/src/runboat/settings.py @@ -50,7 +50,7 @@ class Settings(BaseSettings): @validator("supported_repos") @classmethod - def validate_supported_repos(v) -> set[str]: + def validate_supported_repos(cls, v: set[str]) -> set[str]: return {item.lower() for item in v} diff --git a/src/runboat/webui.py b/src/runboat/webui.py index 5034d5d..7910ac7 100644 --- a/src/runboat/webui.py +++ b/src/runboat/webui.py @@ -9,7 +9,7 @@ from .controller import controller from .models import BuildStatus router = APIRouter() -templates = Jinja2Templates(directory=Path(__file__).parent / "webui") +templates = Jinja2Templates(directory=str(Path(__file__).parent / "webui")) @router.get("/builds/{name}", response_class=HTMLResponse) diff --git a/tests/test_k8s.py b/tests/test_k8s.py new file mode 100644 index 0000000..68be498 --- /dev/null +++ b/tests/test_k8s.py @@ -0,0 +1,14 @@ +import pytest + +from runboat.k8s import _split_image_name_tag + + +@pytest.mark.parametrize( + ("image", "expected"), + [ + ("postgres", ("postgres", "latest")), + ("postgres:12", ("postgres", "12")), + ], +) +def test_split_image_name_tag(image, expected): + assert _split_image_name_tag(image) == expected