# SPDX-FileCopyrightText: 2021 - 2023 Constantine Evans <qslib@mb.costi.net>
#
# SPDX-License-Identifier: EUPL-1.2
from __future__ import annotations
import asyncio
import io
import logging
import re
import time
from asyncio.futures import Future
from dataclasses import dataclass
from typing import Any, Coroutine, Optional, Protocol, Type
from .scpi_commands import AccessLevel, SCPICommand, _arglist
NL_OR_Q = re.compile(rb"(?:\n|<(/?)([\w.]+)[ *]*>?)")
TIMESTAMP = re.compile(rb"(\d{8,}\.\d{3})")
log = logging.getLogger(__name__)
def _validate_command_format(commandstring: bytes) -> None:
# This is meant to validate that the command will not mess up comms
# The command may be completely malformed otherwise
tagseq: list[tuple[bytes, bytes, bytes]] = re.findall(
rb"<(/?)([\w.]+)[ *]*>|(\n)", commandstring.rstrip()
) # tuple of close?,tag,newline?
tagstack: list[bytes] = []
for c, t, n in tagseq:
if n and not tagstack:
raise ValueError("newline outside of quotation")
elif t and not c:
tagstack.append(t)
elif c:
if not tagstack:
raise ValueError(f"unbalanced tag <{c.decode()}{t.decode()}>")
opentag = tagstack.pop()
if opentag != t:
raise ValueError(
f"unbalanced tags <{opentag.decode()}> <{c.decode()}{t.decode()}>"
)
elif n:
continue
else:
raise ValueError("Unknown")
if tagstack:
raise ValueError("Unclosed tags")
[docs]
class Error(Exception):
pass
[docs]
class CommandError(Error):
[docs]
@staticmethod
def parse(command: str, ref_index: str, response: str) -> CommandError:
m = re.match(r"\[(\w+)\] (.*)", response)
if (not m) or (m[1] not in COM_ERRORS):
return UnparsedCommandError(command, ref_index, response)
try:
return COM_ERRORS[m[1]].parse(command, ref_index, m[2])
except ValueError:
return UnparsedCommandError(command, ref_index, response)
[docs]
@dataclass
class UnparsedCommandError(CommandError):
"""The machine has returned an error that we are not familiar with,
and that we haven't parsed."""
command: Optional[str]
ref_index: Optional[str]
response: str
[docs]
@dataclass
class QS_IOError(CommandError):
command: str
message: str
data: dict[str, str]
[docs]
@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> QS_IOError:
m = re.match(r"(.*) --> (.*)", message)
if not m:
raise ValueError
data = _arglist.parse_string(m[1])[0].opts
return cls(command, m[2], data)
[docs]
@dataclass
class InsufficientAccess(CommandError):
command: str
requiredAccess: AccessLevel
currentAccess: AccessLevel
message: str
[docs]
@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> InsufficientAccess:
m = re.match(
r'-requiredAccess="(\w+)" -currentAccess="(\w+)" --> (.*)', message
)
if not m:
raise ValueError
return cls(command, AccessLevel(m[1]), AccessLevel(m[2]), m[3])
[docs]
@dataclass
class AuthError(CommandError):
command: str
message: str
[docs]
@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> AuthError:
m = re.match(r"--> (.*)", message)
if not m:
raise ValueError
return cls(command, m[1])
[docs]
@dataclass
class AccessLevelExceeded(CommandError):
command: str
accessLimit: AccessLevel
message: str
[docs]
@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> AccessLevelExceeded:
m = re.match(r'-accessLimit="(\w+)" --> (.*)', message)
if not m:
raise ValueError
return cls(command, AccessLevel(m[1]), m[2])
[docs]
@dataclass
class InvocationError(CommandError):
command: str
message: str
[docs]
@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> InvocationError:
return cls(command, message)
[docs]
@dataclass
class NoMatch(CommandError):
command: str
message: str
[docs]
@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> NoMatch:
return cls(command, message)
COM_ERRORS: dict[str, Type[CommandError]] = {
"InsufficientAccess": InsufficientAccess,
"AuthError": AuthError,
"AccessLevelExceeded": AccessLevelExceeded,
"InvocationError": InvocationError,
"NoMatch": NoMatch,
"IOError": QS_IOError,
}
[docs]
class ReplyError(IOError):
pass
[docs]
class SubHandler(Protocol):
def __call__(
self, topic: bytes, message: bytes, timestamp: float | None
) -> Coroutine[None, None, None]: # pragma: no cover
...
[docs]
class QS_IS_Protocol(asyncio.Protocol):
lostconnection: Future[Any]
last_received: float
waiting_commands: list[
tuple[
bytes,
None
| Future[tuple[bytes, bytes, None | Future[tuple[bytes, bytes, None]]]],
]
]
def __init__(self) -> None:
self.default_topic_handler = self._default_topic_handler
self.readymsg: Future[str] = asyncio.get_running_loop().create_future()
self.lostconnection = asyncio.get_running_loop().create_future()
self.should_be_connected = False
self.last_received = time.time()
[docs]
def connection_lost(self, exc: Optional[Exception]) -> None:
if self.should_be_connected:
log.warn("Lost connection")
else:
log.info("Connection closed")
self.lostconnection.set_result(exc)
self.should_be_connected = False
# Cancel all futures; we'll never recover them.
for _, future in self.waiting_commands:
if future is not None:
future.cancel()
self.waiting_commands = []
[docs]
async def disconnect(self) -> None:
self.should_be_connected = False
await self.run_command("QUIT")
[docs]
def connection_made(self, transport: Any) -> None:
log.info("Made connection")
self.should_be_connected = True
# setup connection.
self.transport = transport
self.waiting_commands = []
self.buffer = io.BytesIO()
self.quote_stack: list[bytes] = []
self.topic_handlers: dict[bytes, SubHandler] = {}
self.last_received = time.time()
self.unclosed_quote_pos: int | None = None
async def _default_topic_handler(
self, topic: bytes, message: bytes, timestamp: Optional[float] = None
) -> None:
log.info(f"{topic.decode()} at {timestamp}: {message.decode()}")
self.last_received = time.time()
[docs]
async def handle_sub_message(self, message: bytes) -> None:
i = message.index(b" ")
topic = message[0:i]
if m := TIMESTAMP.match(message, i + 1):
timestamp: float | None = float(m[1])
i = m.end()
else:
timestamp = None
asyncio.create_task(
self.topic_handlers.get(topic, self.default_topic_handler)(
topic, message[i + 1 :], timestamp=timestamp
)
)
[docs]
async def parse_message(self, ds: bytes) -> None:
if ds.startswith((b"ERRor", b"OK", b"NEXT")):
ms = ds.index(b" ")
r = None
comfut_new = None
if ds.startswith(b"NEXT"):
loop = asyncio.get_running_loop()
comfut_new = loop.create_future()
for i, (commref, comfut) in enumerate(self.waiting_commands):
if ds.startswith(commref, ms + 1):
if comfut is not None:
comfut.set_result(
(ds[:ms], ds[ms + len(commref) + 2 :], comfut_new)
)
else:
log.info(f"{commref!r} complete: {ds!r}")
r = i
break
if r is None:
log.error(f"received unexpected command response: {ds!r}")
elif ds.startswith(b"NEXT"):
self.waiting_commands.append((self.waiting_commands[r][0], comfut_new))
del self.waiting_commands[r]
else:
del self.waiting_commands[r]
elif ds.startswith(b"MESSage"):
await self.handle_sub_message(ds[8:])
elif ds.startswith(b"READy"):
self.readymsg.set_result(ds.decode())
else:
log.error(f"Unknown message: {ds!r}")
[docs]
def data_received(self, data: bytes) -> None:
"""Process received data packet from instrument, keeping track of quotes. If
a newline occurs when the quote stack is empty, create a task to process
the message, but continue processing. (TODO: consider threads/processes here.)
:param data: bytes:
"""
log.debug(f"Received {data!r}")
# If we have an unclosed tag opener (<) in the buffer, add it to the data
if self.unclosed_quote_pos is not None:
self.buffer.write(data)
self.buffer.seek(self.unclosed_quote_pos)
data = self.buffer.read()
self.buffer.truncate(self.unclosed_quote_pos)
self.buffer.seek(self.unclosed_quote_pos)
self.unclosed_quote_pos = None
print(data)
lastwrite = 0
for m in NL_OR_Q.finditer(data):
if m[0] == b"\n":
if len(self.quote_stack) == 0:
self.buffer.write(data[lastwrite : m.end()])
lastwrite = m.end()
asyncio.create_task(self.parse_message(self.buffer.getvalue()))
self.buffer = io.BytesIO()
# else: # This is not actually needed
# continue
else:
if m[0][-1] != ord(">"):
if m.end() != len(data):
raise ValueError(data, m[0])
# We have an unclosed tag opener (<) at the end of the data
log.debug(f"Unclosed tag opener: {m[0]!r}")
self.buffer.write(data[lastwrite : m.start()])
self.unclosed_quote_pos = self.buffer.tell()
self.buffer.write(m[0])
lastwrite = m.end()
elif not m[1]:
self.quote_stack.append(m[2])
else:
try:
i = self.quote_stack.index(m[2])
except ValueError:
log.error(
f"Close quote {m[2]!r} did not have open in stack {self.quote_stack}. "
"Disconnecting to avoid corruption."
)
self.quote_stack = []
self.connection_lost(ConnectionError())
else:
self.quote_stack = self.quote_stack[:i]
self.buffer.write(data[lastwrite : m.end()])
lastwrite = m.end()
self.buffer.write(data[lastwrite:])
self.last_received = time.time()
[docs]
async def run_command(
self,
comm: str | bytes | SCPICommand,
ack_timeout: int = 300,
just_ack: bool = True,
uid: bool = True,
) -> bytes:
if isinstance(comm, str):
comm = comm.encode()
elif isinstance(comm, SCPICommand):
comm = comm.to_string().encode()
comm = comm.rstrip()
_validate_command_format(comm)
log.debug(f"Running command {comm.decode()}")
loop = asyncio.get_running_loop()
comfut: Future[
tuple[bytes, bytes, None | Future[tuple[bytes, bytes, None]]]
] = loop.create_future()
if uid:
import random
commref = str(random.randint(1, 2**30)).encode()
comm_with_ref = commref + b" " + comm
else:
comm_with_ref = comm
self.transport.write((comm_with_ref + b"\n"))
log.debug(f"Sent command {comm_with_ref!r}")
if m := re.match(rb"^(\d+) ", comm_with_ref):
commref = m[1]
else:
commref = comm
self.waiting_commands.append((commref, comfut))
try:
await asyncio.wait_for(asyncio.shield(comfut), ack_timeout)
except asyncio.CancelledError:
raise ConnectionError
state, msg, comnext = comfut.result()
log.debug(f"Received ({state!r}, {msg!r})")
if state == b"NEXT":
if just_ack:
# self.waiting_commands.append((commref, None))
return b""
else:
assert comnext is not None
await comnext
state, msg, comnext2 = comnext.result()
assert comnext2 is None
log.debug(f"Received ({state!r}, {msg!r})")
if state == b"OK":
return msg
elif state == b"ERRor":
raise CommandError.parse(
comm.decode(), commref.decode(), msg.decode().rstrip()
) from None
else: # pragma: no cover
raise CommandError.parse(
comm.decode(), commref.decode(), (state + b" " + msg).decode()
)