diff --git a/src/runboat/api.py b/src/runboat/api.py index 794410e..925cbfa 100644 --- a/src/runboat/api.py +++ b/src/runboat/api.py @@ -174,11 +174,16 @@ async def delete(name: str) -> None: class BuildEventSource: def __init__( - self, request: Request, repo: str | None = None, build_name: str | None = None + self, + request: Request, + repo: str | None = None, + target_branch: str | None = None, + build_name: str | None = None, ): self.queue: asyncio.Queue[str] = asyncio.Queue() self.request = request self.repo = repo + self.target_branch = target_branch self.build_name = build_name controller.db.register_listener(self) @@ -189,12 +194,16 @@ class BuildEventSource: def on_build_event(self, event: models.BuildEvent, build: models.Build) -> None: if self.repo and build.repo != self.repo: return + if self.target_branch and build.target_branch != self.target_branch: + return if self.build_name and build.name != self.build_name: return self.queue.put_nowait(self._serialize(event, build)) async def events(self) -> AsyncGenerator[str, None]: - for build in controller.db.search(self.repo, self.build_name): + for build in controller.db.search( + self.repo, self.target_branch, self.build_name + ): yield self._serialize(models.BuildEvent.modified, build) while True: try: @@ -212,7 +221,8 @@ class BuildEventSource: async def eventsource_endpoint( request: Request, repo: Optional[str] = None, + target_branch: Optional[str] = None, build_name: Optional[str] = None, ) -> EventSourceResponse: - event_source = BuildEventSource(request, repo, build_name) + event_source = BuildEventSource(request, repo, target_branch, build_name) return EventSourceResponse(event_source.events()) diff --git a/src/runboat/db.py b/src/runboat/db.py index 249117f..f3db7c6 100644 --- a/src/runboat/db.py +++ b/src/runboat/db.py @@ -185,7 +185,10 @@ class BuildsDb: return [self._build_from_row(row) for row in rows] def search( - self, repo: str | None = None, name: str | None = None + self, + repo: str | None = None, + target_branch: str | None = None, + name: str | None = None, ) -> Iterator[Build]: query = "SELECT * FROM builds " where = [] @@ -193,6 +196,9 @@ class BuildsDb: if repo: where.append("repo=?") params.append(repo.lower()) + if target_branch: + where.append("target_branch=?") + params.append(target_branch) if name: where.append("name=?") params.append(name)