140 lines
3.9 KiB
Python
140 lines
3.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Callable
|
|
import dataclasses
|
|
import threading
|
|
|
|
import httpx
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Event:
|
|
"""Representation of an event"""
|
|
|
|
id: str = ""
|
|
event: str = "message"
|
|
data: str = ""
|
|
retry: int | None = None
|
|
|
|
|
|
class EventLoop(threading.Thread):
|
|
FIELD_SEPARATOR = ":"
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
method: str = "GET",
|
|
headers: dict | None = None,
|
|
payload: dict | None = None,
|
|
encoding="utf-8",
|
|
listeners: dict[str, Callable] | None = None,
|
|
**kwargs,
|
|
):
|
|
threading.Thread.__init__(self, **kwargs)
|
|
self.kill = False
|
|
self.client = httpx.Client()
|
|
self.url = url
|
|
self.method = method
|
|
self.headers = headers
|
|
self.payload = payload
|
|
self.encoding = encoding
|
|
self.listeners = listeners or {}
|
|
|
|
def _read(self):
|
|
"""Read the incoming event source stream and yield event chunks"""
|
|
data = b""
|
|
with self.client.stream(
|
|
self.method,
|
|
self.url,
|
|
headers=self.headers,
|
|
data=self.payload,
|
|
timeout=None,
|
|
) as r:
|
|
for chunk in r.iter_bytes():
|
|
for line in chunk.splitlines(True):
|
|
data += line
|
|
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
|
yield data
|
|
data = b""
|
|
if data:
|
|
yield data
|
|
|
|
def _events(self):
|
|
for chunk in self._read():
|
|
event = Event()
|
|
for line in chunk.splitlines():
|
|
line = line.decode(self.encoding)
|
|
if not line.strip() or line.startswith(self.FIELD_SEPARATOR):
|
|
continue
|
|
data = line.split(self.FIELD_SEPARATOR, 1)
|
|
field = data[0]
|
|
if field not in event.__dict__:
|
|
continue
|
|
if len(data) > 1:
|
|
if data[1].startswith(" "):
|
|
value = data[1][1:]
|
|
else:
|
|
value = data[1]
|
|
else:
|
|
value = ""
|
|
if field == "data":
|
|
event.data += value + "\n"
|
|
else:
|
|
setattr(event, field, value)
|
|
if not event.data:
|
|
continue
|
|
if event.data.endswith("\n"):
|
|
event.data = event.data[0:-1]
|
|
event.event = event.event or "message"
|
|
yield event
|
|
|
|
def run(self):
|
|
for event in self._events():
|
|
if self.kill:
|
|
break
|
|
if event.event in self.listeners:
|
|
self.listeners[event.event](event)
|
|
|
|
|
|
class SSEClient:
|
|
"""Implementation of a server side event client"""
|
|
|
|
_listeners: dict = {}
|
|
_loop_thread: threading.Thread | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
method: str = "GET",
|
|
headers: dict | None = None,
|
|
payload: dict | None = None,
|
|
encoding="utf-8",
|
|
) -> None:
|
|
self._listeners = {}
|
|
self._loop_thread = EventLoop(
|
|
url=url,
|
|
method=method,
|
|
headers=headers,
|
|
payload=payload,
|
|
encoding=encoding,
|
|
listeners=self._listeners,
|
|
name="loop",
|
|
)
|
|
self._loop_thread.daemon = True
|
|
self._loop_thread.start()
|
|
|
|
def add_event_listener(self, event: str, callback: Callable[[Event], None]) -> None:
|
|
self._listeners[event] = callback
|
|
self._loop_thread.listeners = self._listeners
|
|
|
|
def remove_event_listener(
|
|
self, event: str, callback: Callable[[Event], None]
|
|
) -> None:
|
|
if event in self._listeners:
|
|
self._listeners.pop(event)
|
|
self._loop_thread.listeners = self._listeners
|
|
|
|
def close(self) -> None:
|
|
# TODO: does not work like this
|
|
self._loop_thread.kill = True
|