diff --git a/pocketbase/__init__.py b/pocketbase/__init__.py index 6eb6243..4e7cd7c 100644 --- a/pocketbase/__init__.py +++ b/pocketbase/__init__.py @@ -1,6 +1,6 @@ __title__ = "pocketbase" __description__ = "PocketBase client SDK for python." -__version__ = "0.1.2" +__version__ = "0.1.3" from .client import Client, ClientResponseError diff --git a/pocketbase/client.py b/pocketbase/client.py index 58fa3f4..10f9e16 100644 --- a/pocketbase/client.py +++ b/pocketbase/client.py @@ -19,7 +19,7 @@ class ClientResponseError(Exception): status: int = 0 data: dict = {} is_abort: bool = False - original_error: Any = None + original_error: Any | None = None def __init__(self, *args, **kwargs) -> None: super().__init__(*args) @@ -46,7 +46,7 @@ class Client: self, base_url: str = "/", lang: str = "en-US", - auth_store: BaseAuthStore = None, + auth_store: BaseAuthStore | None = None, ) -> None: self.base_url = base_url self.lang = lang @@ -60,12 +60,6 @@ class Client: self.settings = Settings(self) self.realtime = Realtime(self) - def cancel_request(self, cancel_key: str): - return self - - def cancel_all_requests(self): - return self - def send(self, path: str, req_config: dict[str:Any]) -> Any: """Sends an api http request.""" config = {"method": "GET"} diff --git a/pocketbase/services/realtime.py b/pocketbase/services/realtime.py index 44d2aea..a24c6ca 100644 --- a/pocketbase/services/realtime.py +++ b/pocketbase/services/realtime.py @@ -1,19 +1,46 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Callable +import dataclasses +import json + from pocketbase.services.utils.base_service import BaseService +from pocketbase.services.utils.sse import Event, SSEClient from pocketbase.models.record import Record +@dataclasses.dataclass +class MessageData: + action: str + record: Record + + class Realtime(BaseService): - client_id: str subscriptions: dict + client_id: str = "" + event_source: SSEClient | None = None - def subscribe(self, subscription: str, callback: Callable) -> None: + def __init__(self, client) -> None: + super().__init__(client) + self.subscriptions = {} + self.client_id = "" + self.event_source = None + + def subscribe( + self, subscription: str, callback: Callable[[MessageData], None] + ) -> None: """Inits the sse connection (if not already) and register the subscription.""" - self.subscriptions[subscription] = callback + # unsubscribe existing + if subscription in self.subscriptions and self.event_source: + self.event_source.remove_event_listener(subscription, callback) + # register subscription + self.subscriptions[subscription] = self._make_subscription(callback) + if not self.event_source: + self._connect() + elif self.client_id: + self._submit_subscriptions() - def unsubscribe(self, subscription: Optional[str] = None) -> None: + def unsubscribe(self, subscription: str | None = None) -> None: """ Unsubscribe from a subscription. @@ -23,29 +50,79 @@ class Realtime(BaseService): The related sse connection will be autoclosed if after the unsubscribe operations there are no active subscriptions left. """ - pass + if not subscription: + self._remove_subscription_listeners() + self.subscriptions = {} + elif subscription in self.subscriptions: + self.event_source.remove_event_listener( + subscription, self.subscriptions[subscription] + ) + self.subscriptions.pop(subscription) + else: + return + if self.client_id: + self._submit_subscriptions() + if not self.subscriptions: + self._disconnect() + + def _make_subscription( + self, callback: Callable[[MessageData], None] + ) -> Callable[[Event], None]: + def listener(event: Event) -> None: + data = json.loads(event.data) + if "record" in data and "action" in data: + callback( + MessageData( + action=data["action"], + record=Record( + data=data["record"], + ), + ) + ) + + return listener def _submit_subscriptions(self) -> bool: + self._add_subscription_listeners() self.client.send( "/api/realtime", { "method": "POST", "body": { "clientId": self.client_id, - "subscriptions": self.subscriptions.keys(), + "subscriptions": list(self.subscriptions.keys()), }, }, ) return True def _add_subscription_listeners(self) -> None: - pass + if not self.event_source: + return + self._remove_subscription_listeners() + for subscription, callback in self.subscriptions.items(): + self.event_source.add_event_listener(subscription, callback) def _remove_subscription_listeners(self) -> None: - pass + if not self.event_source: + return + for subscription, callback in self.subscriptions.items(): + self.event_source.remove_event_listener(subscription, callback) + + def _connect_handler(self, event: Event) -> None: + self.client_id = event.id + self._submit_subscriptions() def _connect(self) -> None: - pass + self._disconnect() + self.event_source = SSEClient(self.client.build_url("/api/realtime")) + self.event_source.add_event_listener("PB_CONNECT", self._connect_handler) def _disconnect(self) -> None: - pass + self._remove_subscription_listeners() + self.client_id = "" + if not self.event_source: + return + self.event_source.remove_event_listener("PB_CONNECT", self._connect_handler) + self.event_source.close() + self.event_source = None diff --git a/pocketbase/services/utils/sse.py b/pocketbase/services/utils/sse.py index bc4fbf8..56e0b8c 100644 --- a/pocketbase/services/utils/sse.py +++ b/pocketbase/services/utils/sse.py @@ -1,11 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass +from typing import Callable +import dataclasses +import asyncio import httpx -@dataclass +@dataclasses.dataclass class Event: """Representation of an event""" @@ -19,6 +21,8 @@ class SSEClient: """Implementation of a server side event client""" FIELD_SEPARATOR = ":" + _listeners: dict = {} + _loop_running: bool = False def __init__( self, @@ -33,14 +37,15 @@ class SSEClient: self.headers = headers self.payload = payload self.encoding = encoding + self.client = httpx.AsyncClient() - def _read(self): + async def _read(self): """Read the incoming event source stream and yield event chunks""" data = b"" - with httpx.stream( + async with self.client.stream( self.method, self.url, headers=self.headers, data=self.payload, timeout=None ) as r: - for chunk in r.iter_bytes(): + async for chunk in r.aiter_bytes(): for line in chunk.splitlines(True): data += line if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): @@ -49,8 +54,8 @@ class SSEClient: if data: yield data - def events(self): - for chunk in self._read(): + async def _events(self): + async for chunk in self._read(): event = Event() for line in chunk.splitlines(): line = line.decode(self.encoding) @@ -77,3 +82,23 @@ class SSEClient: event.data = event.data[0:-1] event.event = event.event or "message" yield event + + async def _loop(self): + self._loop_running = True + async for event in self._events(): + if event.event in self._listeners: + self._listeners[event.event](event) + + def add_event_listener(self, event: str, callback: Callable[[Event], None]) -> None: + self._listeners[event] = callback + if not self._loop_running: + asyncio.run(self._loop()) + + def remove_event_listener( + self, event: str, callback: Callable[[Event], None] + ) -> None: + if event in self._listeners: + self._listeners.pop(event) + + def close(self) -> None: + pass diff --git a/pyproject.toml b/pyproject.toml index 049a201..92e25a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dynamic = ["readme", "version"] [tool.poetry] name = "pocketbase" -version = "0.1.2" +version = "0.1.3" description = "PocketBase SDK for python." authors = ["Vithor Jaeger "] readme = "README.md"