Fix a few typing issues and bug revealed by mypy

This commit is contained in:
Stéphane Bidoul 2021-10-28 19:04:31 +02:00
parent 5ee536076f
commit b13e0dc8b1
No known key found for this signature in database
GPG key ID: BCAB2555446B5B92
6 changed files with 22 additions and 22 deletions

View file

@ -118,10 +118,10 @@ async def trigger_pull(org: str, repo: str, pr: int):
def _build_by_name(name: str) -> models.Build: def _build_by_name(name: str) -> models.Build:
try: build = controller.db.get(name)
return controller.db.get(name) if build is None:
except KeyError:
raise HTTPException(status.HTTP_404_NOT_FOUND) raise HTTPException(status.HTTP_404_NOT_FOUND)
return build
@router.get("/builds/{name}", response_model=Build) @router.get("/builds/{name}", response_model=Build)

View file

@ -27,17 +27,17 @@ class Controller:
_tasks: list[asyncio.Task] _tasks: list[asyncio.Task]
_wakeup_event: asyncio.Event _wakeup_event: asyncio.Event
def __init__(self): def __init__(self) -> None:
self._tasks = [] self._tasks = []
self._wakeup_event = asyncio.Event() self._wakeup_event = asyncio.Event()
self.reset() self.reset()
def reset(self): def reset(self) -> None:
self.db = BuildsDb() self.db = BuildsDb()
@property @property
def running(self) -> int: def running(self) -> int:
return self.db.count_by_statuses([BuildStatus.started, BuildStatus.starting]) return self.db.count_by_statuses((BuildStatus.started, BuildStatus.starting))
@property @property
def max_running(self) -> int: def max_running(self) -> int:
@ -45,7 +45,7 @@ class Controller:
@property @property
def starting(self) -> int: def starting(self) -> int:
return self.db.count_by_statuses([BuildStatus.starting]) return self.db.count_by_statuses((BuildStatus.starting,))
@property @property
def max_starting(self) -> int: def max_starting(self) -> int:
@ -131,7 +131,7 @@ class Controller:
if len(to_undeploy) < can_undeploy: if len(to_undeploy) < can_undeploy:
break # back to sleep break # back to sleep
async def start(self): async def start(self) -> None:
_logger.info("Starting controller tasks.") _logger.info("Starting controller tasks.")
async def walking_dead(func): async def walking_dead(func):
@ -148,7 +148,7 @@ class Controller:
for f in (self.watcher, self.starter, self.stopper, self.undeployer): for f in (self.watcher, self.starter, self.stopper, self.undeployer):
self._tasks.append(asyncio.create_task(walking_dead(f))) self._tasks.append(asyncio.create_task(walking_dead(f)))
async def stop(self): async def stop(self) -> None:
_logger.info("Stopping controller tasks.") _logger.info("Stopping controller tasks.")
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()

View file

@ -16,7 +16,7 @@ class BuildsDb:
_con: sqlite3.Connection _con: sqlite3.Connection
def __init__(self): def __init__(self) -> None:
self.reset() self.reset()
@classmethod @classmethod
@ -88,7 +88,7 @@ class BuildsDb:
), ),
) )
def count_by_statuses(self, statuses: tuple[BuildStatus]) -> int: def count_by_statuses(self, statuses: tuple[BuildStatus, ...]) -> int:
q = ",".join(["?"] * len(statuses)) q = ",".join(["?"] * len(statuses))
return self._con.execute( return self._con.execute(
f"SELECT COUNT(name) FROM builds WHERE status IN ({q})", statuses f"SELECT COUNT(name) FROM builds WHERE status IN ({q})", statuses
@ -123,7 +123,7 @@ class BuildsDb:
def branches_and_pulls(self, repo: str) -> list[BranchOrPull]: def branches_and_pulls(self, repo: str) -> list[BranchOrPull]:
res = [] res = []
branch_or_pull: BranchOrPull = None branch_or_pull: BranchOrPull | None = None
for row in self._con.execute( for row in self._con.execute(
"SELECT * FROM builds WHERE repo=?" "SELECT * FROM builds WHERE repo=?"
"ORDER BY target_branch, pr, created DESC", "ORDER BY target_branch, pr, created DESC",

View file

@ -5,7 +5,7 @@ import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from importlib import resources from importlib import resources
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, ContextManager, Dict, List, Optional, Tuple from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple
from jinja2 import Template from jinja2 import Template
from kubernetes_asyncio import client, config, watch from kubernetes_asyncio import client, config, watch
@ -95,7 +95,7 @@ def make_deployment_vars(
@contextmanager @contextmanager
def _render_kubefiles(deployment_vars: DeploymentVars) -> ContextManager[Path]: def _render_kubefiles(deployment_vars: DeploymentVars) -> Generator[Path, None, None]:
with resources.path( with resources.path(
__package__, "kubefiles" __package__, "kubefiles"
) as kubefiles_path, tempfile.TemporaryDirectory() as tmp_dir: ) as kubefiles_path, tempfile.TemporaryDirectory() as tmp_dir:
@ -111,9 +111,9 @@ def _render_kubefiles(deployment_vars: DeploymentVars) -> ContextManager[Path]:
async def _kubectl(args: List[str]) -> None: async def _kubectl(args: List[str]) -> None:
proc = await asyncio.create_subprocess_exec("kubectl", *args) proc = await asyncio.create_subprocess_exec("kubectl", *args)
await proc.wait() return_code = await proc.wait()
if proc.returncode != 0: if return_code != 0:
raise subprocess.CalledProcessError(proc.returncode, ["kubectl"] + args) raise subprocess.CalledProcessError(return_code, ["kubectl"] + args)
async def deploy(deployment_vars: DeploymentVars) -> None: async def deploy(deployment_vars: DeploymentVars) -> None:

View file

@ -1,5 +1,5 @@
import re import re
def slugify(s: str) -> str: def slugify(s: str | int) -> str:
return re.sub(r"[^a-z0-9]", "-", str(s).lower()) return re.sub(r"[^a-z0-9]", "-", str(s).lower())

View file

@ -2,7 +2,7 @@ import logging
from fastapi import APIRouter, BackgroundTasks, Header, Request from fastapi import APIRouter, BackgroundTasks, Header, Request
from . import controller from . import models
from .settings import settings from .settings import settings
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -15,7 +15,7 @@ async def receive_payload(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
request: Request, request: Request,
x_github_event: str = Header(...), x_github_event: str = Header(...),
): ) -> None:
# TODO check x-hub-signature # TODO check x-hub-signature
payload = await request.json() payload = await request.json()
repo = payload["repository"]["full_name"] repo = payload["repository"]["full_name"]
@ -29,7 +29,7 @@ async def receive_payload(
if x_github_event == "pull_request": if x_github_event == "pull_request":
if action in ("opened", "synchronize"): if action in ("opened", "synchronize"):
background_tasks.add_task( background_tasks.add_task(
controller.Build.deploy, models.Build.deploy,
repo=repo, repo=repo,
target_branch=payload["pull_request"]["base"]["ref"], target_branch=payload["pull_request"]["base"]["ref"],
pr=payload["pull_request"]["number"], pr=payload["pull_request"]["number"],
@ -37,7 +37,7 @@ async def receive_payload(
) )
elif x_github_event == "push": elif x_github_event == "push":
background_tasks.add_task( background_tasks.add_task(
controller.Build.deploy, models.Build.deploy,
repo=repo, repo=repo,
target_branch=payload["ref"].split("/")[-1], target_branch=payload["ref"].split("/")[-1],
pr=None, pr=None,