2022-09-22 18:00:49 -04:00

129 lines
4.2 KiB
Python

from __future__ import annotations
from typing import Callable
import dataclasses
import json
from pocketbase.services.utils.base_service import BaseService
from pocketbase.services.utils.sse import Event, SSEClient
from pocketbase.models.record import Record
@dataclasses.dataclass
class MessageData:
action: str
record: Record
class Realtime(BaseService):
subscriptions: dict
client_id: str = ""
event_source: SSEClient | None = None
def __init__(self, client) -> None:
super().__init__(client)
self.subscriptions = {}
self.client_id = ""
self.event_source = None
def subscribe(
self, subscription: str, callback: Callable[[MessageData], None]
) -> None:
"""Inits the sse connection (if not already) and register the subscription."""
# unsubscribe existing
if subscription in self.subscriptions and self.event_source:
self.event_source.remove_event_listener(subscription, callback)
# register subscription
self.subscriptions[subscription] = self._make_subscription(callback)
if not self.event_source:
self._connect()
elif self.client_id:
self._submit_subscriptions()
def unsubscribe(self, subscription: str | None = None) -> None:
"""
Unsubscribe from a subscription.
If the `subscription` argument is not set,
then the client will unsubscribe from all registered subscriptions.
The related sse connection will be autoclosed if after the
unsubscribe operations there are no active subscriptions left.
"""
if not subscription:
self._remove_subscription_listeners()
self.subscriptions = {}
elif subscription in self.subscriptions:
self.event_source.remove_event_listener(
subscription, self.subscriptions[subscription]
)
self.subscriptions.pop(subscription)
else:
return
if self.client_id:
self._submit_subscriptions()
if not self.subscriptions:
self._disconnect()
def _make_subscription(
self, callback: Callable[[MessageData], None]
) -> Callable[[Event], None]:
def listener(event: Event) -> None:
data = json.loads(event.data)
if "record" in data and "action" in data:
callback(
MessageData(
action=data["action"],
record=Record(
data=data["record"],
),
)
)
return listener
def _submit_subscriptions(self) -> bool:
self._add_subscription_listeners()
self.client.send(
"/api/realtime",
{
"method": "POST",
"body": {
"clientId": self.client_id,
"subscriptions": list(self.subscriptions.keys()),
},
},
)
return True
def _add_subscription_listeners(self) -> None:
if not self.event_source:
return
self._remove_subscription_listeners()
for subscription, callback in self.subscriptions.items():
self.event_source.add_event_listener(subscription, callback)
def _remove_subscription_listeners(self) -> None:
if not self.event_source:
return
for subscription, callback in self.subscriptions.items():
self.event_source.remove_event_listener(subscription, callback)
def _connect_handler(self, event: Event) -> None:
self.client_id = event.id
self._submit_subscriptions()
def _connect(self) -> None:
self._disconnect()
self.event_source = SSEClient(self.client.build_url("/api/realtime"))
self.event_source.add_event_listener("PB_CONNECT", self._connect_handler)
def _disconnect(self) -> None:
self._remove_subscription_listeners()
self.client_id = ""
if not self.event_source:
return
self.event_source.remove_event_listener("PB_CONNECT", self._connect_handler)
self.event_source.close()
self.event_source = None