Add locking for multiple reciever

main
Adib Pratama 1 day ago
parent 6a4dabf28f
commit 7175cc5ef3
No known key found for this signature in database
GPG Key ID: 7C855EE276A46D2C
  1. 66
      src/nanobot.py

@ -14,6 +14,7 @@ class Nanobot:
def __init__(self):
self._ws = None
self._chat_id = None
self._lock = asyncio.Lock()
async def __aenter__(self):
await self.connect()
@ -30,40 +31,41 @@ class Nanobot:
main_logger.info("connected chat_id=%s", self._chat_id)
async def chat(self, message: str, timeout: float = 60):
try:
await self._ws.send(json.dumps({"content": message}))
except websockets.ConnectionClosed:
main_logger.info("ws: reconnecting...")
await self.connect()
await self._ws.send(json.dumps({"content": message}))
parts = []
while True:
async with self._lock:
try:
raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout)
except asyncio.TimeoutError:
main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout)
await self.close()
await self._ws.send(json.dumps({"content": message}))
except websockets.ConnectionClosed:
main_logger.info("ws: reconnecting...")
await self.connect()
return
frame = json.loads(raw)
event = frame.get("event")
text = frame.get("text", "")
main_logger.debug("ws raw: %s", raw)
if event == "message":
if frame.get("kind") == "progress":
main_logger.debug("ws: skipping progress message")
continue
yield text
return
if event == "delta":
parts.append(text)
elif event == "stream_end":
t = "".join(parts)
if t:
yield t
parts = []
elif event == "error":
raise RuntimeError(frame.get("detail", "unknown error"))
await self._ws.send(json.dumps({"content": message}))
parts = []
while True:
try:
raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout)
except asyncio.TimeoutError:
main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout)
await self.close()
await self.connect()
return
frame = json.loads(raw)
event = frame.get("event")
text = frame.get("text", "")
main_logger.debug("ws raw: %s", raw)
if event == "message":
if frame.get("kind") == "progress":
main_logger.debug("ws: skipping progress message")
continue
yield text
return
if event == "delta":
parts.append(text)
elif event == "stream_end":
t = "".join(parts)
if t:
yield t
parts = []
elif event == "error":
raise RuntimeError(frame.get("detail", "unknown error"))
async def close(self):
if self._ws:

Loading…
Cancel
Save