Fix a few typing issues and bug revealed by mypy

This commit is contained in:
Stéphane Bidoul 2021-10-28 19:04:31 +02:00
parent 5ee536076f
commit b13e0dc8b1
No known key found for this signature in database
GPG key ID: BCAB2555446B5B92
6 changed files with 22 additions and 22 deletions

View file

@ -118,10 +118,10 @@ async def trigger_pull(org: str, repo: str, pr: int):
def _build_by_name(name: str) -> models.Build:
try:
return controller.db.get(name)
except KeyError:
build = controller.db.get(name)
if build is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
return build
@router.get("/builds/{name}", response_model=Build)

View file

@ -27,17 +27,17 @@ class Controller:
_tasks: list[asyncio.Task]
_wakeup_event: asyncio.Event
def __init__(self):
def __init__(self) -> None:
self._tasks = []
self._wakeup_event = asyncio.Event()
self.reset()
def reset(self):
def reset(self) -> None:
self.db = BuildsDb()
@property
def running(self) -> int:
return self.db.count_by_statuses([BuildStatus.started, BuildStatus.starting])
return self.db.count_by_statuses((BuildStatus.started, BuildStatus.starting))
@property
def max_running(self) -> int:
@ -45,7 +45,7 @@ class Controller:
@property
def starting(self) -> int:
return self.db.count_by_statuses([BuildStatus.starting])
return self.db.count_by_statuses((BuildStatus.starting,))
@property
def max_starting(self) -> int:
@ -131,7 +131,7 @@ class Controller:
if len(to_undeploy) < can_undeploy:
break # back to sleep
async def start(self):
async def start(self) -> None:
_logger.info("Starting controller tasks.")
async def walking_dead(func):
@ -148,7 +148,7 @@ class Controller:
for f in (self.watcher, self.starter, self.stopper, self.undeployer):
self._tasks.append(asyncio.create_task(walking_dead(f)))
async def stop(self):
async def stop(self) -> None:
_logger.info("Stopping controller tasks.")
for task in self._tasks:
task.cancel()

View file

@ -16,7 +16,7 @@ class BuildsDb:
_con: sqlite3.Connection
def __init__(self):
def __init__(self) -> None:
self.reset()
@classmethod
@ -88,7 +88,7 @@ class BuildsDb:
),
)
def count_by_statuses(self, statuses: tuple[BuildStatus]) -> int:
def count_by_statuses(self, statuses: tuple[BuildStatus, ...]) -> int:
q = ",".join(["?"] * len(statuses))
return self._con.execute(
f"SELECT COUNT(name) FROM builds WHERE status IN ({q})", statuses
@ -123,7 +123,7 @@ class BuildsDb:
def branches_and_pulls(self, repo: str) -> list[BranchOrPull]:
res = []
branch_or_pull: BranchOrPull = None
branch_or_pull: BranchOrPull | None = None
for row in self._con.execute(
"SELECT * FROM builds WHERE repo=?"
"ORDER BY target_branch, pr, created DESC",

View file

@ -5,7 +5,7 @@ import tempfile
from contextlib import contextmanager
from importlib import resources
from pathlib import Path
from typing import Any, AsyncGenerator, ContextManager, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple
from jinja2 import Template
from kubernetes_asyncio import client, config, watch
@ -95,7 +95,7 @@ def make_deployment_vars(
@contextmanager
def _render_kubefiles(deployment_vars: DeploymentVars) -> ContextManager[Path]:
def _render_kubefiles(deployment_vars: DeploymentVars) -> Generator[Path, None, None]:
with resources.path(
__package__, "kubefiles"
) as kubefiles_path, tempfile.TemporaryDirectory() as tmp_dir:
@ -111,9 +111,9 @@ def _render_kubefiles(deployment_vars: DeploymentVars) -> ContextManager[Path]:
async def _kubectl(args: List[str]) -> None:
proc = await asyncio.create_subprocess_exec("kubectl", *args)
await proc.wait()
if proc.returncode != 0:
raise subprocess.CalledProcessError(proc.returncode, ["kubectl"] + args)
return_code = await proc.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, ["kubectl"] + args)
async def deploy(deployment_vars: DeploymentVars) -> None:

View file

@ -1,5 +1,5 @@
import re
def slugify(s: str) -> str:
def slugify(s: str | int) -> str:
return re.sub(r"[^a-z0-9]", "-", str(s).lower())

View file

@ -2,7 +2,7 @@ import logging
from fastapi import APIRouter, BackgroundTasks, Header, Request
from . import controller
from . import models
from .settings import settings
_logger = logging.getLogger(__name__)
@ -15,7 +15,7 @@ async def receive_payload(
background_tasks: BackgroundTasks,
request: Request,
x_github_event: str = Header(...),
):
) -> None:
# TODO check x-hub-signature
payload = await request.json()
repo = payload["repository"]["full_name"]
@ -29,7 +29,7 @@ async def receive_payload(
if x_github_event == "pull_request":
if action in ("opened", "synchronize"):
background_tasks.add_task(
controller.Build.deploy,
models.Build.deploy,
repo=repo,
target_branch=payload["pull_request"]["base"]["ref"],
pr=payload["pull_request"]["number"],
@ -37,7 +37,7 @@ async def receive_payload(
)
elif x_github_event == "push":
background_tasks.add_task(
controller.Build.deploy,
models.Build.deploy,
repo=repo,
target_branch=payload["ref"].split("/")[-1],
pr=None,