diff --git a/src/nanobot.py b/src/nanobot.py index 194ab75..cce5676 100644 --- a/src/nanobot.py +++ b/src/nanobot.py @@ -14,6 +14,7 @@ class Nanobot: def __init__(self): self._ws = None self._chat_id = None + self._lock = asyncio.Lock() async def __aenter__(self): await self.connect() @@ -30,40 +31,41 @@ class Nanobot: main_logger.info("connected chat_id=%s", self._chat_id) async def chat(self, message: str, timeout: float = 60): - try: - await self._ws.send(json.dumps({"content": message})) - except websockets.ConnectionClosed: - main_logger.info("ws: reconnecting...") - await self.connect() - await self._ws.send(json.dumps({"content": message})) - parts = [] - while True: + async with self._lock: try: - raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout) - except asyncio.TimeoutError: - main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout) - await self.close() + await self._ws.send(json.dumps({"content": message})) + except websockets.ConnectionClosed: + main_logger.info("ws: reconnecting...") await self.connect() - return - frame = json.loads(raw) - event = frame.get("event") - text = frame.get("text", "") - main_logger.debug("ws raw: %s", raw) - if event == "message": - if frame.get("kind") == "progress": - main_logger.debug("ws: skipping progress message") - continue - yield text - return - if event == "delta": - parts.append(text) - elif event == "stream_end": - t = "".join(parts) - if t: - yield t - parts = [] - elif event == "error": - raise RuntimeError(frame.get("detail", "unknown error")) + await self._ws.send(json.dumps({"content": message})) + parts = [] + while True: + try: + raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout) + except asyncio.TimeoutError: + main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout) + await self.close() + await self.connect() + return + frame = json.loads(raw) + event = frame.get("event") + text = frame.get("text", "") + main_logger.debug("ws raw: %s", raw) + if event == "message": + if frame.get("kind") == "progress": + main_logger.debug("ws: skipping progress message") + continue + yield text + return + if event == "delta": + parts.append(text) + elif event == "stream_end": + t = "".join(parts) + if t: + yield t + parts = [] + elif event == "error": + raise RuntimeError(frame.get("detail", "unknown error")) async def close(self): if self._ws: