Some typing

This commit is contained in:
Stéphane Bidoul 2021-11-13 18:05:55 +01:00
parent defa2f959e
commit 858c58e47d
No known key found for this signature in database
GPG key ID: BCAB2555446B5B92
7 changed files with 58 additions and 36 deletions

View file

@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Any, Awaitable, Callable
from . import k8s
from .db import BuildsDb
@ -247,7 +248,7 @@ class Controller:
async def start(self) -> None:
_logger.info("Starting controller tasks.")
async def walking_dead(func):
async def walking_dead(func: Callable[..., Awaitable[Any]]) -> None:
while True:
_logger.info(f"(Re)starting {func.__name__}")
try:
@ -275,7 +276,7 @@ class Controller:
task.cancel()
# Wait until all tasks are cancelled.
await asyncio.gather(*self._tasks, return_exceptions=True)
self._task = []
self._tasks.clear()
controller = Controller()

View file

@ -1,6 +1,6 @@
import logging
import sqlite3
from typing import Optional
from typing import cast
from .models import Build, BuildInitStatus, BuildStatus
@ -59,7 +59,7 @@ class BuildsDb:
self, repo: str, target_branch: str, pr: int | None, git_commit: str
) -> Build | None:
query = "SELECT * FROM builds WHERE repo=? AND target_branch=? AND git_commit=?"
params = [repo.lower(), target_branch, git_commit]
params: list[str | int] = [repo.lower(), target_branch, git_commit]
if pr:
query += " AND pr=?"
params.append(pr)
@ -131,17 +131,20 @@ class BuildsDb:
return True
def count_by_status(self, status: BuildStatus) -> int:
return self._con.execute(
count = self._con.execute(
"SELECT COUNT(name) FROM builds WHERE status=?", (status,)
).fetchone()[0]
return cast(int, count)
def count_by_init_status(self, init_status: BuildInitStatus) -> int:
return self._con.execute(
count = self._con.execute(
"SELECT COUNT(name) FROM builds WHERE init_status=?", (init_status,)
).fetchone()[0]
return cast(int, count)
def count_all(self) -> int:
return self._con.execute("SELECT COUNT(name) FROM builds").fetchone()[0]
count = self._con.execute("SELECT COUNT(name) FROM builds").fetchone()[0]
return cast(int, count)
def to_initialize(self, limit: int) -> list[Build]:
"""Return the list of builds to initialize, ordered by creation timestamp."""
@ -168,7 +171,7 @@ class BuildsDb:
).fetchall()
return [self._build_from_row(row) for row in rows]
def search(self, repo: Optional[str] = None) -> list[Build]:
def search(self, repo: str | None = None) -> list[Build]:
query = "SELECT * FROM builds "
where = []
params = []

View file

@ -9,13 +9,14 @@ from contextlib import contextmanager
from enum import Enum
from importlib import resources
from pathlib import Path
from typing import Any, Generator, Optional
from typing import Generator, Optional, TypedDict, cast
import urllib3
from jinja2 import Template
from kubernetes import client, config, watch
from kubernetes.client.exceptions import ApiException
from kubernetes.client.models.v1_deployment import V1Deployment
from kubernetes.client.models.v1_job import V1Job
from pydantic import BaseModel
from .settings import settings
@ -24,10 +25,9 @@ from .utils import sync_to_async, sync_to_async_iterator
_logger = logging.getLogger(__name__)
def _split_image_name_tag(img: str) -> tuple[str, str]:
if ":" in img:
return img.split(":", 2)
return (img, "latest")
def _split_image_name_tag(image: str) -> tuple[str, str]:
img, _, tag = image.partition(":")
return (img, tag or "latest")
@sync_to_async
@ -39,7 +39,7 @@ def load_kube_config() -> None:
@sync_to_async
def read_deployment(name: str) -> Optional[V1Deployment]:
def read_deployment(name: str) -> V1Deployment | None:
appsv1 = client.AppsV1Api()
items = appsv1.list_namespaced_deployment(
namespace=settings.build_namespace,
@ -56,9 +56,15 @@ def delete_deployment(deployment_name: str) -> None:
)
class PatchOperation(TypedDict, total=False):
op: str
path: str
value: str | int # maybe absent, hence total=False above
@sync_to_async
def patch_deployment(
deployment_name: str, ops: list[dict["str", Any]], not_found_ok: bool
deployment_name: str, ops: list[PatchOperation], not_found_ok: bool
) -> None:
appsv1 = client.AppsV1Api()
try:
@ -113,7 +119,7 @@ def _watch(list_method, *args, **kwargs):
@sync_to_async_iterator
def watch_deployments():
def watch_deployments() -> Generator[V1Deployment, None, None]:
appsv1 = client.AppsV1Api()
yield from _watch(
appsv1.list_namespaced_deployment, namespace=settings.build_namespace
@ -121,7 +127,7 @@ def watch_deployments():
@sync_to_async_iterator
def watch_jobs():
def watch_jobs() -> Generator[V1Job, None, None]:
batchv1 = client.BatchV1Api()
yield from _watch(batchv1.list_namespaced_job, namespace=settings.build_namespace)
@ -258,8 +264,8 @@ async def delete_job(build_name: str, job_kind: DeploymentMode) -> None:
@sync_to_async
def log(build_name: str, job_kind: DeploymentMode | None) -> str:
"""Return the buil log.
def log(build_name: str, job_kind: DeploymentMode | None) -> str | None:
"""Return the build log.
The pod for which the log is returned is the first that matches the
build_name (via its runboat/build label) and job_kind (via its
@ -269,20 +275,18 @@ def log(build_name: str, job_kind: DeploymentMode | None) -> str:
pods = corev1.list_namespaced_pod(
namespace=settings.build_namespace, label_selector=f"runboat/build={build_name}"
).items
pod = None
for pod in pods:
if job_kind is None:
if "runboat/job-kind" not in pod.metadata.labels:
break
else:
if pod.metadata.labels.get("runboat/job-kind") == job_kind:
break
if pod.metadata.labels.get("runboat/job-kind") == job_kind:
break
else:
# no matching pod found
return
return corev1.read_namespaced_pod_log(
pod.metadata.name,
namespace=settings.build_namespace,
tail_lines=None if job_kind else None,
follow=False,
return None
return cast(
str,
corev1.read_namespaced_pod_log(
pod.metadata.name,
namespace=settings.build_namespace,
tail_lines=None if job_kind else None,
follow=False,
),
)

View file

@ -325,8 +325,8 @@ class Build(BaseModel):
desired_replicas: int | None = None,
remove_finalizers: bool = False,
not_found_ok: bool = False,
) -> None:
ops = []
) -> bool:
ops: list[k8s.PatchOperation] = []
if init_status is not None and init_status != self.init_status:
ops.extend(
[

View file

@ -50,7 +50,7 @@ class Settings(BaseSettings):
@validator("supported_repos")
@classmethod
def validate_supported_repos(v) -> set[str]:
def validate_supported_repos(cls, v: set[str]) -> set[str]:
return {item.lower() for item in v}

View file

@ -9,7 +9,7 @@ from .controller import controller
from .models import BuildStatus
router = APIRouter()
templates = Jinja2Templates(directory=Path(__file__).parent / "webui")
templates = Jinja2Templates(directory=str(Path(__file__).parent / "webui"))
@router.get("/builds/{name}", response_class=HTMLResponse)

14
tests/test_k8s.py Normal file
View file

@ -0,0 +1,14 @@
import pytest
from runboat.k8s import _split_image_name_tag
@pytest.mark.parametrize(
("image", "expected"),
[
("postgres", ("postgres", "latest")),
("postgres:12", ("postgres", "12")),
],
)
def test_split_image_name_tag(image, expected):
assert _split_image_name_tag(image) == expected