|
|
|
|
@ -14,6 +14,7 @@ class Nanobot: |
|
|
|
|
def __init__(self): |
|
|
|
|
self._ws = None |
|
|
|
|
self._chat_id = None |
|
|
|
|
self._lock = asyncio.Lock() |
|
|
|
|
|
|
|
|
|
async def __aenter__(self): |
|
|
|
|
await self.connect() |
|
|
|
|
@ -30,40 +31,41 @@ class Nanobot: |
|
|
|
|
main_logger.info("connected chat_id=%s", self._chat_id) |
|
|
|
|
|
|
|
|
|
async def chat(self, message: str, timeout: float = 60): |
|
|
|
|
try: |
|
|
|
|
await self._ws.send(json.dumps({"content": message})) |
|
|
|
|
except websockets.ConnectionClosed: |
|
|
|
|
main_logger.info("ws: reconnecting...") |
|
|
|
|
await self.connect() |
|
|
|
|
await self._ws.send(json.dumps({"content": message})) |
|
|
|
|
parts = [] |
|
|
|
|
while True: |
|
|
|
|
async with self._lock: |
|
|
|
|
try: |
|
|
|
|
raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout) |
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout) |
|
|
|
|
await self.close() |
|
|
|
|
await self._ws.send(json.dumps({"content": message})) |
|
|
|
|
except websockets.ConnectionClosed: |
|
|
|
|
main_logger.info("ws: reconnecting...") |
|
|
|
|
await self.connect() |
|
|
|
|
return |
|
|
|
|
frame = json.loads(raw) |
|
|
|
|
event = frame.get("event") |
|
|
|
|
text = frame.get("text", "") |
|
|
|
|
main_logger.debug("ws raw: %s", raw) |
|
|
|
|
if event == "message": |
|
|
|
|
if frame.get("kind") == "progress": |
|
|
|
|
main_logger.debug("ws: skipping progress message") |
|
|
|
|
continue |
|
|
|
|
yield text |
|
|
|
|
return |
|
|
|
|
if event == "delta": |
|
|
|
|
parts.append(text) |
|
|
|
|
elif event == "stream_end": |
|
|
|
|
t = "".join(parts) |
|
|
|
|
if t: |
|
|
|
|
yield t |
|
|
|
|
parts = [] |
|
|
|
|
elif event == "error": |
|
|
|
|
raise RuntimeError(frame.get("detail", "unknown error")) |
|
|
|
|
await self._ws.send(json.dumps({"content": message})) |
|
|
|
|
parts = [] |
|
|
|
|
while True: |
|
|
|
|
try: |
|
|
|
|
raw = await asyncio.wait_for(self._ws.recv(), timeout=timeout) |
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
main_logger.info("ws: timeout (no data for %.1fs), reconnecting", timeout) |
|
|
|
|
await self.close() |
|
|
|
|
await self.connect() |
|
|
|
|
return |
|
|
|
|
frame = json.loads(raw) |
|
|
|
|
event = frame.get("event") |
|
|
|
|
text = frame.get("text", "") |
|
|
|
|
main_logger.debug("ws raw: %s", raw) |
|
|
|
|
if event == "message": |
|
|
|
|
if frame.get("kind") == "progress": |
|
|
|
|
main_logger.debug("ws: skipping progress message") |
|
|
|
|
continue |
|
|
|
|
yield text |
|
|
|
|
return |
|
|
|
|
if event == "delta": |
|
|
|
|
parts.append(text) |
|
|
|
|
elif event == "stream_end": |
|
|
|
|
t = "".join(parts) |
|
|
|
|
if t: |
|
|
|
|
yield t |
|
|
|
|
parts = [] |
|
|
|
|
elif event == "error": |
|
|
|
|
raise RuntimeError(frame.get("detail", "unknown error")) |
|
|
|
|
|
|
|
|
|
async def close(self): |
|
|
|
|
if self._ws: |
|
|
|
|
|