Add locking for multiple reciever

main
Adib Pratama 1 day ago
parent 6a4dabf28f
commit 7175cc5ef3
No known key found for this signature in database
GPG Key ID: 7C855EE276A46D2C
  1. 66
      src/nanobot.py

@ -14,6 +14,7 @@ class Nanobot:
def __init__(self): def __init__(self):
self._ws = None self._ws = None
self._chat_id = None self._chat_id = None
self._lock = asyncio.Lock()
async def __aenter__(self): async def __aenter__(self):
await self.connect() await self.connect()
@ -30,40 +31,41 @@ class Nanobot:
main_logger.info("connected chat_id=%s", self._chat_id) main_logger.info("connected chat_id=%s", self._chat_id)
async def chat(self, message: str, timeout: float = 60): async def chat(self, message: str, timeout: float = 60):
try: async with self._lock:
await self._ws.send(json.dumps({"content": message}))
except websockets.ConnectionClosed:
main_logger.info("ws: reconnecting...")
await self.connect()
await self._ws.send(json.dumps({"content": message}))
parts = []
while True:
try: try:
raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout) await self._ws.send(json.dumps({"content": message}))
except asyncio.TimeoutError: except websockets.ConnectionClosed:
main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout) main_logger.info("ws: reconnecting...")
await self.close()
await self.connect() await self.connect()
return await self._ws.send(json.dumps({"content": message}))
frame = json.loads(raw) parts = []
event = frame.get("event") while True:
text = frame.get("text", "") try:
main_logger.debug("ws raw: %s", raw) raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout)
if event == "message": except asyncio.TimeoutError:
if frame.get("kind") == "progress": main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout)
main_logger.debug("ws: skipping progress message") await self.close()
continue await self.connect()
yield text return
return frame = json.loads(raw)
if event == "delta": event = frame.get("event")
parts.append(text) text = frame.get("text", "")
elif event == "stream_end": main_logger.debug("ws raw: %s", raw)
t = "".join(parts) if event == "message":
if t: if frame.get("kind") == "progress":
yield t main_logger.debug("ws: skipping progress message")
parts = [] continue
elif event == "error": yield text
raise RuntimeError(frame.get("detail", "unknown error")) return
if event == "delta":
parts.append(text)
elif event == "stream_end":
t = "".join(parts)
if t:
yield t
parts = []
elif event == "error":
raise RuntimeError(frame.get("detail", "unknown error"))
async def close(self): async def close(self):
if self._ws: if self._ws:

Loading…
Cancel
Save