Some typing
This commit is contained in:
parent
defa2f959e
commit
858c58e47d
7 changed files with 58 additions and 36 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from . import k8s
|
||||
from .db import BuildsDb
|
||||
|
|
@ -247,7 +248,7 @@ class Controller:
|
|||
async def start(self) -> None:
|
||||
_logger.info("Starting controller tasks.")
|
||||
|
||||
async def walking_dead(func):
|
||||
async def walking_dead(func: Callable[..., Awaitable[Any]]) -> None:
|
||||
while True:
|
||||
_logger.info(f"(Re)starting {func.__name__}")
|
||||
try:
|
||||
|
|
@ -275,7 +276,7 @@ class Controller:
|
|||
task.cancel()
|
||||
# Wait until all tasks are cancelled.
|
||||
await asyncio.gather(*self._tasks, return_exceptions=True)
|
||||
self._task = []
|
||||
self._tasks.clear()
|
||||
|
||||
|
||||
controller = Controller()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import sqlite3
|
||||
from typing import Optional
|
||||
from typing import cast
|
||||
|
||||
from .models import Build, BuildInitStatus, BuildStatus
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ class BuildsDb:
|
|||
self, repo: str, target_branch: str, pr: int | None, git_commit: str
|
||||
) -> Build | None:
|
||||
query = "SELECT * FROM builds WHERE repo=? AND target_branch=? AND git_commit=?"
|
||||
params = [repo.lower(), target_branch, git_commit]
|
||||
params: list[str | int] = [repo.lower(), target_branch, git_commit]
|
||||
if pr:
|
||||
query += " AND pr=?"
|
||||
params.append(pr)
|
||||
|
|
@ -131,17 +131,20 @@ class BuildsDb:
|
|||
return True
|
||||
|
||||
def count_by_status(self, status: BuildStatus) -> int:
|
||||
return self._con.execute(
|
||||
count = self._con.execute(
|
||||
"SELECT COUNT(name) FROM builds WHERE status=?", (status,)
|
||||
).fetchone()[0]
|
||||
return cast(int, count)
|
||||
|
||||
def count_by_init_status(self, init_status: BuildInitStatus) -> int:
|
||||
return self._con.execute(
|
||||
count = self._con.execute(
|
||||
"SELECT COUNT(name) FROM builds WHERE init_status=?", (init_status,)
|
||||
).fetchone()[0]
|
||||
return cast(int, count)
|
||||
|
||||
def count_all(self) -> int:
|
||||
return self._con.execute("SELECT COUNT(name) FROM builds").fetchone()[0]
|
||||
count = self._con.execute("SELECT COUNT(name) FROM builds").fetchone()[0]
|
||||
return cast(int, count)
|
||||
|
||||
def to_initialize(self, limit: int) -> list[Build]:
|
||||
"""Return the list of builds to initialize, ordered by creation timestamp."""
|
||||
|
|
@ -168,7 +171,7 @@ class BuildsDb:
|
|||
).fetchall()
|
||||
return [self._build_from_row(row) for row in rows]
|
||||
|
||||
def search(self, repo: Optional[str] = None) -> list[Build]:
|
||||
def search(self, repo: str | None = None) -> list[Build]:
|
||||
query = "SELECT * FROM builds "
|
||||
where = []
|
||||
params = []
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@ from contextlib import contextmanager
|
|||
from enum import Enum
|
||||
from importlib import resources
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional
|
||||
from typing import Generator, Optional, TypedDict, cast
|
||||
|
||||
import urllib3
|
||||
from jinja2 import Template
|
||||
from kubernetes import client, config, watch
|
||||
from kubernetes.client.exceptions import ApiException
|
||||
from kubernetes.client.models.v1_deployment import V1Deployment
|
||||
from kubernetes.client.models.v1_job import V1Job
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .settings import settings
|
||||
|
|
@ -24,10 +25,9 @@ from .utils import sync_to_async, sync_to_async_iterator
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _split_image_name_tag(img: str) -> tuple[str, str]:
|
||||
if ":" in img:
|
||||
return img.split(":", 2)
|
||||
return (img, "latest")
|
||||
def _split_image_name_tag(image: str) -> tuple[str, str]:
|
||||
img, _, tag = image.partition(":")
|
||||
return (img, tag or "latest")
|
||||
|
||||
|
||||
@sync_to_async
|
||||
|
|
@ -39,7 +39,7 @@ def load_kube_config() -> None:
|
|||
|
||||
|
||||
@sync_to_async
|
||||
def read_deployment(name: str) -> Optional[V1Deployment]:
|
||||
def read_deployment(name: str) -> V1Deployment | None:
|
||||
appsv1 = client.AppsV1Api()
|
||||
items = appsv1.list_namespaced_deployment(
|
||||
namespace=settings.build_namespace,
|
||||
|
|
@ -56,9 +56,15 @@ def delete_deployment(deployment_name: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
class PatchOperation(TypedDict, total=False):
|
||||
op: str
|
||||
path: str
|
||||
value: str | int # maybe absent, hence total=False above
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def patch_deployment(
|
||||
deployment_name: str, ops: list[dict["str", Any]], not_found_ok: bool
|
||||
deployment_name: str, ops: list[PatchOperation], not_found_ok: bool
|
||||
) -> None:
|
||||
appsv1 = client.AppsV1Api()
|
||||
try:
|
||||
|
|
@ -113,7 +119,7 @@ def _watch(list_method, *args, **kwargs):
|
|||
|
||||
|
||||
@sync_to_async_iterator
|
||||
def watch_deployments():
|
||||
def watch_deployments() -> Generator[V1Deployment, None, None]:
|
||||
appsv1 = client.AppsV1Api()
|
||||
yield from _watch(
|
||||
appsv1.list_namespaced_deployment, namespace=settings.build_namespace
|
||||
|
|
@ -121,7 +127,7 @@ def watch_deployments():
|
|||
|
||||
|
||||
@sync_to_async_iterator
|
||||
def watch_jobs():
|
||||
def watch_jobs() -> Generator[V1Job, None, None]:
|
||||
batchv1 = client.BatchV1Api()
|
||||
yield from _watch(batchv1.list_namespaced_job, namespace=settings.build_namespace)
|
||||
|
||||
|
|
@ -258,8 +264,8 @@ async def delete_job(build_name: str, job_kind: DeploymentMode) -> None:
|
|||
|
||||
|
||||
@sync_to_async
|
||||
def log(build_name: str, job_kind: DeploymentMode | None) -> str:
|
||||
"""Return the buil log.
|
||||
def log(build_name: str, job_kind: DeploymentMode | None) -> str | None:
|
||||
"""Return the build log.
|
||||
|
||||
The pod for which the log is returned is the first that matches the
|
||||
build_name (via its runboat/build label) and job_kind (via its
|
||||
|
|
@ -269,20 +275,18 @@ def log(build_name: str, job_kind: DeploymentMode | None) -> str:
|
|||
pods = corev1.list_namespaced_pod(
|
||||
namespace=settings.build_namespace, label_selector=f"runboat/build={build_name}"
|
||||
).items
|
||||
pod = None
|
||||
for pod in pods:
|
||||
if job_kind is None:
|
||||
if "runboat/job-kind" not in pod.metadata.labels:
|
||||
break
|
||||
else:
|
||||
if pod.metadata.labels.get("runboat/job-kind") == job_kind:
|
||||
break
|
||||
else:
|
||||
# no matching pod found
|
||||
return
|
||||
return corev1.read_namespaced_pod_log(
|
||||
return None
|
||||
return cast(
|
||||
str,
|
||||
corev1.read_namespaced_pod_log(
|
||||
pod.metadata.name,
|
||||
namespace=settings.build_namespace,
|
||||
tail_lines=None if job_kind else None,
|
||||
follow=False,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -325,8 +325,8 @@ class Build(BaseModel):
|
|||
desired_replicas: int | None = None,
|
||||
remove_finalizers: bool = False,
|
||||
not_found_ok: bool = False,
|
||||
) -> None:
|
||||
ops = []
|
||||
) -> bool:
|
||||
ops: list[k8s.PatchOperation] = []
|
||||
if init_status is not None and init_status != self.init_status:
|
||||
ops.extend(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class Settings(BaseSettings):
|
|||
|
||||
@validator("supported_repos")
|
||||
@classmethod
|
||||
def validate_supported_repos(v) -> set[str]:
|
||||
def validate_supported_repos(cls, v: set[str]) -> set[str]:
|
||||
return {item.lower() for item in v}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .controller import controller
|
|||
from .models import BuildStatus
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory=Path(__file__).parent / "webui")
|
||||
templates = Jinja2Templates(directory=str(Path(__file__).parent / "webui"))
|
||||
|
||||
|
||||
@router.get("/builds/{name}", response_class=HTMLResponse)
|
||||
|
|
|
|||
14
tests/test_k8s.py
Normal file
14
tests/test_k8s.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import pytest
|
||||
|
||||
from runboat.k8s import _split_image_name_tag
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("image", "expected"),
|
||||
[
|
||||
("postgres", ("postgres", "latest")),
|
||||
("postgres:12", ("postgres", "12")),
|
||||
],
|
||||
)
|
||||
def test_split_image_name_tag(image, expected):
|
||||
assert _split_image_name_tag(image) == expected
|
||||
Loading…
Reference in a new issue