Add branch filter to event source

This commit is contained in:
Stéphane Bidoul 2021-11-18 22:54:31 +01:00
parent 5cc316fee1
commit cc701d5930
No known key found for this signature in database
GPG key ID: BCAB2555446B5B92
2 changed files with 20 additions and 4 deletions

View file

@ -174,11 +174,16 @@ async def delete(name: str) -> None:
class BuildEventSource: class BuildEventSource:
def __init__( def __init__(
self, request: Request, repo: str | None = None, build_name: str | None = None self,
request: Request,
repo: str | None = None,
target_branch: str | None = None,
build_name: str | None = None,
): ):
self.queue: asyncio.Queue[str] = asyncio.Queue() self.queue: asyncio.Queue[str] = asyncio.Queue()
self.request = request self.request = request
self.repo = repo self.repo = repo
self.target_branch = target_branch
self.build_name = build_name self.build_name = build_name
controller.db.register_listener(self) controller.db.register_listener(self)
@ -189,12 +194,16 @@ class BuildEventSource:
def on_build_event(self, event: models.BuildEvent, build: models.Build) -> None: def on_build_event(self, event: models.BuildEvent, build: models.Build) -> None:
if self.repo and build.repo != self.repo: if self.repo and build.repo != self.repo:
return return
if self.target_branch and build.target_branch != self.target_branch:
return
if self.build_name and build.name != self.build_name: if self.build_name and build.name != self.build_name:
return return
self.queue.put_nowait(self._serialize(event, build)) self.queue.put_nowait(self._serialize(event, build))
async def events(self) -> AsyncGenerator[str, None]: async def events(self) -> AsyncGenerator[str, None]:
for build in controller.db.search(self.repo, self.build_name): for build in controller.db.search(
self.repo, self.target_branch, self.build_name
):
yield self._serialize(models.BuildEvent.modified, build) yield self._serialize(models.BuildEvent.modified, build)
while True: while True:
try: try:
@ -212,7 +221,8 @@ class BuildEventSource:
async def eventsource_endpoint( async def eventsource_endpoint(
request: Request, request: Request,
repo: Optional[str] = None, repo: Optional[str] = None,
target_branch: Optional[str] = None,
build_name: Optional[str] = None, build_name: Optional[str] = None,
) -> EventSourceResponse: ) -> EventSourceResponse:
event_source = BuildEventSource(request, repo, build_name) event_source = BuildEventSource(request, repo, target_branch, build_name)
return EventSourceResponse(event_source.events()) return EventSourceResponse(event_source.events())

View file

@ -185,7 +185,10 @@ class BuildsDb:
return [self._build_from_row(row) for row in rows] return [self._build_from_row(row) for row in rows]
def search( def search(
self, repo: str | None = None, name: str | None = None self,
repo: str | None = None,
target_branch: str | None = None,
name: str | None = None,
) -> Iterator[Build]: ) -> Iterator[Build]:
query = "SELECT * FROM builds " query = "SELECT * FROM builds "
where = [] where = []
@ -193,6 +196,9 @@ class BuildsDb:
if repo: if repo:
where.append("repo=?") where.append("repo=?")
params.append(repo.lower()) params.append(repo.lower())
if target_branch:
where.append("target_branch=?")
params.append(target_branch)
if name: if name:
where.append("name=?") where.append("name=?")
params.append(name) params.append(name)